committed by
GitHub
parent
a2e8d7e469
commit
c4fda0e036
@@ -26,7 +26,7 @@ private class Attention: Module {
|
|||||||
let heads = args.attentionHeads
|
let heads = args.attentionHeads
|
||||||
let kvHeads = args.kvHeads
|
let kvHeads = args.kvHeads
|
||||||
|
|
||||||
let headDim = args.hiddenSize / heads
|
let headDim = args.headDimensions ?? (args.hiddenSize / heads)
|
||||||
self.scale = pow(Float(headDim), -0.5)
|
self.scale = pow(Float(headDim), -0.5)
|
||||||
|
|
||||||
self._wq.wrappedValue = Linear(dim, heads * headDim, bias: false)
|
self._wq.wrappedValue = Linear(dim, heads * headDim, bias: false)
|
||||||
@@ -215,6 +215,7 @@ public struct LlamaConfiguration: Codable {
|
|||||||
var hiddenLayers: Int
|
var hiddenLayers: Int
|
||||||
var intermediateSize: Int
|
var intermediateSize: Int
|
||||||
var attentionHeads: Int
|
var attentionHeads: Int
|
||||||
|
var headDimensions: Int? = nil
|
||||||
var rmsNormEps: Float
|
var rmsNormEps: Float
|
||||||
var vocabularySize: Int
|
var vocabularySize: Int
|
||||||
var kvHeads: Int
|
var kvHeads: Int
|
||||||
@@ -228,6 +229,7 @@ public struct LlamaConfiguration: Codable {
|
|||||||
case hiddenLayers = "num_hidden_layers"
|
case hiddenLayers = "num_hidden_layers"
|
||||||
case intermediateSize = "intermediate_size"
|
case intermediateSize = "intermediate_size"
|
||||||
case attentionHeads = "num_attention_heads"
|
case attentionHeads = "num_attention_heads"
|
||||||
|
case headDimensions = "head_dim"
|
||||||
case rmsNormEps = "rms_norm_eps"
|
case rmsNormEps = "rms_norm_eps"
|
||||||
case vocabularySize = "vocab_size"
|
case vocabularySize = "vocab_size"
|
||||||
case kvHeads = "num_key_value_heads"
|
case kvHeads = "num_key_value_heads"
|
||||||
@@ -251,6 +253,8 @@ public struct LlamaConfiguration: Codable {
|
|||||||
Int.self, forKey: LlamaConfiguration.CodingKeys.intermediateSize)
|
Int.self, forKey: LlamaConfiguration.CodingKeys.intermediateSize)
|
||||||
self.attentionHeads = try container.decode(
|
self.attentionHeads = try container.decode(
|
||||||
Int.self, forKey: LlamaConfiguration.CodingKeys.attentionHeads)
|
Int.self, forKey: LlamaConfiguration.CodingKeys.attentionHeads)
|
||||||
|
self.headDimensions = try container.decodeIfPresent(
|
||||||
|
Int.self, forKey: LlamaConfiguration.CodingKeys.headDimensions)
|
||||||
self.rmsNormEps = try container.decode(
|
self.rmsNormEps = try container.decode(
|
||||||
Float.self, forKey: LlamaConfiguration.CodingKeys.rmsNormEps)
|
Float.self, forKey: LlamaConfiguration.CodingKeys.rmsNormEps)
|
||||||
self.vocabularySize = try container.decode(
|
self.vocabularySize = try container.decode(
|
||||||
|
|||||||
@@ -118,12 +118,19 @@ extension ModelConfiguration {
|
|||||||
"<|im_start|>user\n\(prompt)<|im_end|>\n<|im_start|>assistant\n"
|
"<|im_start|>user\n\(prompt)<|im_end|>\n<|im_start|>assistant\n"
|
||||||
}
|
}
|
||||||
|
|
||||||
public static let mistral7B4bit = ModelConfiguration(
|
public static let mistralNeMo4bit = ModelConfiguration(
|
||||||
id: "mlx-community/Mistral-7B-v0.1-hf-4bit-mlx",
|
id: "mlx-community/Mistral-Nemo-Instruct-2407-4bit",
|
||||||
|
defaultPrompt: "Explain quaternions."
|
||||||
|
) { prompt in
|
||||||
|
"<s>[INST] \(prompt) [/INST] "
|
||||||
|
}
|
||||||
|
|
||||||
// https://www.promptingguide.ai/models/mistral-7b
|
public static let mistral7B4bit = ModelConfiguration(
|
||||||
defaultPrompt: "describe the swift language"
|
id: "mlx-community/Mistral-7B-Instruct-v0.3-4bit",
|
||||||
)
|
defaultPrompt: "Describe the Swift language."
|
||||||
|
) { prompt in
|
||||||
|
"<s>[INST] \(prompt) [/INST] "
|
||||||
|
}
|
||||||
|
|
||||||
public static let codeLlama13b4bit = ModelConfiguration(
|
public static let codeLlama13b4bit = ModelConfiguration(
|
||||||
id: "mlx-community/CodeLlama-13b-Instruct-hf-4bit-MLX",
|
id: "mlx-community/CodeLlama-13b-Instruct-hf-4bit-MLX",
|
||||||
@@ -213,6 +220,8 @@ extension ModelConfiguration {
|
|||||||
case .idle:
|
case .idle:
|
||||||
bootstrapState = .bootstrapping
|
bootstrapState = .bootstrapping
|
||||||
register(configurations: [
|
register(configurations: [
|
||||||
|
mistralNeMo4bit,
|
||||||
|
smolLM_135M_4bit,
|
||||||
mistral7B4bit,
|
mistral7B4bit,
|
||||||
codeLlama13b4bit,
|
codeLlama13b4bit,
|
||||||
phi4bit,
|
phi4bit,
|
||||||
|
|||||||
Reference in New Issue
Block a user