From 82f6a969d45589751b5687ea3829ea50945a5554 Mon Sep 17 00:00:00 2001 From: David Koski Date: Fri, 1 Mar 2024 14:46:32 -0800 Subject: [PATCH] llm improvements - document the tokenizer used (https://github.com/huggingface/swift-transformers) - provide a hook for tokenizer configuration, prompt augmentation - this isn't as rich as the python equivalents but it helps a little --- Libraries/LLM/Load.swift | 6 +- Libraries/LLM/Models.swift | 107 ++++++++++++++++++ Libraries/LLM/README.md | 17 ++- Libraries/LLM/Tokenizer.swift | 4 +- Tools/llm-tool/LLMTool.swift | 32 +++--- mlx-swift-examples.xcodeproj/project.pbxproj | 4 + .../xcshareddata/swiftpm/Package.resolved | 2 +- .../xcshareddata/xcschemes/llm-tool.xcscheme | 100 ++++++++++++++++ 8 files changed, 250 insertions(+), 22 deletions(-) create mode 100644 Libraries/LLM/Models.swift create mode 100644 mlx-swift-examples.xcodeproj/xcshareddata/xcschemes/llm-tool.xcscheme diff --git a/Libraries/LLM/Load.swift b/Libraries/LLM/Load.swift index add4552..329ed1d 100644 --- a/Libraries/LLM/Load.swift +++ b/Libraries/LLM/Load.swift @@ -14,13 +14,13 @@ struct LLMError: Error { /// Load and return the model and tokenizer public func load( - hub: HubApi = HubApi(), name: String, progressHandler: @escaping (Progress) -> Void = { _ in } + hub: HubApi = HubApi(), configuration: ModelConfiguration, progressHandler: @escaping (Progress) -> Void = { _ in } ) async throws -> (LLMModel, Tokenizer) { // note: this doesn't have a way to pass the HubApi - let tokenizer = try await loadTokenizer(name: name) + let tokenizer = try await loadTokenizer(configuration: configuration) // download the model weights and config - let repo = Hub.Repo(id: name) + let repo = Hub.Repo(id: configuration.id) let modelFiles = ["config.json", "*.safetensors"] let modelDirectory = try await hub.snapshot( from: repo, matching: modelFiles, progressHandler: progressHandler) diff --git a/Libraries/LLM/Models.swift b/Libraries/LLM/Models.swift new file mode 100644 index 0000000..5db31ba --- /dev/null +++ b/Libraries/LLM/Models.swift @@ -0,0 +1,107 @@ +// Copyright © 2024 Apple Inc. + +import Foundation + +/// Registry of models and and any overrides that go with them, e.g. prompt augmentation. +/// If asked for an unknown configuration this will use the model/tokenizer as-is. +/// +/// The python tokenizers have a very rich set of implementations and configuration. The +/// swift-tokenizers code handles a good chunk of that and this is a place to augment that +/// implementation, if needed. +public struct ModelConfiguration { + public let id: String + + /// overrides for TokenizerModel/knownTokenizers -- useful before swift-transformers is updated + public let overrideTokenizer: String? + + /// custom preparation logic for the prompt. custom tokenizers provide more capability, but this + /// allows some minor formtting changes, e.g. wrapping the user input in the expected prompt + /// format + private let preparePrompt: ((String) -> String)? + + public init(id: String, overrideTokenizer: String? = nil, preparePrompt: ((String) -> String)? = nil) { + self.id = id + self.overrideTokenizer = overrideTokenizer + self.preparePrompt = preparePrompt + } + + public func prepare(prompt: String) -> String { + preparePrompt?(prompt) ?? prompt + } + + public static var registry = [String:ModelConfiguration]() + + public static func register(configurations: [ModelConfiguration]) { + bootstrap() + + for c in configurations { + registry[c.id] = c + } + } + + public static func configuration(id: String) -> ModelConfiguration { + bootstrap() + + if let c = registry[id] { + return c + } else { + return ModelConfiguration(id: id) + } + } +} + +extension ModelConfiguration { + + static let mistral7B4bit = ModelConfiguration(id: "mlx-community/Mistral-7B-v0.1-hf-4bit-mlx") + + static let codeLlama13b4bit = ModelConfiguration( + id: "mlx-community/CodeLlama-13b-Instruct-hf-4bit-MLX", + overrideTokenizer: "PreTrainedTokenizer") + { prompt in + // given the prompt: func sortArray(_ array: [Int]) -> String { } + // the python code produces this (via its custom tokenizer): + //
 func sortArray(_ array: [Int]) -> String {   } 
+        
+        "
 " +
+        prompt.replacingOccurrences(of: "", with: "") +
+        " "
+    }
+    
+    static let phi4bit = ModelConfiguration(id: "mlx-community/phi-2-hf-4bit-mlx") { prompt in
+        "Instruct: \(prompt). Output: "
+    }
+    
+    static let gemma2bQuantized = ModelConfiguration(
+        id: "mlx-community/quantized-gemma-2b-it",
+        overrideTokenizer: "PreTrainedTokenizer") { prompt in
+            "user \(prompt)model"
+        }
+
+    private enum BootstrapState {
+        case idle
+        case bootstrapping
+        case bootstrapped
+    }
+    
+    static private var bootstrapState = BootstrapState.idle
+    
+    static func bootstrap() {
+        switch bootstrapState {
+        case .idle:
+            bootstrapState = .bootstrapping
+            register(configurations: [
+                mistral7B4bit,
+                codeLlama13b4bit,
+                phi4bit,
+                gemma2bQuantized,
+            ])
+            bootstrapState = .bootstrapped
+
+        case .bootstrapping:
+            break
+            
+        case .bootstrapped:
+            break
+        }
+    }
+}
diff --git a/Libraries/LLM/README.md b/Libraries/LLM/README.md
index 482e661..c852fee 100644
--- a/Libraries/LLM/README.md
+++ b/Libraries/LLM/README.md
@@ -4,9 +4,22 @@ This is a port of several models from:
 
 - https://github.com/ml-explore/mlx-examples/blob/main/llms/mlx_lm/models/
 
-You can use this to load models from huggingface, e.g.:
+using the Hugging Face swift transformers package to provide tokenization:
 
-- https://huggingface.co/mlx-community/Mistral-7B-v0.1-hf-4bit-mlx
+https://github.com/huggingface/swift-transformers
+
+The [Models.swift](Models.swift) provides minor overrides and customization --
+if you require overrides for the tokenizer or prompt customizations they can be
+added there.
+
+This is set up to load models from Hugging Face, e.g. https://huggingface.co/mlx-community
+
+The following models have been tried:
+
+- mlx-community/Mistral-7B-v0.1-hf-4bit-mlx
+- mlx-community/CodeLlama-13b-Instruct-hf-4bit-MLX
+- mlx-community/phi-2-hf-4bit-mlx
+- mlx-community/quantized-gemma-2b-it
 
 Currently supported model types are:
 
diff --git a/Libraries/LLM/Tokenizer.swift b/Libraries/LLM/Tokenizer.swift
index b85a990..e8c6446 100644
--- a/Libraries/LLM/Tokenizer.swift
+++ b/Libraries/LLM/Tokenizer.swift
@@ -49,9 +49,9 @@ public struct Tokenizer: Tokenizers.Tokenizer {
 
 }
 
-public func loadTokenizer(name: String) async throws -> Tokenizer {
+public func loadTokenizer(configuration: ModelConfiguration) async throws -> Tokenizer {
     // from AutoTokenizer.from() -- this lets us override parts of the configuration
-    let config = LanguageModelConfigurationFromHub(modelName: name)
+    let config = LanguageModelConfigurationFromHub(modelName: configuration.id)
     guard var tokenizerConfig = try await config.tokenizerConfig else {
         throw LLMError(message: "missing config")
     }
diff --git a/Tools/llm-tool/LLMTool.swift b/Tools/llm-tool/LLMTool.swift
index d2d6495..9a79a41 100644
--- a/Tools/llm-tool/LLMTool.swift
+++ b/Tools/llm-tool/LLMTool.swift
@@ -25,7 +25,7 @@ struct SyncGenerator: AsyncParsableCommand {
     var model: String = "mlx-community/Mistral-7B-v0.1-hf-4bit-mlx"
 
     @Option(name: .shortAndLong, help: "The message to be processed by the model")
-    var prompt = "compare swift and python"
+    var prompt = "compare python and swift"
 
     @Option(name: .shortAndLong, help: "Maximum number of tokens to generate")
     var maxTokens = 100
@@ -40,22 +40,24 @@ struct SyncGenerator: AsyncParsableCommand {
     func run() async throws {
         MLXRandom.seed(seed)
 
-        let (model, tokenizer) = try await load(name: model)
+        let modelConfiguration = ModelConfiguration.configuration(id: model)
+        let (model, tokenizer) = try await load(configuration: modelConfiguration)
+        
+        let prompt = modelConfiguration.prepare(prompt: self.prompt)
+        let promptTokens = tokenizer.encode(text: prompt)
 
         print("Starting generation ...")
-        print(prompt, terminator: "")
+        print(self.prompt, terminator: "")
 
         var start = Date.timeIntervalSinceReferenceDate
         var promptTime: TimeInterval = 0
 
-        let prompt = MLXArray(tokenizer.encode(text: prompt))
-
         // collect the tokens and keep track of how much of the string
         // we have printed already
         var tokens = [Int]()
         var printed = 0
 
-        for token in TokenIterator(prompt: prompt, model: model, temp: temperature) {
+        for token in TokenIterator(prompt: MLXArray(promptTokens), model: model, temp: temperature) {
             if tokens.isEmpty {
                 eval(token)
                 let now = Date.timeIntervalSinceReferenceDate
@@ -90,7 +92,7 @@ struct SyncGenerator: AsyncParsableCommand {
 
         print(
             """
-            Prompt Tokens per second:     \((Double(prompt.size) / promptTime).formatted())
+            Prompt Tokens per second:     \((Double(promptTokens.count) / promptTime).formatted())
             Generation tokens per second: \((Double(tokens.count - 1) / generateTime).formatted())
             """)
     }
@@ -111,7 +113,7 @@ struct AsyncGenerator: AsyncParsableCommand {
     var model: String = "mlx-community/Mistral-7B-v0.1-hf-4bit-mlx"
 
     @Option(name: .shortAndLong, help: "The message to be processed by the model")
-    var prompt = "compare swift and python"
+    var prompt = "compare python and swift"
 
     @Option(name: .shortAndLong, help: "Maximum number of tokens to generate")
     var maxTokens = 100
@@ -126,22 +128,24 @@ struct AsyncGenerator: AsyncParsableCommand {
     func run() async throws {
         MLXRandom.seed(seed)
 
-        let (model, tokenizer) = try await load(name: model)
+        let modelConfiguration = ModelConfiguration.configuration(id: model)
+        let (model, tokenizer) = try await load(configuration: modelConfiguration)
+        
+        let prompt = modelConfiguration.prepare(prompt: self.prompt)
+        let promptTokens = tokenizer.encode(text: prompt)
 
         print("Starting generation ...")
-        print(prompt, terminator: "")
+        print(self.prompt, terminator: "")
 
         var start = Date.timeIntervalSinceReferenceDate
         var promptTime: TimeInterval = 0
 
-        let prompt = MLXArray(tokenizer.encode(text: prompt))
-
         // collect the tokens and keep track of how much of the string
         // we have printed already
         var tokens = [Int]()
         var printed = 0
 
-        let (task, channel) = generate(prompt: prompt, model: model, temp: temperature)
+        let (task, channel) = generate(prompt: MLXArray(promptTokens), model: model, temp: temperature)
 
         for await token in channel {
             if tokens.isEmpty {
@@ -179,7 +183,7 @@ struct AsyncGenerator: AsyncParsableCommand {
 
         print(
             """
-            Prompt Tokens per second:     \((Double(prompt.size) / promptTime).formatted())
+            Prompt Tokens per second:     \((Double(promptTokens.count) / promptTime).formatted())
             Generation tokens per second: \((Double(tokens.count - 1) / generateTime).formatted())
             """)
 
diff --git a/mlx-swift-examples.xcodeproj/project.pbxproj b/mlx-swift-examples.xcodeproj/project.pbxproj
index 4f68254..ce8ac4d 100644
--- a/mlx-swift-examples.xcodeproj/project.pbxproj
+++ b/mlx-swift-examples.xcodeproj/project.pbxproj
@@ -36,6 +36,7 @@
 		C3932D572B6A060B00A81055 /* MNIST.swift in Sources */ = {isa = PBXBuildFile; fileRef = C3932D562B6A060B00A81055 /* MNIST.swift */; };
 		C3932D592B6A0BE400A81055 /* Random.swift in Sources */ = {isa = PBXBuildFile; fileRef = C3932D582B6A0BE400A81055 /* Random.swift */; };
 		C397C59C2B62C6D0004B084D /* ArgumentParser in Frameworks */ = {isa = PBXBuildFile; productRef = C397C59B2B62C6D0004B084D /* ArgumentParser */; };
+		C3A8B3AC2B9283150002EFB8 /* Models.swift in Sources */ = {isa = PBXBuildFile; fileRef = C3A8B3AB2B9283150002EFB8 /* Models.swift */; };
 		C3E786AB2B8D1AEC0004D037 /* Evaluate.swift in Sources */ = {isa = PBXBuildFile; fileRef = C3E786AA2B8D1AEC0004D037 /* Evaluate.swift */; };
 		C3E786AD2B8D4AF50004D037 /* Tokenizer.swift in Sources */ = {isa = PBXBuildFile; fileRef = C3E786AC2B8D4AF50004D037 /* Tokenizer.swift */; };
 		C3FBCB212B8520B80007E490 /* MLX in Frameworks */ = {isa = PBXBuildFile; productRef = C3FBCB202B8520B80007E490 /* MLX */; };
@@ -152,6 +153,7 @@
 		C3932D562B6A060B00A81055 /* MNIST.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = MNIST.swift; sourceTree = ""; };
 		C3932D582B6A0BE400A81055 /* Random.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = Random.swift; sourceTree = ""; };
 		C397C58B2B62C6A9004B084D /* llm-tool */ = {isa = PBXFileReference; explicitFileType = "compiled.mach-o.executable"; includeInIndex = 0; path = "llm-tool"; sourceTree = BUILT_PRODUCTS_DIR; };
+		C3A8B3AB2B9283150002EFB8 /* Models.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = Models.swift; sourceTree = ""; };
 		C3C3240B2B6CA689007D2D9A /* README.md */ = {isa = PBXFileReference; lastKnownFileType = net.daringfireball.markdown; path = README.md; sourceTree = ""; };
 		C3C3240C2B6CA792007D2D9A /* README.md */ = {isa = PBXFileReference; lastKnownFileType = net.daringfireball.markdown; path = README.md; sourceTree = ""; };
 		C3E786AA2B8D1AEC0004D037 /* Evaluate.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = Evaluate.swift; sourceTree = ""; };
@@ -267,6 +269,7 @@
 			isa = PBXGroup;
 			children = (
 				C34E48EF2B696E6500FCB841 /* Configuration.swift */,
+				C3A8B3AB2B9283150002EFB8 /* Models.swift */,
 				C34E48EE2B696E6500FCB841 /* Llama.swift */,
 				C38935E22B86C0FE0037B833 /* Gemma.swift */,
 				C38935C72B869C7A0037B833 /* LLM.h */,
@@ -614,6 +617,7 @@
 				C38935DF2B869DD00037B833 /* Phi.swift in Sources */,
 				C38935CE2B869C870037B833 /* Load.swift in Sources */,
 				C3E786AD2B8D4AF50004D037 /* Tokenizer.swift in Sources */,
+				C3A8B3AC2B9283150002EFB8 /* Models.swift in Sources */,
 				C3E786AB2B8D1AEC0004D037 /* Evaluate.swift in Sources */,
 				C38935CC2B869C870037B833 /* Llama.swift in Sources */,
 			);
diff --git a/mlx-swift-examples.xcodeproj/project.xcworkspace/xcshareddata/swiftpm/Package.resolved b/mlx-swift-examples.xcodeproj/project.xcworkspace/xcshareddata/swiftpm/Package.resolved
index 7e747b3..1386059 100644
--- a/mlx-swift-examples.xcodeproj/project.xcworkspace/xcshareddata/swiftpm/Package.resolved
+++ b/mlx-swift-examples.xcodeproj/project.xcworkspace/xcshareddata/swiftpm/Package.resolved
@@ -15,7 +15,7 @@
       "location" : "https://github.com/ml-explore/mlx-swift",
       "state" : {
         "branch" : "main",
-        "revision" : "fbe215ae29ec286bdc66f4b3423e3eea1b3ef2fc"
+        "revision" : "83bd19f9fe93d77d9f89981eeead2d6d190afdba"
       }
     },
     {
diff --git a/mlx-swift-examples.xcodeproj/xcshareddata/xcschemes/llm-tool.xcscheme b/mlx-swift-examples.xcodeproj/xcshareddata/xcschemes/llm-tool.xcscheme
new file mode 100644
index 0000000..0200f91
--- /dev/null
+++ b/mlx-swift-examples.xcodeproj/xcshareddata/xcschemes/llm-tool.xcscheme
@@ -0,0 +1,100 @@
+
+
+   
+      
+         
+            
+            
+         
+      
+   
+   
+   
+   
+      
+         
+         
+      
+      
+         
+         
+         
+         
+         
+         
+         
+         
+         
+         
+      
+   
+   
+      
+         
+         
+      
+   
+   
+   
+   
+   
+