Add Llama 3.1 (#98)
* Update Mistral 7B config * Add Mistral NeMo * Update for Llama 3.1 * Align LlamaConfiguration with Python implementation * Fix model configuration names * Refine DynamicNTKScalingRoPE * compute base only once --------- Co-authored-by: Awni Hannun <awni@apple.com>
This commit is contained in:
@@ -7,6 +7,86 @@ import MLXNN
|
||||
|
||||
// port of https://github.com/ml-explore/mlx-examples/blob/main/llms/mlx_lm/models/llama.py
|
||||
|
||||
func computeBaseFrequency(
|
||||
base: Float, dims: Int, ropeType: String, ropeScaling: [String: StringOrNumber]?
|
||||
)
|
||||
-> Float
|
||||
{
|
||||
if ropeType != "llama3" {
|
||||
return base
|
||||
}
|
||||
|
||||
guard let ropeScaling = ropeScaling else {
|
||||
return base
|
||||
}
|
||||
|
||||
guard case .float(let factor) = ropeScaling["factor"],
|
||||
case .float(let lowFreqFactor) = ropeScaling["low_freq_factor"] ?? .float(1.0),
|
||||
case .float(let highFreqFactor) = ropeScaling["high_freq_factor"] ?? .float(4.0),
|
||||
case .float(let oldContextLen) = ropeScaling["original_max_position_embeddings"]
|
||||
?? .float(8192)
|
||||
else {
|
||||
return base
|
||||
}
|
||||
|
||||
let lowFreqWavelen = oldContextLen / lowFreqFactor
|
||||
let highFreqWavelen = oldContextLen / highFreqFactor
|
||||
|
||||
let freqs = (0 ..< dims).compactMap { index -> Float? in
|
||||
if index % 2 == 0 {
|
||||
return pow(base, Float(index) / Float(dims))
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
let newBaseFreqs = freqs.map { freq -> Float in
|
||||
let wavelen = 2 * .pi / freq
|
||||
let smooth = max(
|
||||
0, min(1, (wavelen - highFreqWavelen) / (lowFreqWavelen - highFreqWavelen)))
|
||||
return freq * ((1 - smooth) * factor + smooth)
|
||||
}
|
||||
|
||||
return newBaseFreqs.reduce(0, +) / Float(newBaseFreqs.count)
|
||||
}
|
||||
|
||||
private class DynamicNTKScalingRoPE: Module {
|
||||
let dims: Int
|
||||
let maxPositionEmbeddings: Int?
|
||||
let traditional: Bool
|
||||
let base: Float
|
||||
var scale: Float
|
||||
let ropeType: String
|
||||
let ropeScaling: [String: StringOrNumber]?
|
||||
|
||||
init(
|
||||
dims: Int, maxPositionEmbeddings: Int?, traditional: Bool = false,
|
||||
base: Float = 10000, scale: Float = 1.0, ropeType: String = "default",
|
||||
ropeScaling: [String: StringOrNumber]? = nil
|
||||
) {
|
||||
self.dims = dims
|
||||
self.maxPositionEmbeddings = maxPositionEmbeddings
|
||||
self.traditional = traditional
|
||||
self.base = computeBaseFrequency(
|
||||
base: base, dims: dims, ropeType: ropeType, ropeScaling: ropeScaling)
|
||||
self.scale = scale
|
||||
self.ropeType = ropeType
|
||||
self.ropeScaling = ropeScaling
|
||||
}
|
||||
|
||||
func callAsFunction(_ x: MLXArray, offset: Int = 0) -> MLXArray {
|
||||
let seqLen = x.dim(1) + offset
|
||||
var base = self.base
|
||||
if let maxPositionEmbeddings, seqLen > maxPositionEmbeddings {
|
||||
let factorAdjustment = Float(seqLen) / Float(maxPositionEmbeddings) - 1
|
||||
let dimensionRatio = Float(dims) / Float(Float(dims) - 2)
|
||||
let adjustedScale = scale * pow(1 + factorAdjustment, dimensionRatio)
|
||||
base *= adjustedScale
|
||||
}
|
||||
return MLXFast.RoPE(
|
||||
x, dimensions: dims, traditional: traditional, base: base, scale: scale, offset: offset)
|
||||
}
|
||||
}
|
||||
|
||||
private class Attention: Module {
|
||||
|
||||
let args: LlamaConfiguration
|
||||
@@ -17,9 +97,9 @@ private class Attention: Module {
|
||||
@ModuleInfo(key: "v_proj") var wv: Linear
|
||||
@ModuleInfo(key: "o_proj") var wo: Linear
|
||||
|
||||
let rope: RoPE
|
||||
let rope: DynamicNTKScalingRoPE
|
||||
|
||||
public init(_ args: LlamaConfiguration) {
|
||||
init(_ args: LlamaConfiguration) {
|
||||
self.args = args
|
||||
|
||||
let dim = args.hiddenSize
|
||||
@@ -29,31 +109,28 @@ private class Attention: Module {
|
||||
let headDim = args.headDimensions ?? (args.hiddenSize / heads)
|
||||
self.scale = pow(Float(headDim), -0.5)
|
||||
|
||||
self._wq.wrappedValue = Linear(dim, heads * headDim, bias: false)
|
||||
self._wk.wrappedValue = Linear(dim, kvHeads * headDim, bias: false)
|
||||
self._wv.wrappedValue = Linear(dim, kvHeads * headDim, bias: false)
|
||||
self._wo.wrappedValue = Linear(heads * headDim, dim, bias: false)
|
||||
self._wq.wrappedValue = Linear(dim, heads * headDim, bias: args.attentionBias)
|
||||
self._wk.wrappedValue = Linear(dim, kvHeads * headDim, bias: args.attentionBias)
|
||||
self._wv.wrappedValue = Linear(dim, kvHeads * headDim, bias: args.attentionBias)
|
||||
self._wo.wrappedValue = Linear(heads * headDim, dim, bias: args.attentionBias)
|
||||
|
||||
let ropeScale: Float
|
||||
if let ropeScaling = args.ropeScaling, ropeScaling["type"] == .string("linear"),
|
||||
let factor = ropeScaling["factor"]
|
||||
{
|
||||
switch factor {
|
||||
case .string:
|
||||
fatalError("ropeScaling.factor must be a float")
|
||||
case .float(let v):
|
||||
ropeScale = 1 / v
|
||||
}
|
||||
} else {
|
||||
ropeScale = 1
|
||||
}
|
||||
|
||||
self.rope = RoPE(
|
||||
dimensions: headDim, traditional: args.ropeTraditional, base: args.ropeTheta,
|
||||
scale: ropeScale)
|
||||
self.rope = DynamicNTKScalingRoPE(
|
||||
dims: headDim,
|
||||
maxPositionEmbeddings: args.maxPositionEmbeddings,
|
||||
traditional: args.ropeTraditional,
|
||||
base: args.ropeTheta,
|
||||
scale: 1.0,
|
||||
ropeType: {
|
||||
if case .string(let value) = args.ropeScaling?["type"] {
|
||||
return value
|
||||
} else {
|
||||
return "default"
|
||||
}
|
||||
}(),
|
||||
ropeScaling: args.ropeScaling)
|
||||
}
|
||||
|
||||
public func callAsFunction(
|
||||
func callAsFunction(
|
||||
_ x: MLXArray, mask: MLXArray? = nil, cache: (MLXArray, MLXArray)? = nil
|
||||
) -> (MLXArray, (MLXArray, MLXArray)) {
|
||||
let (B, L) = (x.dim(0), x.dim(1))
|
||||
@@ -62,7 +139,7 @@ private class Attention: Module {
|
||||
var keys = wk(x)
|
||||
var values = wv(x)
|
||||
|
||||
// prepare the queries, keys and values for the attention computation
|
||||
// Prepare the queries, keys and values for the attention computation
|
||||
queries = queries.reshaped(B, L, args.attentionHeads, -1).transposed(0, 2, 1, 3)
|
||||
keys = keys.reshaped(B, L, args.kvHeads, -1).transposed(0, 2, 1, 3)
|
||||
values = values.reshaped(B, L, args.kvHeads, -1).transposed(0, 2, 1, 3)
|
||||
@@ -93,35 +170,35 @@ private class MLP: Module, UnaryLayer {
|
||||
@ModuleInfo(key: "down_proj") var down: Linear
|
||||
@ModuleInfo(key: "up_proj") var up: Linear
|
||||
|
||||
public init(dimensions: Int, hiddenDimensions: Int) {
|
||||
self._gate.wrappedValue = Linear(dimensions, hiddenDimensions, bias: false)
|
||||
self._down.wrappedValue = Linear(hiddenDimensions, dimensions, bias: false)
|
||||
self._up.wrappedValue = Linear(dimensions, hiddenDimensions, bias: false)
|
||||
init(_ args: LlamaConfiguration) {
|
||||
self._gate.wrappedValue = Linear(args.hiddenSize, args.intermediateSize, bias: args.mlpBias)
|
||||
self._down.wrappedValue = Linear(args.intermediateSize, args.hiddenSize, bias: args.mlpBias)
|
||||
self._up.wrappedValue = Linear(args.hiddenSize, args.intermediateSize, bias: args.mlpBias)
|
||||
}
|
||||
|
||||
public func callAsFunction(_ x: MLXArray) -> MLXArray {
|
||||
down(silu(gate(x)) * up(x))
|
||||
func callAsFunction(_ x: MLXArray) -> MLXArray {
|
||||
let activation = silu(gate(x))
|
||||
return down(activation * up(x))
|
||||
}
|
||||
}
|
||||
|
||||
private class TransformerBlock: Module {
|
||||
|
||||
@ModuleInfo(key: "self_attn") var attention: Attention
|
||||
let mlp: MLP
|
||||
@ModuleInfo(key: "mlp") var mlp: MLP
|
||||
|
||||
@ModuleInfo(key: "input_layernorm") var inputLayerNorm: RMSNorm
|
||||
@ModuleInfo(key: "post_attention_layernorm") var postAttentionLayerNorm: RMSNorm
|
||||
|
||||
public init(_ args: LlamaConfiguration) {
|
||||
init(_ args: LlamaConfiguration) {
|
||||
self._attention.wrappedValue = Attention(args)
|
||||
self.mlp = MLP(dimensions: args.hiddenSize, hiddenDimensions: args.intermediateSize)
|
||||
self._mlp.wrappedValue = MLP(args)
|
||||
self._inputLayerNorm.wrappedValue = RMSNorm(
|
||||
dimensions: args.hiddenSize, eps: args.rmsNormEps)
|
||||
self._postAttentionLayerNorm.wrappedValue = RMSNorm(
|
||||
dimensions: args.hiddenSize, eps: args.rmsNormEps)
|
||||
}
|
||||
|
||||
public func callAsFunction(
|
||||
func callAsFunction(
|
||||
_ x: MLXArray, mask: MLXArray? = nil, cache: (MLXArray, MLXArray)? = nil
|
||||
) -> (MLXArray, (MLXArray, MLXArray)) {
|
||||
var (r, cache) = attention(inputLayerNorm(x), mask: mask, cache: cache)
|
||||
@@ -132,27 +209,24 @@ private class TransformerBlock: Module {
|
||||
}
|
||||
}
|
||||
|
||||
public class LlamaModelInner: Module {
|
||||
private class LlamaModelInner: Module {
|
||||
|
||||
@ModuleInfo(key: "embed_tokens") var embedTokens: Embedding
|
||||
|
||||
fileprivate let layers: [TransformerBlock]
|
||||
let layers: [TransformerBlock]
|
||||
let norm: RMSNorm
|
||||
|
||||
public init(_ args: LlamaConfiguration) {
|
||||
init(_ args: LlamaConfiguration) {
|
||||
precondition(args.vocabularySize > 0)
|
||||
|
||||
self._embedTokens.wrappedValue = Embedding(
|
||||
embeddingCount: args.vocabularySize, dimensions: args.hiddenSize)
|
||||
|
||||
self.layers = (0 ..< args.hiddenLayers)
|
||||
.map { _ in
|
||||
TransformerBlock(args)
|
||||
}
|
||||
self.layers = (0 ..< args.hiddenLayers).map { _ in TransformerBlock(args) }
|
||||
self.norm = RMSNorm(dimensions: args.hiddenSize, eps: args.rmsNormEps)
|
||||
}
|
||||
|
||||
public func callAsFunction(_ inputs: MLXArray, cache: [(MLXArray, MLXArray)]? = nil) -> (
|
||||
func callAsFunction(_ inputs: MLXArray, cache: [(MLXArray, MLXArray)]? = nil) -> (
|
||||
MLXArray, [(MLXArray, MLXArray)]
|
||||
) {
|
||||
var h = embedTokens(inputs)
|
||||
@@ -178,7 +252,7 @@ public class LlamaModelInner: Module {
|
||||
public class LlamaModel: Module, LLMModel {
|
||||
|
||||
public let vocabularySize: Int
|
||||
let model: LlamaModelInner
|
||||
fileprivate let model: LlamaModelInner
|
||||
|
||||
@ModuleInfo(key: "lm_head") var lmHead: Linear?
|
||||
|
||||
@@ -202,7 +276,7 @@ public class LlamaModel: Module, LLMModel {
|
||||
}
|
||||
|
||||
public func sanitize(weights: [String: MLXArray]) -> [String: MLXArray] {
|
||||
// Remove unused precomputed rotary freqs
|
||||
// Remove unused precomputed rotary frequencies
|
||||
weights.filter {
|
||||
!$0.key.contains("self_attn.rotary_emb.inv_freq")
|
||||
}
|
||||
@@ -215,14 +289,17 @@ public struct LlamaConfiguration: Codable {
|
||||
var hiddenLayers: Int
|
||||
var intermediateSize: Int
|
||||
var attentionHeads: Int
|
||||
var headDimensions: Int? = nil
|
||||
var headDimensions: Int?
|
||||
var rmsNormEps: Float
|
||||
var vocabularySize: Int
|
||||
var kvHeads: Int
|
||||
var maxPositionEmbeddings: Int?
|
||||
var ropeTheta: Float = 10_000
|
||||
var ropeTraditional: Bool = false
|
||||
var ropeScaling: [String: StringOrNumber]? = nil
|
||||
var tieWordEmbeddings: Bool = false
|
||||
var ropeScaling: [String: StringOrNumber]?
|
||||
var tieWordEmbeddings: Bool = true
|
||||
var attentionBias: Bool = false
|
||||
var mlpBias: Bool = false
|
||||
|
||||
enum CodingKeys: String, CodingKey {
|
||||
case hiddenSize = "hidden_size"
|
||||
@@ -233,45 +310,75 @@ public struct LlamaConfiguration: Codable {
|
||||
case rmsNormEps = "rms_norm_eps"
|
||||
case vocabularySize = "vocab_size"
|
||||
case kvHeads = "num_key_value_heads"
|
||||
case maxPositionEmbeddings = "max_position_embeddings"
|
||||
case ropeTheta = "rope_theta"
|
||||
case ropeTraditional = "rope_traditional"
|
||||
case ropeScaling = "rope_scaling"
|
||||
case tieWordEmbeddings = "tie_word_embeddings"
|
||||
case attentionBias = "attention_bias"
|
||||
case mlpBias = "mlp_bias"
|
||||
}
|
||||
|
||||
public init(from decoder: Decoder) throws {
|
||||
// custom implementation to handle optional keys with required values
|
||||
let container: KeyedDecodingContainer<LlamaConfiguration.CodingKeys> =
|
||||
try decoder.container(
|
||||
keyedBy: LlamaConfiguration.CodingKeys.self)
|
||||
let container = try decoder.container(keyedBy: CodingKeys.self)
|
||||
|
||||
self.hiddenSize = try container.decode(
|
||||
Int.self, forKey: LlamaConfiguration.CodingKeys.hiddenSize)
|
||||
self.hiddenLayers = try container.decode(
|
||||
Int.self, forKey: LlamaConfiguration.CodingKeys.hiddenLayers)
|
||||
self.intermediateSize = try container.decode(
|
||||
Int.self, forKey: LlamaConfiguration.CodingKeys.intermediateSize)
|
||||
self.attentionHeads = try container.decode(
|
||||
Int.self, forKey: LlamaConfiguration.CodingKeys.attentionHeads)
|
||||
self.headDimensions = try container.decodeIfPresent(
|
||||
Int.self, forKey: LlamaConfiguration.CodingKeys.headDimensions)
|
||||
self.rmsNormEps = try container.decode(
|
||||
Float.self, forKey: LlamaConfiguration.CodingKeys.rmsNormEps)
|
||||
self.vocabularySize = try container.decode(
|
||||
Int.self, forKey: LlamaConfiguration.CodingKeys.vocabularySize)
|
||||
self.kvHeads = try container.decode(Int.self, forKey: LlamaConfiguration.CodingKeys.kvHeads)
|
||||
self.ropeTheta =
|
||||
try container.decodeIfPresent(
|
||||
Float.self, forKey: LlamaConfiguration.CodingKeys.ropeTheta)
|
||||
?? 10_000
|
||||
self.ropeTraditional =
|
||||
try container.decodeIfPresent(
|
||||
Bool.self, forKey: LlamaConfiguration.CodingKeys.ropeTraditional) ?? false
|
||||
self.ropeScaling = try container.decodeIfPresent(
|
||||
[String: StringOrNumber].self, forKey: LlamaConfiguration.CodingKeys.ropeScaling)
|
||||
self.tieWordEmbeddings =
|
||||
try container.decodeIfPresent(Bool.self, forKey: .tieWordEmbeddings) ?? false
|
||||
hiddenSize = try container.decode(Int.self, forKey: .hiddenSize)
|
||||
hiddenLayers = try container.decode(Int.self, forKey: .hiddenLayers)
|
||||
intermediateSize = try container.decode(Int.self, forKey: .intermediateSize)
|
||||
attentionHeads = try container.decode(Int.self, forKey: .attentionHeads)
|
||||
headDimensions = try container.decodeIfPresent(Int.self, forKey: .headDimensions)
|
||||
rmsNormEps = try container.decode(Float.self, forKey: .rmsNormEps)
|
||||
vocabularySize = try container.decode(Int.self, forKey: .vocabularySize)
|
||||
kvHeads = try container.decodeIfPresent(Int.self, forKey: .kvHeads) ?? attentionHeads
|
||||
maxPositionEmbeddings = try container.decodeIfPresent(
|
||||
Int.self, forKey: .maxPositionEmbeddings)
|
||||
if let ropeTheta = try container.decodeIfPresent(Float.self, forKey: .ropeTheta) {
|
||||
self.ropeTheta = ropeTheta
|
||||
}
|
||||
if let ropeTraditional = try container.decodeIfPresent(Bool.self, forKey: .ropeTraditional)
|
||||
{
|
||||
self.ropeTraditional = ropeTraditional
|
||||
}
|
||||
ropeScaling = try container.decodeIfPresent(
|
||||
[String: StringOrNumber].self, forKey: .ropeScaling)
|
||||
if let tieWordEmbeddings = try container.decodeIfPresent(
|
||||
Bool.self, forKey: .tieWordEmbeddings)
|
||||
{
|
||||
self.tieWordEmbeddings = tieWordEmbeddings
|
||||
}
|
||||
if let attentionBias = try container.decodeIfPresent(Bool.self, forKey: .attentionBias) {
|
||||
self.attentionBias = attentionBias
|
||||
}
|
||||
if let mlpBias = try container.decodeIfPresent(Bool.self, forKey: .mlpBias) {
|
||||
self.mlpBias = mlpBias
|
||||
}
|
||||
|
||||
if let ropeScaling {
|
||||
if ropeScaling["factor"] == nil {
|
||||
throw DecodingError.dataCorruptedError(
|
||||
forKey: .ropeScaling, in: container,
|
||||
debugDescription: "rope_scaling must contain 'factor'")
|
||||
}
|
||||
if let ropeType = ropeScaling["type"] ?? ropeScaling["rope_type"] {
|
||||
if case .string = ropeType {
|
||||
let options = [
|
||||
StringOrNumber.string("linear"), StringOrNumber.string("dynamic"),
|
||||
StringOrNumber.string("llama3"),
|
||||
]
|
||||
if !options.contains(ropeType) {
|
||||
throw DecodingError.dataCorruptedError(
|
||||
forKey: .ropeScaling, in: container,
|
||||
debugDescription:
|
||||
"rope_scaling 'type' currently only supports 'linear', 'dynamic', or 'llama3'"
|
||||
)
|
||||
}
|
||||
}
|
||||
} else {
|
||||
throw DecodingError.dataCorruptedError(
|
||||
forKey: .ropeScaling, in: container,
|
||||
debugDescription: "rope_scaling must contain either 'type' or 'rope_type'")
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -151,7 +151,7 @@ extension ModelConfiguration {
|
||||
defaultPrompt: "Why is the sky blue?"
|
||||
)
|
||||
|
||||
public static let phi34bit = ModelConfiguration(
|
||||
public static let phi3_4bit = ModelConfiguration(
|
||||
id: "mlx-community/Phi-3-mini-4k-instruct-4bit-no-q-embed",
|
||||
defaultPrompt: "what is the gravity on mars and the moon?",
|
||||
extraEOSTokens: ["<|end|>"]
|
||||
@@ -199,9 +199,17 @@ extension ModelConfiguration {
|
||||
"\(prompt)"
|
||||
}
|
||||
|
||||
public static let llama38B4bit = ModelConfiguration(
|
||||
public static let llama3_1_8B_4bit = ModelConfiguration(
|
||||
id: "mlx-community/Meta-Llama-3.1-8B-Instruct-4bit",
|
||||
defaultPrompt: "What is the difference between a fruit and a vegetable?"
|
||||
) {
|
||||
prompt in
|
||||
"<|begin_of_text|><|start_header_id|>system<|end_header_id|>\nYou are a helpful assistant<|eot_id|>\n<|start_header_id|>user<|end_header_id|>\n\(prompt)<|eot_id|>\n<|start_header_id|>assistant<|end_header_id|>"
|
||||
}
|
||||
|
||||
public static let llama3_8B_4bit = ModelConfiguration(
|
||||
id: "mlx-community/Meta-Llama-3-8B-Instruct-4bit",
|
||||
defaultPrompt: "what is the difference between a fruit and a vegetable?"
|
||||
defaultPrompt: "What is the difference between a fruit and a vegetable?"
|
||||
) {
|
||||
prompt in
|
||||
"<|begin_of_text|><|start_header_id|>system<|end_header_id|>\nYou are a helpful assistant<|eot_id|>\n<|start_header_id|>user<|end_header_id|>\n\(prompt)<|eot_id|>\n<|start_header_id|>assistant<|end_header_id|>"
|
||||
@@ -220,12 +228,13 @@ extension ModelConfiguration {
|
||||
case .idle:
|
||||
bootstrapState = .bootstrapping
|
||||
register(configurations: [
|
||||
llama3_1_8B_4bit,
|
||||
mistralNeMo4bit,
|
||||
smolLM_135M_4bit,
|
||||
mistral7B4bit,
|
||||
codeLlama13b4bit,
|
||||
phi4bit,
|
||||
phi34bit,
|
||||
phi3_4bit,
|
||||
gemma2bQuantized,
|
||||
gemma_2_9b_it_4bit,
|
||||
qwen205b4bit,
|
||||
|
||||
Reference in New Issue
Block a user