handle partially quantized models (#76)

* handle partially quantized models

- fix for #53 #71 #69 #74
- in order to test the models
	- I added a default prompt of an appropriate form
	- while working on the model configuration also added additional stop tokens (#74)
- fixed the repetitionPenalty code (#71)
This commit is contained in:
David Koski
2024-05-28 16:35:11 -07:00
committed by GitHub
parent 65f4968e5f
commit 9d74afd119
12 changed files with 139 additions and 67 deletions

View File

@@ -10,7 +10,7 @@ import Tokenizers
struct ContentView: View { struct ContentView: View {
@State var prompt = "compare python and swift" @State var prompt = ""
@State var llm = LLMEvaluator() @State var llm = LLMEvaluator()
@Environment(DeviceStat.self) private var deviceStat @Environment(DeviceStat.self) private var deviceStat
@@ -125,6 +125,8 @@ struct ContentView: View {
} }
.task { .task {
self.prompt = llm.modelConfiguration.defaultPrompt
// pre-load the weights on launch to speed up the first generation // pre-load the weights on launch to speed up the first generation
_ = try? await llm.load() _ = try? await llm.load()
} }
@@ -224,7 +226,7 @@ class LLMEvaluator {
let result = await LLM.generate( let result = await LLM.generate(
promptTokens: promptTokens, parameters: generateParameters, model: model, promptTokens: promptTokens, parameters: generateParameters, model: model,
tokenizer: tokenizer tokenizer: tokenizer, extraEOSTokens: modelConfiguration.extraEOSTokens
) { tokens in ) { tokens in
// update the output -- this will make the view show the text as it generates // update the output -- this will make the view show the text as it generates
if tokens.count % displayEveryNTokens == 0 { if tokens.count % displayEveryNTokens == 0 {

View File

@@ -266,6 +266,7 @@ class LoRAEvaluator {
let result = await LLM.generate( let result = await LLM.generate(
promptTokens: promptTokens, parameters: generateParameters, model: model, promptTokens: promptTokens, parameters: generateParameters, model: model,
tokenizer: tokenizer, tokenizer: tokenizer,
extraEOSTokens: modelConfiguration.extraEOSTokens,
didGenerate: { tokens in didGenerate: { tokens in
if tokens.count % evaluateShowEvery == 0 { if tokens.count % evaluateShowEvery == 0 {
let fullOutput = tokenizer.decode(tokens: tokens) let fullOutput = tokenizer.decode(tokens: tokens)

View File

@@ -12,7 +12,7 @@ private func topPSampling(logits: MLXArray, topP: Float, temp: Float) -> MLXArra
logits = logits.asType(.float32) logits = logits.asType(.float32)
} }
let probs = softMax(logits / temp, axis: -1) let probs = softmax(logits / temp, axis: -1)
let sortedIndices = argSort(probs, axis: -1) let sortedIndices = argSort(probs, axis: -1)
// probs shape is [B,V] and after take it will be [1, B, V], so we squeeze it back to [B, V] // probs shape is [B,V] and after take it will be [1, B, V], so we squeeze it back to [B, V]
@@ -31,7 +31,7 @@ private func applyRepetitionPenalty(
) -> MLXArray { ) -> MLXArray {
if repetitionContext.shape[0] > 0 { if repetitionContext.shape[0] > 0 {
let indices = repetitionContext let indices = repetitionContext
var selectedLogits = take(logits, indices, axis: -1).squeezed(axis: 0) var selectedLogits = logits[0..., indices]
selectedLogits = MLX.where( selectedLogits = MLX.where(
selectedLogits .< 0, selectedLogits * penalty, selectedLogits / penalty) selectedLogits .< 0, selectedLogits * penalty, selectedLogits / penalty)
@@ -100,7 +100,7 @@ public struct TokenIterator: Sequence, IteratorProtocol {
if prompt.shape[0] <= parameters.repetitionContextSize { if prompt.shape[0] <= parameters.repetitionContextSize {
self.repetitionContext = prompt self.repetitionContext = prompt
} else { } else {
self.repetitionContext = prompt[-parameters.repetitionContextSize ... -1] self.repetitionContext = prompt[(-parameters.repetitionContextSize)...]
} }
} else { } else {
self.repetitionContext = [] self.repetitionContext = []
@@ -120,9 +120,8 @@ public struct TokenIterator: Sequence, IteratorProtocol {
y = sample(logits: logits, temp: parameters.temperature, topP: parameters.topP) y = sample(logits: logits, temp: parameters.temperature, topP: parameters.topP)
// append the current token to the context and check repetitionPenalty context see if need to remove the first token // append the current token to the context and check repetitionPenalty context see if need to remove the first token
if parameters.repetitionContextSize > 1 { if parameters.repetitionContextSize > 1 {
repetitionContext = concatenated([repetitionContext, y], axis: 0)
if repetitionContext.shape[0] > parameters.repetitionContextSize { if repetitionContext.shape[0] > parameters.repetitionContextSize {
repetitionContext = repetitionContext[1...] repetitionContext = repetitionContext[(-parameters.repetitionContextSize)...]
} }
} }
@@ -174,14 +173,31 @@ public enum GenerateDisposition {
/// - parameters: generation parameters /// - parameters: generation parameters
/// - model: model to evaluate /// - model: model to evaluate
/// - tokenizer: tokenizer to convert tokens back into strings and recognizer special tokens /// - tokenizer: tokenizer to convert tokens back into strings and recognizer special tokens
/// - configuration: the model configuration
/// - didGenerate: visitor for the tokens as they are generated /// - didGenerate: visitor for the tokens as they are generated
public func generate( public func generate(
promptTokens: [Int], parameters: GenerateParameters, model: LLMModel, tokenizer: Tokenizer, promptTokens: [Int], parameters: GenerateParameters, model: LLMModel, tokenizer: Tokenizer,
extraEOSTokens: Set<String>? = nil,
didGenerate: ([Int]) async -> GenerateDisposition didGenerate: ([Int]) async -> GenerateDisposition
) async -> GenerateResult { ) async -> GenerateResult {
var start = Date.timeIntervalSinceReferenceDate var start = Date.timeIntervalSinceReferenceDate
var promptTime: TimeInterval = 0 var promptTime: TimeInterval = 0
// build a set of additional stop tokens
let additionalEOSTokenIds = Set(
(extraEOSTokens ?? [])
.map {
tokenizer.encode(text: $0)
}
.filter {
// discard anything that is not a single token. sometimes
// the tokenizer will insert a <s> token, so accept that too
$0.count == 1 || ($0.count == 2 && $0[0] == 1)
}
.map {
$0.last!
})
var tokens = [Int]() var tokens = [Int]()
for token in TokenIterator( for token in TokenIterator(
@@ -196,7 +212,9 @@ public func generate(
} }
let t = token.item(Int.self) let t = token.item(Int.self)
if t == tokenizer.unknownTokenId || t == tokenizer.eosTokenId { if t == tokenizer.unknownTokenId || t == tokenizer.eosTokenId
|| additionalEOSTokenIds.contains(t)
{
break break
} }

View File

@@ -12,4 +12,15 @@ public protocol LLMModel: Module {
func callAsFunction(_ inputs: MLXArray, cache: [(MLXArray, MLXArray)]?) -> ( func callAsFunction(_ inputs: MLXArray, cache: [(MLXArray, MLXArray)]?) -> (
MLXArray, [(MLXArray, MLXArray)] MLXArray, [(MLXArray, MLXArray)]
) )
/// Optionally preprocess the weights and modify / remove values as needed.
func sanitize(weights: [String: MLXArray]) -> [String: MLXArray]
}
extension LLMModel {
public func sanitize(weights: [String: MLXArray]) -> [String: MLXArray] {
weights
}
} }

View File

@@ -194,6 +194,13 @@ public class LlamaModel: Module, LLMModel {
let (out, cache) = model(inputs, cache: cache) let (out, cache) = model(inputs, cache: cache)
return (lmHead(out), cache) return (lmHead(out), cache)
} }
public func sanitize(weights: [String: MLXArray]) -> [String: MLXArray] {
// Remove unused precomputed rotary freqs
weights.filter {
!$0.key.contains("self_attn.rotary_emb.inv_freq")
}
}
} }
public struct LlamaConfiguration: Codable { public struct LlamaConfiguration: Codable {

View File

@@ -54,9 +54,15 @@ public func load(
} }
} }
// per-model cleanup
weights = model.sanitize(weights: weights)
// quantize if needed // quantize if needed
if let quantization = baseConfig.quantization { if let quantization = baseConfig.quantization {
quantizeIfNeeded(model: model, weights: weights, quantization: quantization) quantize(model: model, groupSize: quantization.groupSize, bits: quantization.bits) {
path, module in
weights["\(path).scales"] != nil
}
} }
// apply the loaded weights // apply the loaded weights
@@ -76,38 +82,3 @@ public func load(
hub: hub, configuration: newConfiguration, progressHandler: progressHandler) hub: hub, configuration: newConfiguration, progressHandler: progressHandler)
} }
} }
// MARK: - Quantization
private func quantizeIfNeeded(
model: LLMModel, weights: [String: MLXArray], quantization: BaseConfiguration.Quantization
) {
func linearPredicate(layer: Module) -> Bool {
if let layer = layer as? Linear {
// avoid quantizing gate layers, otherwise we have to re-quant and upload all the mixtral models
return layer.weight.dim(0) != 8
}
return false
}
var predicate = linearPredicate(layer:)
// for legacy models that don't have lm_head quant due to non-32 dims
if weights["lm_head.scales"] == nil {
let vocabularySize = model.vocabularySize
func vocabularySizePredicate(layer: Module) -> Bool {
if let layer = layer as? Linear {
return layer.weight.dim(0) != 8 && layer.weight.dim(0) != vocabularySize
}
return false
}
predicate = vocabularySizePredicate(layer:)
}
QuantizedLinear.quantize(
model: model, groupSize: quantization.groupSize, bits: quantization.bits,
predicate: predicate)
}

View File

@@ -377,7 +377,7 @@ public enum LoRATrain {
/// - training with ``train(model:train:validate:optimizer:loss:tokenizer:parameters:progress:)`` /// - training with ``train(model:train:validate:optimizer:loss:tokenizer:parameters:progress:)``
/// - loss evaluation with ``evaluate(model:dataset:loss:tokenizer:batchSize:batchCount:)`` /// - loss evaluation with ``evaluate(model:dataset:loss:tokenizer:batchSize:batchCount:)``
/// - fusing with ``fuse(model:layers:deQuantize:)`` /// - fusing with ``fuse(model:layers:deQuantize:)``
/// - text generation with ``generate(promptTokens:parameters:model:tokenizer:didGenerate:)`` /// - text generation with ``generate(promptTokens:parameters:model:tokenizer:additionalEOSTokens:didGenerate:)``
/// - note that this is just using normal model text generation /// - note that this is just using normal model text generation
/// ///
/// - Parameters: /// - Parameters:

View File

@@ -33,6 +33,12 @@ public struct ModelConfiguration {
/// overrides for TokenizerModel/knownTokenizers -- useful before swift-transformers is updated /// overrides for TokenizerModel/knownTokenizers -- useful before swift-transformers is updated
public let overrideTokenizer: String? public let overrideTokenizer: String?
/// A reasonable default prompt for the model
public let defaultPrompt: String
/// Additional tokens to use for end of string
public let extraEOSTokens: Set<String>
/// custom preparation logic for the prompt. custom tokenizers provide more capability, but this /// custom preparation logic for the prompt. custom tokenizers provide more capability, but this
/// allows some minor formtting changes, e.g. wrapping the user input in the expected prompt /// allows some minor formtting changes, e.g. wrapping the user input in the expected prompt
/// format /// format
@@ -40,21 +46,29 @@ public struct ModelConfiguration {
public init( public init(
id: String, tokenizerId: String? = nil, overrideTokenizer: String? = nil, id: String, tokenizerId: String? = nil, overrideTokenizer: String? = nil,
defaultPrompt: String = "hello",
extraEOSTokens: Set<String> = [],
preparePrompt: ((String) -> String)? = nil preparePrompt: ((String) -> String)? = nil
) { ) {
self.id = .id(id) self.id = .id(id)
self.tokenizerId = tokenizerId self.tokenizerId = tokenizerId
self.overrideTokenizer = overrideTokenizer self.overrideTokenizer = overrideTokenizer
self.defaultPrompt = defaultPrompt
self.extraEOSTokens = extraEOSTokens
self.preparePrompt = preparePrompt self.preparePrompt = preparePrompt
} }
public init( public init(
directory: URL, tokenizerId: String? = nil, overrideTokenizer: String? = nil, directory: URL, tokenizerId: String? = nil, overrideTokenizer: String? = nil,
defaultPrompt: String = "hello",
extraEOSTokens: Set<String> = [],
preparePrompt: ((String) -> String)? = nil preparePrompt: ((String) -> String)? = nil
) { ) {
self.id = .directory(directory) self.id = .directory(directory)
self.tokenizerId = tokenizerId self.tokenizerId = tokenizerId
self.overrideTokenizer = overrideTokenizer self.overrideTokenizer = overrideTokenizer
self.defaultPrompt = defaultPrompt
self.extraEOSTokens = extraEOSTokens
self.preparePrompt = preparePrompt self.preparePrompt = preparePrompt
} }
@@ -98,11 +112,16 @@ public struct ModelConfiguration {
extension ModelConfiguration { extension ModelConfiguration {
public static let mistral7B4bit = ModelConfiguration( public static let mistral7B4bit = ModelConfiguration(
id: "mlx-community/Mistral-7B-v0.1-hf-4bit-mlx") id: "mlx-community/Mistral-7B-v0.1-hf-4bit-mlx",
// https://www.promptingguide.ai/models/mistral-7b
defaultPrompt: "describe the swift language"
)
public static let codeLlama13b4bit = ModelConfiguration( public static let codeLlama13b4bit = ModelConfiguration(
id: "mlx-community/CodeLlama-13b-Instruct-hf-4bit-MLX", id: "mlx-community/CodeLlama-13b-Instruct-hf-4bit-MLX",
overrideTokenizer: "PreTrainedTokenizer" overrideTokenizer: "PreTrainedTokenizer",
defaultPrompt: "func sortArray(_ array: [Int]) -> String { <FILL_ME> }"
) { prompt in ) { prompt in
// given the prompt: func sortArray(_ array: [Int]) -> String { <FILL_ME> } // given the prompt: func sortArray(_ array: [Int]) -> String { <FILL_ME> }
// the python code produces this (via its custom tokenizer): // the python code produces this (via its custom tokenizer):
@@ -111,13 +130,17 @@ extension ModelConfiguration {
"<PRE> " + prompt.replacingOccurrences(of: "<FILL_ME>", with: "<SUF>") + " <MID>" "<PRE> " + prompt.replacingOccurrences(of: "<FILL_ME>", with: "<SUF>") + " <MID>"
} }
public static let phi4bit = ModelConfiguration(id: "mlx-community/phi-2-hf-4bit-mlx") { public static let phi4bit = ModelConfiguration(
prompt in id: "mlx-community/phi-2-hf-4bit-mlx",
"Instruct: \(prompt)\nOutput: "
} // https://www.promptingguide.ai/models/phi-2
defaultPrompt: "Why is the sky blue?"
)
public static let phi34bit = ModelConfiguration( public static let phi34bit = ModelConfiguration(
id: "mlx-community/Phi-3-mini-4k-instruct-4bit-no-q-embed" id: "mlx-community/Phi-3-mini-4k-instruct-4bit-no-q-embed",
defaultPrompt: "what is the gravity on mars and the moon?",
extraEOSTokens: ["<|end|>"]
) { ) {
prompt in prompt in
"<s><|user|>\n\(prompt)<|end|>\n<|assistant|>\n" "<s><|user|>\n\(prompt)<|end|>\n<|assistant|>\n"
@@ -125,26 +148,35 @@ extension ModelConfiguration {
public static let gemma2bQuantized = ModelConfiguration( public static let gemma2bQuantized = ModelConfiguration(
id: "mlx-community/quantized-gemma-2b-it", id: "mlx-community/quantized-gemma-2b-it",
overrideTokenizer: "PreTrainedTokenizer" overrideTokenizer: "PreTrainedTokenizer",
// https://www.promptingguide.ai/models/gemma
defaultPrompt: "what is the difference between lettuce and cabbage?"
) { prompt in ) { prompt in
"<start_of_turn>user \(prompt)<end_of_turn><start_of_turn>model" "<start_of_turn>user \(prompt)<end_of_turn><start_of_turn>model"
} }
public static let qwen205b4bit = ModelConfiguration( public static let qwen205b4bit = ModelConfiguration(
id: "mlx-community/Qwen1.5-0.5B-Chat-4bit", id: "mlx-community/Qwen1.5-0.5B-Chat-4bit",
overrideTokenizer: "PreTrainedTokenizer" overrideTokenizer: "PreTrainedTokenizer",
defaultPrompt: "why is the sky blue?"
) { prompt in ) { prompt in
"<|im_start|>system\nYou are a helpful assistant<|im_end|>\n<|im_start|>user\n\(prompt)<|im_end|>\n<|im_start|>assistant" "<|im_start|>system\nYou are a helpful assistant<|im_end|>\n<|im_start|>user\n\(prompt)<|im_end|>\n<|im_start|>assistant"
} }
public static let openelm270m4bit = ModelConfiguration( public static let openelm270m4bit = ModelConfiguration(
id: "mlx-community/OpenELM-270M-Instruct" id: "mlx-community/OpenELM-270M-Instruct",
// https://huggingface.co/apple/OpenELM
defaultPrompt: "Once upon a time there was"
) { prompt in ) { prompt in
"\(prompt)" "\(prompt)"
} }
public static let llama38B4bit = ModelConfiguration( public static let llama38B4bit = ModelConfiguration(
id: "mlx-community/Meta-Llama-3-8B-Instruct-4bit" id: "mlx-community/Meta-Llama-3-8B-Instruct-4bit",
defaultPrompt: "what is the difference between a fruit and a vegetable?"
) { ) {
prompt in prompt in
"<|begin_of_text|><|start_header_id|>system<|end_header_id|>\nYou are a helpful assistant<|eot_id|>\n<|start_header_id|>user<|end_header_id|>\n\(prompt)<|eot_id|>\n<|start_header_id|>assistant<|end_header_id|>" "<|begin_of_text|><|start_header_id|>system<|end_header_id|>\nYou are a helpful assistant<|eot_id|>\n<|start_header_id|>user<|end_header_id|>\n\(prompt)<|eot_id|>\n<|start_header_id|>assistant<|end_header_id|>"

View File

@@ -179,10 +179,12 @@ public class Qwen2ModelInner: Module {
public class Qwen2Model: Module, LLMModel { public class Qwen2Model: Module, LLMModel {
public let vocabularySize: Int public let vocabularySize: Int
let model: Qwen2ModelInner let model: Qwen2ModelInner
let configuration: Qwen2Configuration
@ModuleInfo(key: "lm_head") var lmHead: Linear @ModuleInfo(key: "lm_head") var lmHead: Linear
public init(_ args: Qwen2Configuration) { public init(_ args: Qwen2Configuration) {
self.configuration = args
self.vocabularySize = args.vocabularySize self.vocabularySize = args.vocabularySize
self.model = Qwen2ModelInner(args) self.model = Qwen2ModelInner(args)
_lmHead.wrappedValue = Linear(args.hiddenSize, args.vocabularySize, bias: false) _lmHead.wrappedValue = Linear(args.hiddenSize, args.vocabularySize, bias: false)
@@ -191,8 +193,26 @@ public class Qwen2Model: Module, LLMModel {
public func callAsFunction(_ inputs: MLXArray, cache: [(MLXArray, MLXArray)]?) -> ( public func callAsFunction(_ inputs: MLXArray, cache: [(MLXArray, MLXArray)]?) -> (
MLXArray, [(MLXArray, MLXArray)] MLXArray, [(MLXArray, MLXArray)]
) { ) {
let (out, cache) = model(inputs, cache: cache) var (out, cache) = model(inputs, cache: cache)
return (lmHead(out), cache) if configuration.tieWordEmbeddings {
out = model.embedTokens.asLinear(out)
} else {
out = lmHead(out)
}
return (out, cache)
}
public func sanitize(weights: [String: MLXArray]) -> [String: MLXArray] {
var weights = weights
if configuration.tieWordEmbeddings {
weights["lm_head.weight"] = nil
}
// Remove unused precomputed rotary freqs
return weights.filter {
!$0.key.contains("self_attn.rotary_emb.inv_freq")
}
} }
} }
@@ -207,6 +227,7 @@ public struct Qwen2Configuration: Codable {
var ropeTheta: Float = 1_000_000 var ropeTheta: Float = 1_000_000
var ropeTraditional: Bool = false var ropeTraditional: Bool = false
var ropeScaling: [String: StringOrNumber]? = nil var ropeScaling: [String: StringOrNumber]? = nil
var tieWordEmbeddings = false
enum CodingKeys: String, CodingKey { enum CodingKeys: String, CodingKey {
case hiddenSize = "hidden_size" case hiddenSize = "hidden_size"
@@ -219,6 +240,7 @@ public struct Qwen2Configuration: Codable {
case ropeTheta = "rope_theta" case ropeTheta = "rope_theta"
case ropeTraditional = "rope_traditional" case ropeTraditional = "rope_traditional"
case ropeScaling = "rope_scaling" case ropeScaling = "rope_scaling"
case tieWordEmbeddings = "tie_word_embeddings"
} }
public init(from decoder: Decoder) throws { public init(from decoder: Decoder) throws {
@@ -249,6 +271,8 @@ public struct Qwen2Configuration: Codable {
Bool.self, forKey: Qwen2Configuration.CodingKeys.ropeTraditional) ?? false Bool.self, forKey: Qwen2Configuration.CodingKeys.ropeTraditional) ?? false
self.ropeScaling = try container.decodeIfPresent( self.ropeScaling = try container.decodeIfPresent(
[String: StringOrNumber].self, forKey: Qwen2Configuration.CodingKeys.ropeScaling) [String: StringOrNumber].self, forKey: Qwen2Configuration.CodingKeys.ropeScaling)
self.tieWordEmbeddings =
try container.decodeIfPresent(Bool.self, forKey: .tieWordEmbeddings) ?? false
} }
} }

View File

@@ -44,7 +44,7 @@ struct GenerateArguments: ParsableArguments {
help: help:
"The message to be processed by the model. Use @path,@path to load from files, e.g. @/tmp/prompt.txt" "The message to be processed by the model. Use @path,@path to load from files, e.g. @/tmp/prompt.txt"
) )
var prompt = "compare python and swift" var prompt: String?
@Option(name: .shortAndLong, help: "Maximum number of tokens to generate") @Option(name: .shortAndLong, help: "Maximum number of tokens to generate")
var maxTokens = 100 var maxTokens = 100
@@ -73,7 +73,8 @@ struct GenerateArguments: ParsableArguments {
repetitionContextSize: repetitionContextSize) repetitionContextSize: repetitionContextSize)
} }
func resolvePrompt() throws -> String { func resolvePrompt(configuration: ModelConfiguration) throws -> String {
let prompt = self.prompt ?? configuration.defaultPrompt
if prompt.hasPrefix("@") { if prompt.hasPrefix("@") {
let names = prompt.split(separator: ",").map { String($0.dropFirst()) } let names = prompt.split(separator: ",").map { String($0.dropFirst()) }
return try names.map { try String(contentsOfFile: $0) }.joined(separator: "\n") return try names.map { try String(contentsOfFile: $0) }.joined(separator: "\n")
@@ -87,14 +88,17 @@ struct GenerateArguments: ParsableArguments {
) { ) {
MLXRandom.seed(seed) MLXRandom.seed(seed)
let prompt = try resolvePrompt() let prompt = try resolvePrompt(configuration: configuration)
let preparedPrompt = configuration.prepare(prompt: prompt) let preparedPrompt = configuration.prepare(prompt: prompt)
let promptTokens = tokenizer.encode(text: preparedPrompt) let promptTokens = tokenizer.encode(text: preparedPrompt)
return (prompt, promptTokens) return (prompt, promptTokens)
} }
func generate(promptTokens: [Int], model: LLMModel, tokenizer: Tokenizer) async func generate(
promptTokens: [Int], model: LLMModel, tokenizer: Tokenizer,
extraEOSTokens: Set<String>? = nil
) async
-> GenerateResult -> GenerateResult
{ {
// track how much we have printed // track how much we have printed
@@ -102,7 +106,7 @@ struct GenerateArguments: ParsableArguments {
return await LLM.generate( return await LLM.generate(
promptTokens: promptTokens, parameters: generateParameters, promptTokens: promptTokens, parameters: generateParameters,
model: model, tokenizer: tokenizer model: model, tokenizer: tokenizer, extraEOSTokens: extraEOSTokens
) { tokens in ) { tokens in
// print any new parts of the string // print any new parts of the string
@@ -226,7 +230,8 @@ struct EvaluateCommand: AsyncParsableCommand {
} }
let result = await generate.generate( let result = await generate.generate(
promptTokens: promptTokens, model: model, tokenizer: tokenizer) promptTokens: promptTokens, model: model, tokenizer: tokenizer,
extraEOSTokens: modelConfiguration.extraEOSTokens)
print() print()
if !generate.quiet { if !generate.quiet {

View File

@@ -275,7 +275,8 @@ struct LoRAEvalCommand: AsyncParsableCommand {
// generate and print the result // generate and print the result
let _ = await generate.generate( let _ = await generate.generate(
promptTokens: promptTokens, model: model, tokenizer: tokenizer) promptTokens: promptTokens, model: model, tokenizer: tokenizer,
extraEOSTokens: modelConfiguration.extraEOSTokens)
print() print()
} }
} }

View File

@@ -16,7 +16,7 @@
"location" : "https://github.com/ml-explore/mlx-swift", "location" : "https://github.com/ml-explore/mlx-swift",
"state" : { "state" : {
"branch" : "main", "branch" : "main",
"revision" : "3c802c808d281c191d5f26f37a4f93135d8ca119" "revision" : "d6d9472da5bf7ec2654e8914bd1d15622f45b6a9"
} }
}, },
{ {
@@ -61,7 +61,7 @@
"location" : "https://github.com/gonzalezreal/swift-markdown-ui", "location" : "https://github.com/gonzalezreal/swift-markdown-ui",
"state" : { "state" : {
"branch" : "main", "branch" : "main",
"revision" : "723249a1ba361042812cf785244de94f11f7c8fd" "revision" : "c0daf6eb79d97964180f3113868c990bd1c4a007"
} }
}, },
{ {