Add SmolLM (#95)
This commit is contained in:
committed by
GitHub
parent
2a2931ba8d
commit
a2e8d7e469
@@ -180,19 +180,25 @@ public class LlamaModel: Module, LLMModel {
|
|||||||
public let vocabularySize: Int
|
public let vocabularySize: Int
|
||||||
let model: LlamaModelInner
|
let model: LlamaModelInner
|
||||||
|
|
||||||
@ModuleInfo(key: "lm_head") var lmHead: Linear
|
@ModuleInfo(key: "lm_head") var lmHead: Linear?
|
||||||
|
|
||||||
public init(_ args: LlamaConfiguration) {
|
public init(_ args: LlamaConfiguration) {
|
||||||
self.vocabularySize = args.vocabularySize
|
self.vocabularySize = args.vocabularySize
|
||||||
self.model = LlamaModelInner(args)
|
self.model = LlamaModelInner(args)
|
||||||
self._lmHead.wrappedValue = Linear(args.hiddenSize, args.vocabularySize, bias: false)
|
if !args.tieWordEmbeddings {
|
||||||
|
self._lmHead.wrappedValue = Linear(args.hiddenSize, args.vocabularySize, bias: false)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
public func callAsFunction(_ inputs: MLXArray, cache: [(MLXArray, MLXArray)]?) -> (
|
public func callAsFunction(_ inputs: MLXArray, cache: [(MLXArray, MLXArray)]?) -> (
|
||||||
MLXArray, [(MLXArray, MLXArray)]
|
MLXArray, [(MLXArray, MLXArray)]
|
||||||
) {
|
) {
|
||||||
let (out, cache) = model(inputs, cache: cache)
|
let (out, cache) = model(inputs, cache: cache)
|
||||||
return (lmHead(out), cache)
|
if let lmHead {
|
||||||
|
return (lmHead(out), cache)
|
||||||
|
} else {
|
||||||
|
return (model.embedTokens.asLinear(out), cache)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
public func sanitize(weights: [String: MLXArray]) -> [String: MLXArray] {
|
public func sanitize(weights: [String: MLXArray]) -> [String: MLXArray] {
|
||||||
@@ -215,6 +221,7 @@ public struct LlamaConfiguration: Codable {
|
|||||||
var ropeTheta: Float = 10_000
|
var ropeTheta: Float = 10_000
|
||||||
var ropeTraditional: Bool = false
|
var ropeTraditional: Bool = false
|
||||||
var ropeScaling: [String: StringOrNumber]? = nil
|
var ropeScaling: [String: StringOrNumber]? = nil
|
||||||
|
var tieWordEmbeddings: Bool = false
|
||||||
|
|
||||||
enum CodingKeys: String, CodingKey {
|
enum CodingKeys: String, CodingKey {
|
||||||
case hiddenSize = "hidden_size"
|
case hiddenSize = "hidden_size"
|
||||||
@@ -227,6 +234,7 @@ public struct LlamaConfiguration: Codable {
|
|||||||
case ropeTheta = "rope_theta"
|
case ropeTheta = "rope_theta"
|
||||||
case ropeTraditional = "rope_traditional"
|
case ropeTraditional = "rope_traditional"
|
||||||
case ropeScaling = "rope_scaling"
|
case ropeScaling = "rope_scaling"
|
||||||
|
case tieWordEmbeddings = "tie_word_embeddings"
|
||||||
}
|
}
|
||||||
|
|
||||||
public init(from decoder: Decoder) throws {
|
public init(from decoder: Decoder) throws {
|
||||||
@@ -257,6 +265,8 @@ public struct LlamaConfiguration: Codable {
|
|||||||
Bool.self, forKey: LlamaConfiguration.CodingKeys.ropeTraditional) ?? false
|
Bool.self, forKey: LlamaConfiguration.CodingKeys.ropeTraditional) ?? false
|
||||||
self.ropeScaling = try container.decodeIfPresent(
|
self.ropeScaling = try container.decodeIfPresent(
|
||||||
[String: StringOrNumber].self, forKey: LlamaConfiguration.CodingKeys.ropeScaling)
|
[String: StringOrNumber].self, forKey: LlamaConfiguration.CodingKeys.ropeScaling)
|
||||||
|
self.tieWordEmbeddings =
|
||||||
|
try container.decodeIfPresent(Bool.self, forKey: .tieWordEmbeddings) ?? false
|
||||||
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -110,6 +110,13 @@ public struct ModelConfiguration {
|
|||||||
}
|
}
|
||||||
|
|
||||||
extension ModelConfiguration {
|
extension ModelConfiguration {
|
||||||
|
public static let smolLM_135M_4bit = ModelConfiguration(
|
||||||
|
id: "mlx-community/SmolLM-135M-Instruct-4bit",
|
||||||
|
defaultPrompt: "Tell me about the history of Spain."
|
||||||
|
) {
|
||||||
|
prompt in
|
||||||
|
"<|im_start|>user\n\(prompt)<|im_end|>\n<|im_start|>assistant\n"
|
||||||
|
}
|
||||||
|
|
||||||
public static let mistral7B4bit = ModelConfiguration(
|
public static let mistral7B4bit = ModelConfiguration(
|
||||||
id: "mlx-community/Mistral-7B-v0.1-hf-4bit-mlx",
|
id: "mlx-community/Mistral-7B-v0.1-hf-4bit-mlx",
|
||||||
|
|||||||
Reference in New Issue
Block a user