Add SmolLM (#95)

This commit is contained in:
Anthony DePasquale
2024-07-24 00:42:52 +02:00
committed by GitHub
parent 2a2931ba8d
commit a2e8d7e469
2 changed files with 20 additions and 3 deletions

View File

@@ -180,19 +180,25 @@ public class LlamaModel: Module, LLMModel {
public let vocabularySize: Int public let vocabularySize: Int
let model: LlamaModelInner let model: LlamaModelInner
@ModuleInfo(key: "lm_head") var lmHead: Linear @ModuleInfo(key: "lm_head") var lmHead: Linear?
public init(_ args: LlamaConfiguration) { public init(_ args: LlamaConfiguration) {
self.vocabularySize = args.vocabularySize self.vocabularySize = args.vocabularySize
self.model = LlamaModelInner(args) self.model = LlamaModelInner(args)
self._lmHead.wrappedValue = Linear(args.hiddenSize, args.vocabularySize, bias: false) if !args.tieWordEmbeddings {
self._lmHead.wrappedValue = Linear(args.hiddenSize, args.vocabularySize, bias: false)
}
} }
public func callAsFunction(_ inputs: MLXArray, cache: [(MLXArray, MLXArray)]?) -> ( public func callAsFunction(_ inputs: MLXArray, cache: [(MLXArray, MLXArray)]?) -> (
MLXArray, [(MLXArray, MLXArray)] MLXArray, [(MLXArray, MLXArray)]
) { ) {
let (out, cache) = model(inputs, cache: cache) let (out, cache) = model(inputs, cache: cache)
return (lmHead(out), cache) if let lmHead {
return (lmHead(out), cache)
} else {
return (model.embedTokens.asLinear(out), cache)
}
} }
public func sanitize(weights: [String: MLXArray]) -> [String: MLXArray] { public func sanitize(weights: [String: MLXArray]) -> [String: MLXArray] {
@@ -215,6 +221,7 @@ public struct LlamaConfiguration: Codable {
var ropeTheta: Float = 10_000 var ropeTheta: Float = 10_000
var ropeTraditional: Bool = false var ropeTraditional: Bool = false
var ropeScaling: [String: StringOrNumber]? = nil var ropeScaling: [String: StringOrNumber]? = nil
var tieWordEmbeddings: Bool = false
enum CodingKeys: String, CodingKey { enum CodingKeys: String, CodingKey {
case hiddenSize = "hidden_size" case hiddenSize = "hidden_size"
@@ -227,6 +234,7 @@ public struct LlamaConfiguration: Codable {
case ropeTheta = "rope_theta" case ropeTheta = "rope_theta"
case ropeTraditional = "rope_traditional" case ropeTraditional = "rope_traditional"
case ropeScaling = "rope_scaling" case ropeScaling = "rope_scaling"
case tieWordEmbeddings = "tie_word_embeddings"
} }
public init(from decoder: Decoder) throws { public init(from decoder: Decoder) throws {
@@ -257,6 +265,8 @@ public struct LlamaConfiguration: Codable {
Bool.self, forKey: LlamaConfiguration.CodingKeys.ropeTraditional) ?? false Bool.self, forKey: LlamaConfiguration.CodingKeys.ropeTraditional) ?? false
self.ropeScaling = try container.decodeIfPresent( self.ropeScaling = try container.decodeIfPresent(
[String: StringOrNumber].self, forKey: LlamaConfiguration.CodingKeys.ropeScaling) [String: StringOrNumber].self, forKey: LlamaConfiguration.CodingKeys.ropeScaling)
self.tieWordEmbeddings =
try container.decodeIfPresent(Bool.self, forKey: .tieWordEmbeddings) ?? false
} }
} }

View File

@@ -110,6 +110,13 @@ public struct ModelConfiguration {
} }
extension ModelConfiguration { extension ModelConfiguration {
public static let smolLM_135M_4bit = ModelConfiguration(
id: "mlx-community/SmolLM-135M-Instruct-4bit",
defaultPrompt: "Tell me about the history of Spain."
) {
prompt in
"<|im_start|>user\n\(prompt)<|im_end|>\n<|im_start|>assistant\n"
}
public static let mistral7B4bit = ModelConfiguration( public static let mistral7B4bit = ModelConfiguration(
id: "mlx-community/Mistral-7B-v0.1-hf-4bit-mlx", id: "mlx-community/Mistral-7B-v0.1-hf-4bit-mlx",