diff --git a/Libraries/LLM/Llama.swift b/Libraries/LLM/Llama.swift index 9e07cc4..c9ef4cc 100644 --- a/Libraries/LLM/Llama.swift +++ b/Libraries/LLM/Llama.swift @@ -180,19 +180,25 @@ public class LlamaModel: Module, LLMModel { public let vocabularySize: Int let model: LlamaModelInner - @ModuleInfo(key: "lm_head") var lmHead: Linear + @ModuleInfo(key: "lm_head") var lmHead: Linear? public init(_ args: LlamaConfiguration) { self.vocabularySize = args.vocabularySize self.model = LlamaModelInner(args) - self._lmHead.wrappedValue = Linear(args.hiddenSize, args.vocabularySize, bias: false) + if !args.tieWordEmbeddings { + self._lmHead.wrappedValue = Linear(args.hiddenSize, args.vocabularySize, bias: false) + } } public func callAsFunction(_ inputs: MLXArray, cache: [(MLXArray, MLXArray)]?) -> ( MLXArray, [(MLXArray, MLXArray)] ) { let (out, cache) = model(inputs, cache: cache) - return (lmHead(out), cache) + if let lmHead { + return (lmHead(out), cache) + } else { + return (model.embedTokens.asLinear(out), cache) + } } public func sanitize(weights: [String: MLXArray]) -> [String: MLXArray] { @@ -215,6 +221,7 @@ public struct LlamaConfiguration: Codable { var ropeTheta: Float = 10_000 var ropeTraditional: Bool = false var ropeScaling: [String: StringOrNumber]? = nil + var tieWordEmbeddings: Bool = false enum CodingKeys: String, CodingKey { case hiddenSize = "hidden_size" @@ -227,6 +234,7 @@ public struct LlamaConfiguration: Codable { case ropeTheta = "rope_theta" case ropeTraditional = "rope_traditional" case ropeScaling = "rope_scaling" + case tieWordEmbeddings = "tie_word_embeddings" } public init(from decoder: Decoder) throws { @@ -257,6 +265,8 @@ public struct LlamaConfiguration: Codable { Bool.self, forKey: LlamaConfiguration.CodingKeys.ropeTraditional) ?? false self.ropeScaling = try container.decodeIfPresent( [String: StringOrNumber].self, forKey: LlamaConfiguration.CodingKeys.ropeScaling) + self.tieWordEmbeddings = + try container.decodeIfPresent(Bool.self, forKey: .tieWordEmbeddings) ?? false } } diff --git a/Libraries/LLM/Models.swift b/Libraries/LLM/Models.swift index 937ee5b..229fb94 100644 --- a/Libraries/LLM/Models.swift +++ b/Libraries/LLM/Models.swift @@ -110,6 +110,13 @@ public struct ModelConfiguration { } extension ModelConfiguration { + public static let smolLM_135M_4bit = ModelConfiguration( + id: "mlx-community/SmolLM-135M-Instruct-4bit", + defaultPrompt: "Tell me about the history of Spain." + ) { + prompt in + "<|im_start|>user\n\(prompt)<|im_end|>\n<|im_start|>assistant\n" + } public static let mistral7B4bit = ModelConfiguration( id: "mlx-community/Mistral-7B-v0.1-hf-4bit-mlx",