diff --git a/Applications/LLMEval/ContentView.swift b/Applications/LLMEval/ContentView.swift index 02e4478..52821d5 100644 --- a/Applications/LLMEval/ContentView.swift +++ b/Applications/LLMEval/ContentView.swift @@ -10,7 +10,7 @@ import Tokenizers struct ContentView: View { - @State var prompt = "compare python and swift" + @State var prompt = "" @State var llm = LLMEvaluator() @Environment(DeviceStat.self) private var deviceStat @@ -125,6 +125,8 @@ struct ContentView: View { } .task { + self.prompt = llm.modelConfiguration.defaultPrompt + // pre-load the weights on launch to speed up the first generation _ = try? await llm.load() } @@ -224,7 +226,7 @@ class LLMEvaluator { let result = await LLM.generate( promptTokens: promptTokens, parameters: generateParameters, model: model, - tokenizer: tokenizer + tokenizer: tokenizer, extraEOSTokens: modelConfiguration.extraEOSTokens ) { tokens in // update the output -- this will make the view show the text as it generates if tokens.count % displayEveryNTokens == 0 { diff --git a/Applications/LoRATrainingExample/ContentView.swift b/Applications/LoRATrainingExample/ContentView.swift index e72d8e0..03eecc3 100644 --- a/Applications/LoRATrainingExample/ContentView.swift +++ b/Applications/LoRATrainingExample/ContentView.swift @@ -266,6 +266,7 @@ class LoRAEvaluator { let result = await LLM.generate( promptTokens: promptTokens, parameters: generateParameters, model: model, tokenizer: tokenizer, + extraEOSTokens: modelConfiguration.extraEOSTokens, didGenerate: { tokens in if tokens.count % evaluateShowEvery == 0 { let fullOutput = tokenizer.decode(tokens: tokens) diff --git a/Libraries/LLM/Evaluate.swift b/Libraries/LLM/Evaluate.swift index 5d870a1..f25d5b0 100644 --- a/Libraries/LLM/Evaluate.swift +++ b/Libraries/LLM/Evaluate.swift @@ -12,7 +12,7 @@ private func topPSampling(logits: MLXArray, topP: Float, temp: Float) -> MLXArra logits = logits.asType(.float32) } - let probs = softMax(logits / temp, axis: -1) + let probs = softmax(logits / temp, axis: -1) let sortedIndices = argSort(probs, axis: -1) // probs shape is [B,V] and after take it will be [1, B, V], so we squeeze it back to [B, V] @@ -31,7 +31,7 @@ private func applyRepetitionPenalty( ) -> MLXArray { if repetitionContext.shape[0] > 0 { let indices = repetitionContext - var selectedLogits = take(logits, indices, axis: -1).squeezed(axis: 0) + var selectedLogits = logits[0..., indices] selectedLogits = MLX.where( selectedLogits .< 0, selectedLogits * penalty, selectedLogits / penalty) @@ -100,7 +100,7 @@ public struct TokenIterator: Sequence, IteratorProtocol { if prompt.shape[0] <= parameters.repetitionContextSize { self.repetitionContext = prompt } else { - self.repetitionContext = prompt[-parameters.repetitionContextSize ... -1] + self.repetitionContext = prompt[(-parameters.repetitionContextSize)...] } } else { self.repetitionContext = [] @@ -120,9 +120,8 @@ public struct TokenIterator: Sequence, IteratorProtocol { y = sample(logits: logits, temp: parameters.temperature, topP: parameters.topP) // append the current token to the context and check repetitionPenalty context see if need to remove the first token if parameters.repetitionContextSize > 1 { - repetitionContext = concatenated([repetitionContext, y], axis: 0) if repetitionContext.shape[0] > parameters.repetitionContextSize { - repetitionContext = repetitionContext[1...] + repetitionContext = repetitionContext[(-parameters.repetitionContextSize)...] } } @@ -174,14 +173,31 @@ public enum GenerateDisposition { /// - parameters: generation parameters /// - model: model to evaluate /// - tokenizer: tokenizer to convert tokens back into strings and recognizer special tokens +/// - configuration: the model configuration /// - didGenerate: visitor for the tokens as they are generated public func generate( promptTokens: [Int], parameters: GenerateParameters, model: LLMModel, tokenizer: Tokenizer, + extraEOSTokens: Set? = nil, didGenerate: ([Int]) async -> GenerateDisposition ) async -> GenerateResult { var start = Date.timeIntervalSinceReferenceDate var promptTime: TimeInterval = 0 + // build a set of additional stop tokens + let additionalEOSTokenIds = Set( + (extraEOSTokens ?? []) + .map { + tokenizer.encode(text: $0) + } + .filter { + // discard anything that is not a single token. sometimes + // the tokenizer will insert a token, so accept that too + $0.count == 1 || ($0.count == 2 && $0[0] == 1) + } + .map { + $0.last! + }) + var tokens = [Int]() for token in TokenIterator( @@ -196,7 +212,9 @@ public func generate( } let t = token.item(Int.self) - if t == tokenizer.unknownTokenId || t == tokenizer.eosTokenId { + if t == tokenizer.unknownTokenId || t == tokenizer.eosTokenId + || additionalEOSTokenIds.contains(t) + { break } diff --git a/Libraries/LLM/LLMModel.swift b/Libraries/LLM/LLMModel.swift index 885dce3..241cbba 100644 --- a/Libraries/LLM/LLMModel.swift +++ b/Libraries/LLM/LLMModel.swift @@ -12,4 +12,15 @@ public protocol LLMModel: Module { func callAsFunction(_ inputs: MLXArray, cache: [(MLXArray, MLXArray)]?) -> ( MLXArray, [(MLXArray, MLXArray)] ) + + /// Optionally preprocess the weights and modify / remove values as needed. + func sanitize(weights: [String: MLXArray]) -> [String: MLXArray] +} + +extension LLMModel { + + public func sanitize(weights: [String: MLXArray]) -> [String: MLXArray] { + weights + } + } diff --git a/Libraries/LLM/Llama.swift b/Libraries/LLM/Llama.swift index c74d636..9e07cc4 100644 --- a/Libraries/LLM/Llama.swift +++ b/Libraries/LLM/Llama.swift @@ -194,6 +194,13 @@ public class LlamaModel: Module, LLMModel { let (out, cache) = model(inputs, cache: cache) return (lmHead(out), cache) } + + public func sanitize(weights: [String: MLXArray]) -> [String: MLXArray] { + // Remove unused precomputed rotary freqs + weights.filter { + !$0.key.contains("self_attn.rotary_emb.inv_freq") + } + } } public struct LlamaConfiguration: Codable { diff --git a/Libraries/LLM/Load.swift b/Libraries/LLM/Load.swift index 4d9849a..8d25b24 100644 --- a/Libraries/LLM/Load.swift +++ b/Libraries/LLM/Load.swift @@ -54,9 +54,15 @@ public func load( } } + // per-model cleanup + weights = model.sanitize(weights: weights) + // quantize if needed if let quantization = baseConfig.quantization { - quantizeIfNeeded(model: model, weights: weights, quantization: quantization) + quantize(model: model, groupSize: quantization.groupSize, bits: quantization.bits) { + path, module in + weights["\(path).scales"] != nil + } } // apply the loaded weights @@ -76,38 +82,3 @@ public func load( hub: hub, configuration: newConfiguration, progressHandler: progressHandler) } } - -// MARK: - Quantization - -private func quantizeIfNeeded( - model: LLMModel, weights: [String: MLXArray], quantization: BaseConfiguration.Quantization -) { - - func linearPredicate(layer: Module) -> Bool { - if let layer = layer as? Linear { - // avoid quantizing gate layers, otherwise we have to re-quant and upload all the mixtral models - return layer.weight.dim(0) != 8 - } - return false - } - - var predicate = linearPredicate(layer:) - - // for legacy models that don't have lm_head quant due to non-32 dims - if weights["lm_head.scales"] == nil { - let vocabularySize = model.vocabularySize - - func vocabularySizePredicate(layer: Module) -> Bool { - if let layer = layer as? Linear { - return layer.weight.dim(0) != 8 && layer.weight.dim(0) != vocabularySize - } - return false - } - - predicate = vocabularySizePredicate(layer:) - } - - QuantizedLinear.quantize( - model: model, groupSize: quantization.groupSize, bits: quantization.bits, - predicate: predicate) -} diff --git a/Libraries/LLM/Lora.swift b/Libraries/LLM/Lora.swift index 5798108..81438c9 100644 --- a/Libraries/LLM/Lora.swift +++ b/Libraries/LLM/Lora.swift @@ -377,7 +377,7 @@ public enum LoRATrain { /// - training with ``train(model:train:validate:optimizer:loss:tokenizer:parameters:progress:)`` /// - loss evaluation with ``evaluate(model:dataset:loss:tokenizer:batchSize:batchCount:)`` /// - fusing with ``fuse(model:layers:deQuantize:)`` - /// - text generation with ``generate(promptTokens:parameters:model:tokenizer:didGenerate:)`` + /// - text generation with ``generate(promptTokens:parameters:model:tokenizer:additionalEOSTokens:didGenerate:)`` /// - note that this is just using normal model text generation /// /// - Parameters: diff --git a/Libraries/LLM/Models.swift b/Libraries/LLM/Models.swift index 57573b0..01ab42a 100644 --- a/Libraries/LLM/Models.swift +++ b/Libraries/LLM/Models.swift @@ -33,6 +33,12 @@ public struct ModelConfiguration { /// overrides for TokenizerModel/knownTokenizers -- useful before swift-transformers is updated public let overrideTokenizer: String? + /// A reasonable default prompt for the model + public let defaultPrompt: String + + /// Additional tokens to use for end of string + public let extraEOSTokens: Set + /// custom preparation logic for the prompt. custom tokenizers provide more capability, but this /// allows some minor formtting changes, e.g. wrapping the user input in the expected prompt /// format @@ -40,21 +46,29 @@ public struct ModelConfiguration { public init( id: String, tokenizerId: String? = nil, overrideTokenizer: String? = nil, + defaultPrompt: String = "hello", + extraEOSTokens: Set = [], preparePrompt: ((String) -> String)? = nil ) { self.id = .id(id) self.tokenizerId = tokenizerId self.overrideTokenizer = overrideTokenizer + self.defaultPrompt = defaultPrompt + self.extraEOSTokens = extraEOSTokens self.preparePrompt = preparePrompt } public init( directory: URL, tokenizerId: String? = nil, overrideTokenizer: String? = nil, + defaultPrompt: String = "hello", + extraEOSTokens: Set = [], preparePrompt: ((String) -> String)? = nil ) { self.id = .directory(directory) self.tokenizerId = tokenizerId self.overrideTokenizer = overrideTokenizer + self.defaultPrompt = defaultPrompt + self.extraEOSTokens = extraEOSTokens self.preparePrompt = preparePrompt } @@ -98,11 +112,16 @@ public struct ModelConfiguration { extension ModelConfiguration { public static let mistral7B4bit = ModelConfiguration( - id: "mlx-community/Mistral-7B-v0.1-hf-4bit-mlx") + id: "mlx-community/Mistral-7B-v0.1-hf-4bit-mlx", + + // https://www.promptingguide.ai/models/mistral-7b + defaultPrompt: "describe the swift language" + ) public static let codeLlama13b4bit = ModelConfiguration( id: "mlx-community/CodeLlama-13b-Instruct-hf-4bit-MLX", - overrideTokenizer: "PreTrainedTokenizer" + overrideTokenizer: "PreTrainedTokenizer", + defaultPrompt: "func sortArray(_ array: [Int]) -> String { }" ) { prompt in // given the prompt: func sortArray(_ array: [Int]) -> String { } // the python code produces this (via its custom tokenizer): @@ -111,13 +130,17 @@ extension ModelConfiguration { "
 " + prompt.replacingOccurrences(of: "", with: "") + " "
     }
 
-    public static let phi4bit = ModelConfiguration(id: "mlx-community/phi-2-hf-4bit-mlx") {
-        prompt in
-        "Instruct: \(prompt)\nOutput: "
-    }
+    public static let phi4bit = ModelConfiguration(
+        id: "mlx-community/phi-2-hf-4bit-mlx",
+
+        // https://www.promptingguide.ai/models/phi-2
+        defaultPrompt: "Why is the sky blue?"
+    )
 
     public static let phi34bit = ModelConfiguration(
-        id: "mlx-community/Phi-3-mini-4k-instruct-4bit-no-q-embed"
+        id: "mlx-community/Phi-3-mini-4k-instruct-4bit-no-q-embed",
+        defaultPrompt: "what is the gravity on mars and the moon?",
+        extraEOSTokens: ["<|end|>"]
     ) {
         prompt in
         "<|user|>\n\(prompt)<|end|>\n<|assistant|>\n"
@@ -125,26 +148,35 @@ extension ModelConfiguration {
 
     public static let gemma2bQuantized = ModelConfiguration(
         id: "mlx-community/quantized-gemma-2b-it",
-        overrideTokenizer: "PreTrainedTokenizer"
+        overrideTokenizer: "PreTrainedTokenizer",
+
+        // https://www.promptingguide.ai/models/gemma
+        defaultPrompt: "what is the difference between lettuce and cabbage?"
+
     ) { prompt in
         "user \(prompt)model"
     }
 
     public static let qwen205b4bit = ModelConfiguration(
         id: "mlx-community/Qwen1.5-0.5B-Chat-4bit",
-        overrideTokenizer: "PreTrainedTokenizer"
+        overrideTokenizer: "PreTrainedTokenizer",
+        defaultPrompt: "why is the sky blue?"
     ) { prompt in
         "<|im_start|>system\nYou are a helpful assistant<|im_end|>\n<|im_start|>user\n\(prompt)<|im_end|>\n<|im_start|>assistant"
     }
 
     public static let openelm270m4bit = ModelConfiguration(
-        id: "mlx-community/OpenELM-270M-Instruct"
+        id: "mlx-community/OpenELM-270M-Instruct",
+
+        // https://huggingface.co/apple/OpenELM
+        defaultPrompt: "Once upon a time there was"
     ) { prompt in
         "\(prompt)"
     }
 
     public static let llama38B4bit = ModelConfiguration(
-        id: "mlx-community/Meta-Llama-3-8B-Instruct-4bit"
+        id: "mlx-community/Meta-Llama-3-8B-Instruct-4bit",
+        defaultPrompt: "what is the difference between a fruit and a vegetable?"
     ) {
         prompt in
         "<|begin_of_text|><|start_header_id|>system<|end_header_id|>\nYou are a helpful assistant<|eot_id|>\n<|start_header_id|>user<|end_header_id|>\n\(prompt)<|eot_id|>\n<|start_header_id|>assistant<|end_header_id|>"
diff --git a/Libraries/LLM/Qwen2.swift b/Libraries/LLM/Qwen2.swift
index b867aad..9a478de 100644
--- a/Libraries/LLM/Qwen2.swift
+++ b/Libraries/LLM/Qwen2.swift
@@ -179,10 +179,12 @@ public class Qwen2ModelInner: Module {
 public class Qwen2Model: Module, LLMModel {
     public let vocabularySize: Int
     let model: Qwen2ModelInner
+    let configuration: Qwen2Configuration
 
     @ModuleInfo(key: "lm_head") var lmHead: Linear
 
     public init(_ args: Qwen2Configuration) {
+        self.configuration = args
         self.vocabularySize = args.vocabularySize
         self.model = Qwen2ModelInner(args)
         _lmHead.wrappedValue = Linear(args.hiddenSize, args.vocabularySize, bias: false)
@@ -191,8 +193,26 @@ public class Qwen2Model: Module, LLMModel {
     public func callAsFunction(_ inputs: MLXArray, cache: [(MLXArray, MLXArray)]?) -> (
         MLXArray, [(MLXArray, MLXArray)]
     ) {
-        let (out, cache) = model(inputs, cache: cache)
-        return (lmHead(out), cache)
+        var (out, cache) = model(inputs, cache: cache)
+        if configuration.tieWordEmbeddings {
+            out = model.embedTokens.asLinear(out)
+        } else {
+            out = lmHead(out)
+        }
+        return (out, cache)
+    }
+
+    public func sanitize(weights: [String: MLXArray]) -> [String: MLXArray] {
+        var weights = weights
+
+        if configuration.tieWordEmbeddings {
+            weights["lm_head.weight"] = nil
+        }
+
+        // Remove unused precomputed rotary freqs
+        return weights.filter {
+            !$0.key.contains("self_attn.rotary_emb.inv_freq")
+        }
     }
 }
 
@@ -207,6 +227,7 @@ public struct Qwen2Configuration: Codable {
     var ropeTheta: Float = 1_000_000
     var ropeTraditional: Bool = false
     var ropeScaling: [String: StringOrNumber]? = nil
+    var tieWordEmbeddings = false
 
     enum CodingKeys: String, CodingKey {
         case hiddenSize = "hidden_size"
@@ -219,6 +240,7 @@ public struct Qwen2Configuration: Codable {
         case ropeTheta = "rope_theta"
         case ropeTraditional = "rope_traditional"
         case ropeScaling = "rope_scaling"
+        case tieWordEmbeddings = "tie_word_embeddings"
     }
 
     public init(from decoder: Decoder) throws {
@@ -249,6 +271,8 @@ public struct Qwen2Configuration: Codable {
                 Bool.self, forKey: Qwen2Configuration.CodingKeys.ropeTraditional) ?? false
         self.ropeScaling = try container.decodeIfPresent(
             [String: StringOrNumber].self, forKey: Qwen2Configuration.CodingKeys.ropeScaling)
+        self.tieWordEmbeddings =
+            try container.decodeIfPresent(Bool.self, forKey: .tieWordEmbeddings) ?? false
     }
 }
 
diff --git a/Tools/llm-tool/LLMTool.swift b/Tools/llm-tool/LLMTool.swift
index 7116f32..7b9d47f 100644
--- a/Tools/llm-tool/LLMTool.swift
+++ b/Tools/llm-tool/LLMTool.swift
@@ -44,7 +44,7 @@ struct GenerateArguments: ParsableArguments {
         help:
             "The message to be processed by the model.  Use @path,@path to load from files, e.g. @/tmp/prompt.txt"
     )
-    var prompt = "compare python and swift"
+    var prompt: String?
 
     @Option(name: .shortAndLong, help: "Maximum number of tokens to generate")
     var maxTokens = 100
@@ -73,7 +73,8 @@ struct GenerateArguments: ParsableArguments {
             repetitionContextSize: repetitionContextSize)
     }
 
-    func resolvePrompt() throws -> String {
+    func resolvePrompt(configuration: ModelConfiguration) throws -> String {
+        let prompt = self.prompt ?? configuration.defaultPrompt
         if prompt.hasPrefix("@") {
             let names = prompt.split(separator: ",").map { String($0.dropFirst()) }
             return try names.map { try String(contentsOfFile: $0) }.joined(separator: "\n")
@@ -87,14 +88,17 @@ struct GenerateArguments: ParsableArguments {
     ) {
         MLXRandom.seed(seed)
 
-        let prompt = try resolvePrompt()
+        let prompt = try resolvePrompt(configuration: configuration)
         let preparedPrompt = configuration.prepare(prompt: prompt)
         let promptTokens = tokenizer.encode(text: preparedPrompt)
 
         return (prompt, promptTokens)
     }
 
-    func generate(promptTokens: [Int], model: LLMModel, tokenizer: Tokenizer) async
+    func generate(
+        promptTokens: [Int], model: LLMModel, tokenizer: Tokenizer,
+        extraEOSTokens: Set? = nil
+    ) async
         -> GenerateResult
     {
         // track how much we have printed
@@ -102,7 +106,7 @@ struct GenerateArguments: ParsableArguments {
 
         return await LLM.generate(
             promptTokens: promptTokens, parameters: generateParameters,
-            model: model, tokenizer: tokenizer
+            model: model, tokenizer: tokenizer, extraEOSTokens: extraEOSTokens
         ) { tokens in
 
             // print any new parts of the string
@@ -226,7 +230,8 @@ struct EvaluateCommand: AsyncParsableCommand {
         }
 
         let result = await generate.generate(
-            promptTokens: promptTokens, model: model, tokenizer: tokenizer)
+            promptTokens: promptTokens, model: model, tokenizer: tokenizer,
+            extraEOSTokens: modelConfiguration.extraEOSTokens)
         print()
 
         if !generate.quiet {
diff --git a/Tools/llm-tool/LoraCommands.swift b/Tools/llm-tool/LoraCommands.swift
index f75a701..4c422e5 100644
--- a/Tools/llm-tool/LoraCommands.swift
+++ b/Tools/llm-tool/LoraCommands.swift
@@ -275,7 +275,8 @@ struct LoRAEvalCommand: AsyncParsableCommand {
 
         // generate and print the result
         let _ = await generate.generate(
-            promptTokens: promptTokens, model: model, tokenizer: tokenizer)
+            promptTokens: promptTokens, model: model, tokenizer: tokenizer,
+            extraEOSTokens: modelConfiguration.extraEOSTokens)
         print()
     }
 }
diff --git a/mlx-swift-examples.xcodeproj/project.xcworkspace/xcshareddata/swiftpm/Package.resolved b/mlx-swift-examples.xcodeproj/project.xcworkspace/xcshareddata/swiftpm/Package.resolved
index 9b02ae0..ef218f0 100644
--- a/mlx-swift-examples.xcodeproj/project.xcworkspace/xcshareddata/swiftpm/Package.resolved
+++ b/mlx-swift-examples.xcodeproj/project.xcworkspace/xcshareddata/swiftpm/Package.resolved
@@ -16,7 +16,7 @@
       "location" : "https://github.com/ml-explore/mlx-swift",
       "state" : {
         "branch" : "main",
-        "revision" : "3c802c808d281c191d5f26f37a4f93135d8ca119"
+        "revision" : "d6d9472da5bf7ec2654e8914bd1d15622f45b6a9"
       }
     },
     {
@@ -61,7 +61,7 @@
       "location" : "https://github.com/gonzalezreal/swift-markdown-ui",
       "state" : {
         "branch" : "main",
-        "revision" : "723249a1ba361042812cf785244de94f11f7c8fd"
+        "revision" : "c0daf6eb79d97964180f3113868c990bd1c4a007"
       }
     },
     {