From 2157333905f58252c88014bf23103d0bd62836aa Mon Sep 17 00:00:00 2001 From: David Koski Date: Fri, 1 Mar 2024 14:47:43 -0800 Subject: [PATCH] swift-format! --- Libraries/LLM/Load.swift | 3 ++- Libraries/LLM/Models.swift | 51 ++++++++++++++++++------------------ Tools/llm-tool/LLMTool.swift | 10 ++++--- 3 files changed, 34 insertions(+), 30 deletions(-) diff --git a/Libraries/LLM/Load.swift b/Libraries/LLM/Load.swift index 329ed1d..a4236af 100644 --- a/Libraries/LLM/Load.swift +++ b/Libraries/LLM/Load.swift @@ -14,7 +14,8 @@ struct LLMError: Error { /// Load and return the model and tokenizer public func load( - hub: HubApi = HubApi(), configuration: ModelConfiguration, progressHandler: @escaping (Progress) -> Void = { _ in } + hub: HubApi = HubApi(), configuration: ModelConfiguration, + progressHandler: @escaping (Progress) -> Void = { _ in } ) async throws -> (LLMModel, Tokenizer) { // note: this doesn't have a way to pass the HubApi let tokenizer = try await loadTokenizer(configuration: configuration) diff --git a/Libraries/LLM/Models.swift b/Libraries/LLM/Models.swift index 5db31ba..a731add 100644 --- a/Libraries/LLM/Models.swift +++ b/Libraries/LLM/Models.swift @@ -10,35 +10,37 @@ import Foundation /// implementation, if needed. public struct ModelConfiguration { public let id: String - + /// overrides for TokenizerModel/knownTokenizers -- useful before swift-transformers is updated public let overrideTokenizer: String? - + /// custom preparation logic for the prompt. custom tokenizers provide more capability, but this /// allows some minor formtting changes, e.g. wrapping the user input in the expected prompt /// format private let preparePrompt: ((String) -> String)? - - public init(id: String, overrideTokenizer: String? = nil, preparePrompt: ((String) -> String)? = nil) { + + public init( + id: String, overrideTokenizer: String? = nil, preparePrompt: ((String) -> String)? = nil + ) { self.id = id self.overrideTokenizer = overrideTokenizer self.preparePrompt = preparePrompt } - + public func prepare(prompt: String) -> String { preparePrompt?(prompt) ?? prompt } - - public static var registry = [String:ModelConfiguration]() - + + public static var registry = [String: ModelConfiguration]() + public static func register(configurations: [ModelConfiguration]) { bootstrap() - + for c in configurations { registry[c.id] = c } } - + public static func configuration(id: String) -> ModelConfiguration { bootstrap() @@ -51,40 +53,39 @@ public struct ModelConfiguration { } extension ModelConfiguration { - + static let mistral7B4bit = ModelConfiguration(id: "mlx-community/Mistral-7B-v0.1-hf-4bit-mlx") static let codeLlama13b4bit = ModelConfiguration( id: "mlx-community/CodeLlama-13b-Instruct-hf-4bit-MLX", - overrideTokenizer: "PreTrainedTokenizer") - { prompt in + overrideTokenizer: "PreTrainedTokenizer" + ) { prompt in // given the prompt: func sortArray(_ array: [Int]) -> String { } // the python code produces this (via its custom tokenizer): //
 func sortArray(_ array: [Int]) -> String {   } 
-        
-        "
 " +
-        prompt.replacingOccurrences(of: "", with: "") +
-        " "
+
+        "
 " + prompt.replacingOccurrences(of: "", with: "") + " "
     }
-    
+
     static let phi4bit = ModelConfiguration(id: "mlx-community/phi-2-hf-4bit-mlx") { prompt in
         "Instruct: \(prompt). Output: "
     }
-    
+
     static let gemma2bQuantized = ModelConfiguration(
         id: "mlx-community/quantized-gemma-2b-it",
-        overrideTokenizer: "PreTrainedTokenizer") { prompt in
-            "user \(prompt)model"
-        }
+        overrideTokenizer: "PreTrainedTokenizer"
+    ) { prompt in
+        "user \(prompt)model"
+    }
 
     private enum BootstrapState {
         case idle
         case bootstrapping
         case bootstrapped
     }
-    
+
     static private var bootstrapState = BootstrapState.idle
-    
+
     static func bootstrap() {
         switch bootstrapState {
         case .idle:
@@ -99,7 +100,7 @@ extension ModelConfiguration {
 
         case .bootstrapping:
             break
-            
+
         case .bootstrapped:
             break
         }
diff --git a/Tools/llm-tool/LLMTool.swift b/Tools/llm-tool/LLMTool.swift
index 9a79a41..9e9fcea 100644
--- a/Tools/llm-tool/LLMTool.swift
+++ b/Tools/llm-tool/LLMTool.swift
@@ -42,7 +42,7 @@ struct SyncGenerator: AsyncParsableCommand {
 
         let modelConfiguration = ModelConfiguration.configuration(id: model)
         let (model, tokenizer) = try await load(configuration: modelConfiguration)
-        
+
         let prompt = modelConfiguration.prepare(prompt: self.prompt)
         let promptTokens = tokenizer.encode(text: prompt)
 
@@ -57,7 +57,8 @@ struct SyncGenerator: AsyncParsableCommand {
         var tokens = [Int]()
         var printed = 0
 
-        for token in TokenIterator(prompt: MLXArray(promptTokens), model: model, temp: temperature) {
+        for token in TokenIterator(prompt: MLXArray(promptTokens), model: model, temp: temperature)
+        {
             if tokens.isEmpty {
                 eval(token)
                 let now = Date.timeIntervalSinceReferenceDate
@@ -130,7 +131,7 @@ struct AsyncGenerator: AsyncParsableCommand {
 
         let modelConfiguration = ModelConfiguration.configuration(id: model)
         let (model, tokenizer) = try await load(configuration: modelConfiguration)
-        
+
         let prompt = modelConfiguration.prepare(prompt: self.prompt)
         let promptTokens = tokenizer.encode(text: prompt)
 
@@ -145,7 +146,8 @@ struct AsyncGenerator: AsyncParsableCommand {
         var tokens = [Int]()
         var printed = 0
 
-        let (task, channel) = generate(prompt: MLXArray(promptTokens), model: model, temp: temperature)
+        let (task, channel) = generate(
+            prompt: MLXArray(promptTokens), model: model, temp: temperature)
 
         for await token in channel {
             if tokens.isEmpty {