Add SmolLM (#95)
This commit is contained in:
committed by
GitHub
parent
2a2931ba8d
commit
a2e8d7e469
@@ -180,19 +180,25 @@ public class LlamaModel: Module, LLMModel {
|
||||
public let vocabularySize: Int
|
||||
let model: LlamaModelInner
|
||||
|
||||
@ModuleInfo(key: "lm_head") var lmHead: Linear
|
||||
@ModuleInfo(key: "lm_head") var lmHead: Linear?
|
||||
|
||||
public init(_ args: LlamaConfiguration) {
|
||||
self.vocabularySize = args.vocabularySize
|
||||
self.model = LlamaModelInner(args)
|
||||
if !args.tieWordEmbeddings {
|
||||
self._lmHead.wrappedValue = Linear(args.hiddenSize, args.vocabularySize, bias: false)
|
||||
}
|
||||
}
|
||||
|
||||
public func callAsFunction(_ inputs: MLXArray, cache: [(MLXArray, MLXArray)]?) -> (
|
||||
MLXArray, [(MLXArray, MLXArray)]
|
||||
) {
|
||||
let (out, cache) = model(inputs, cache: cache)
|
||||
if let lmHead {
|
||||
return (lmHead(out), cache)
|
||||
} else {
|
||||
return (model.embedTokens.asLinear(out), cache)
|
||||
}
|
||||
}
|
||||
|
||||
public func sanitize(weights: [String: MLXArray]) -> [String: MLXArray] {
|
||||
@@ -215,6 +221,7 @@ public struct LlamaConfiguration: Codable {
|
||||
var ropeTheta: Float = 10_000
|
||||
var ropeTraditional: Bool = false
|
||||
var ropeScaling: [String: StringOrNumber]? = nil
|
||||
var tieWordEmbeddings: Bool = false
|
||||
|
||||
enum CodingKeys: String, CodingKey {
|
||||
case hiddenSize = "hidden_size"
|
||||
@@ -227,6 +234,7 @@ public struct LlamaConfiguration: Codable {
|
||||
case ropeTheta = "rope_theta"
|
||||
case ropeTraditional = "rope_traditional"
|
||||
case ropeScaling = "rope_scaling"
|
||||
case tieWordEmbeddings = "tie_word_embeddings"
|
||||
}
|
||||
|
||||
public init(from decoder: Decoder) throws {
|
||||
@@ -257,6 +265,8 @@ public struct LlamaConfiguration: Codable {
|
||||
Bool.self, forKey: LlamaConfiguration.CodingKeys.ropeTraditional) ?? false
|
||||
self.ropeScaling = try container.decodeIfPresent(
|
||||
[String: StringOrNumber].self, forKey: LlamaConfiguration.CodingKeys.ropeScaling)
|
||||
self.tieWordEmbeddings =
|
||||
try container.decodeIfPresent(Bool.self, forKey: .tieWordEmbeddings) ?? false
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
@@ -110,6 +110,13 @@ public struct ModelConfiguration {
|
||||
}
|
||||
|
||||
extension ModelConfiguration {
|
||||
public static let smolLM_135M_4bit = ModelConfiguration(
|
||||
id: "mlx-community/SmolLM-135M-Instruct-4bit",
|
||||
defaultPrompt: "Tell me about the history of Spain."
|
||||
) {
|
||||
prompt in
|
||||
"<|im_start|>user\n\(prompt)<|im_end|>\n<|im_start|>assistant\n"
|
||||
}
|
||||
|
||||
public static let mistral7B4bit = ModelConfiguration(
|
||||
id: "mlx-community/Mistral-7B-v0.1-hf-4bit-mlx",
|
||||
|
||||
Reference in New Issue
Block a user