Add OpenAI pipeline + some prompts

This commit is contained in:
Thomas Ricouard 2023-01-13 18:43:02 +01:00
parent 2fdf5fe239
commit 6b210aec4f
7 changed files with 201 additions and 0 deletions

View file

@ -42,6 +42,7 @@
9F7335EF29674F7100AFF0BA /* QuickLook.framework in Frameworks */ = {isa = PBXBuildFile; fileRef = 9F7335EE29674F7100AFF0BA /* QuickLook.framework */; };
9F7335F22967608F00AFF0BA /* AddRemoteTimelineVIew.swift in Sources */ = {isa = PBXBuildFile; fileRef = 9F7335F12967608F00AFF0BA /* AddRemoteTimelineVIew.swift */; };
9F7335F92968576500AFF0BA /* DisplaySettingsView.swift in Sources */ = {isa = PBXBuildFile; fileRef = 9F7335F82968576500AFF0BA /* DisplaySettingsView.swift */; };
9FAD85832971BF7200496AB1 /* Secret.plist in Resources */ = {isa = PBXBuildFile; fileRef = 9FAD85822971BF7200496AB1 /* Secret.plist */; };
9FAE4ACB293783B000772766 /* SettingsTab.swift in Sources */ = {isa = PBXBuildFile; fileRef = 9FAE4ACA293783B000772766 /* SettingsTab.swift */; };
9FAE4ACE29379A5A00772766 /* KeychainSwift in Frameworks */ = {isa = PBXBuildFile; productRef = 9FAE4ACD29379A5A00772766 /* KeychainSwift */; };
9FBFE63D292A715500C250E9 /* IceCubesApp.swift in Sources */ = {isa = PBXBuildFile; fileRef = 9FBFE63C292A715500C250E9 /* IceCubesApp.swift */; };
@ -112,6 +113,7 @@
9F7335EE29674F7100AFF0BA /* QuickLook.framework */ = {isa = PBXFileReference; lastKnownFileType = wrapper.framework; name = QuickLook.framework; path = Platforms/iPhoneOS.platform/Developer/SDKs/iPhoneOS16.2.sdk/System/Library/Frameworks/QuickLook.framework; sourceTree = DEVELOPER_DIR; };
9F7335F12967608F00AFF0BA /* AddRemoteTimelineVIew.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = AddRemoteTimelineVIew.swift; sourceTree = "<group>"; };
9F7335F82968576500AFF0BA /* DisplaySettingsView.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = DisplaySettingsView.swift; sourceTree = "<group>"; };
9FAD85822971BF7200496AB1 /* Secret.plist */ = {isa = PBXFileReference; lastKnownFileType = text.plist; path = Secret.plist; sourceTree = "<group>"; };
9FAE4AC8293774FF00772766 /* Info.plist */ = {isa = PBXFileReference; lastKnownFileType = text.plist; path = Info.plist; sourceTree = "<group>"; };
9FAE4ACA293783B000772766 /* SettingsTab.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = SettingsTab.swift; sourceTree = "<group>"; };
9FBFE639292A715500C250E9 /* IceCubesApp.app */ = {isa = PBXFileReference; explicitFileType = wrapper.application; includeInIndex = 0; path = IceCubesApp.app; sourceTree = BUILT_PRODUCTS_DIR; };
@ -254,6 +256,7 @@
9F398AB429360A5800A889F2 /* App */,
9FBFE642292A715600C250E9 /* IceCubesApp.entitlements */,
9F398AB529360A6100A889F2 /* Resources */,
9FAD85822971BF7200496AB1 /* Secret.plist */,
);
path = IceCubesApp;
sourceTree = "<group>";
@ -399,6 +402,7 @@
9F2A542C296B1177009B2D7C /* glass.caf in Resources */,
9FD34823293D06E800DB0EE9 /* Assets.xcassets in Resources */,
9F24EEB829360C330042359D /* Preview Assets.xcassets in Resources */,
9FAD85832971BF7200496AB1 /* Secret.plist in Resources */,
9F2A542E296B1CC0009B2D7C /* glass.wav in Resources */,
);
runOnlyForDeploymentPostprocessing = 0;

8
IceCubesApp/Secret.plist Normal file
View file

@ -0,0 +1,8 @@
<?xml version="1.0" encoding="UTF-8"?>
<!DOCTYPE plist PUBLIC "-//Apple//DTD PLIST 1.0//EN" "http://www.apple.com/DTDs/PropertyList-1.0.dtd">
<plist version="1.0">
<dict>
<key>OPENAI_SECRET</key>
<string>NICE_TRY</string>
</dict>
</plist>

View file

@ -0,0 +1,97 @@
import Foundation
public struct OpenAIClient {
private let endpoint: URL = URL(string: "https://api.openai.com/v1/completions")!
private var APIKey: String {
if let path = Bundle.main.path(forResource: "Secret", ofType: "plist") {
let secret = NSDictionary(contentsOfFile: path)
return secret?["OPENAI_SECRET"] as? String ?? ""
}
return ""
}
private var authorizationHeaderValue: String {
"Bearer \(APIKey)"
}
private var encoder: JSONEncoder {
let encoder = JSONEncoder()
encoder.keyEncodingStrategy = .convertToSnakeCase
return encoder
}
private var decoder: JSONDecoder {
let decoder = JSONDecoder()
decoder.keyDecodingStrategy = .convertFromSnakeCase
return decoder
}
public struct Request: Encodable {
let model = "text-davinci-003"
let topP: Int = 1
let frequencyPenalty: Int = 0
let presencePenalty: Int = 0
let prompt: String
let temperature: Double
let maxTokens: Int
public init(prompt: String, temperature: Double, maxTokens: Int) {
self.prompt = prompt
self.temperature = temperature
self.maxTokens = maxTokens
}
}
public enum Prompts {
case correct(input: String)
case shorten(input: String)
case emphasize(input: String)
var request: Request {
switch self {
case let .correct(input):
return Request(prompt: "Correct this to standard English:\(input)",
temperature: 0,
maxTokens: 500)
case let .shorten(input):
return Request(prompt: "Make a summary of this paragraph:\(input)",
temperature: 0.7,
maxTokens: 100)
case let .emphasize(input):
return Request(prompt: "Make this paragraph catchy, more fun:\(input)",
temperature: 0.8,
maxTokens: 500)
}
}
}
public struct Response: Decodable {
public struct Choice: Decodable {
public let text: String
}
public let id: String
public let object: String
public let model: String
public let choices: [Choice]
}
public init() { }
public func request(_ prompt: Prompts) async throws -> Response {
do {
let jsonData = try encoder.encode(prompt.request)
var request = URLRequest(url: endpoint)
request.httpMethod = "POST"
request.setValue(authorizationHeaderValue, forHTTPHeaderField: "Authorization")
request.setValue("application/json", forHTTPHeaderField: "Content-Type")
request.httpBody = jsonData
let (result, _) = try await URLSession.shared.data(for: request)
let response = try decoder.decode(Response.self, from: result)
return response
} catch let error {
throw error
}
}
}

View file

@ -0,0 +1,30 @@
import Foundation
import SwiftUI
import Network
enum StatusEditorAIPrompts: CaseIterable {
case correct, fit, emphasize
@ViewBuilder
var label: some View {
switch self {
case .correct:
Label("Correct text", systemImage: "text.badge.checkmark")
case .fit:
Label("Shorten text", systemImage: "text.badge.minus")
case .emphasize:
Label("Emphasize text", systemImage: "text.badge.star")
}
}
func toRequestPrompt(text: String) -> OpenAIClient.Prompts {
switch self {
case .correct:
return .correct(input: text)
case .fit:
return .shorten(input: text)
case .emphasize:
return .emphasize(input: text)
}
}
}

View file

@ -20,6 +20,7 @@ public struct StatusEditorView: View {
@FocusState private var isSpoilerTextFocused: Bool
@State private var isDismissAlertPresented: Bool = false
@State private var isLoadingAIRequest: Bool = false
public init(mode: StatusEditorViewModel.Mode) {
_viewModel = StateObject(wrappedValue: .init(mode: mode))
@ -77,6 +78,10 @@ public struct StatusEditorView: View {
.navigationTitle(viewModel.mode.title)
.navigationBarTitleDisplayMode(.inline)
.toolbar {
ToolbarItem(placement: .navigationBarTrailing) {
AIMenu
.disabled(!viewModel.canPost)
}
ToolbarItem(placement: .navigationBarTrailing) {
Button {
Task {
@ -176,4 +181,35 @@ public struct StatusEditorView: View {
)
}
}
private var AIMenu: some View {
Menu {
ForEach(StatusEditorAIPrompts.allCases, id: \.self) { prompt in
Button {
Task {
isLoadingAIRequest = true
await viewModel.runOpenAI(prompt: prompt.toRequestPrompt(text: viewModel.statusText.string))
isLoadingAIRequest = false
}
} label: {
prompt.label
}
}
if let backup = viewModel.backupStatustext {
Button {
viewModel.replaceTextWith(text: backup.string)
viewModel.backupStatustext = nil
} label: {
Label("Restore previous text", systemImage: "arrow.uturn.right")
}
}
} label: {
if isLoadingAIRequest {
ProgressView()
} else {
Image(systemName: "faxmachine")
}
}
}
}

View file

@ -27,6 +27,7 @@ public class StatusEditorViewModel: ObservableObject {
checkEmbed()
}
}
@Published var backupStatustext: NSAttributedString?
@Published var showPoll: Bool = false
@Published var pollVotingFrequency = PollVotingFrequency.oneVote
@ -88,6 +89,11 @@ public class StatusEditorViewModel: ObservableObject {
statusText = string
selectedRange = NSRange(location: inRange.location + text.utf16.count, length: 0)
}
func replaceTextWith(text: String) {
statusText = .init(string: text)
selectedRange = .init(location: text.utf16.count, length: 0)
}
private func getPollOptionsForAPI() -> [String] {
pollOptions.filter { !$0.trimmingCharacters(in: .whitespaces).isEmpty }
@ -298,6 +304,20 @@ public class StatusEditorViewModel: ObservableObject {
}
}
// MARK: - OpenAI Prompt
func runOpenAI(prompt: OpenAIClient.Prompts) async {
do {
let client = OpenAIClient()
let response = try await client.request(prompt)
if var text = response.choices.first?.text {
text.removeFirst()
text.removeFirst()
backupStatustext = statusText
replaceTextWith(text: text)
}
} catch { }
}
// MARK: - Media related function
private func indexOf(container: ImageContainer) -> Int? {

View file

@ -0,0 +1,6 @@
#!/bin/sh
cd ../IceCubesApp/
plutil -replace OPENAI_SECRET -string $OPENAI_SECRET Secret.plist
plutil -p Secret.plist
exit 0