// Copyright © 2024 Apple Inc. import Foundation import MLX import MLXFast import MLXNN // port of https://github.com/ml-explore/mlx-examples/blob/main/llms/mlx_lm/models/llama.py func computeBaseFrequency( base: Float, dims: Int, ropeType: String, ropeScaling: [String: StringOrNumber]? ) -> Float { if ropeType != "llama3" { return base } guard let ropeScaling = ropeScaling else { return base } guard case .float(let factor) = ropeScaling["factor"], case .float(let lowFreqFactor) = ropeScaling["low_freq_factor"] ?? .float(1.0), case .float(let highFreqFactor) = ropeScaling["high_freq_factor"] ?? .float(4.0), case .float(let oldContextLen) = ropeScaling["original_max_position_embeddings"] ?? .float(8192) else { return base } let lowFreqWavelen = oldContextLen / lowFreqFactor let highFreqWavelen = oldContextLen / highFreqFactor let freqs = (0 ..< dims).compactMap { index -> Float? in if index % 2 == 0 { return pow(base, Float(index) / Float(dims)) } return nil } let newBaseFreqs = freqs.map { freq -> Float in let wavelen = 2 * .pi / freq let smooth = max( 0, min(1, (wavelen - highFreqWavelen) / (lowFreqWavelen - highFreqWavelen))) return freq * ((1 - smooth) * factor + smooth) } return newBaseFreqs.reduce(0, +) / Float(newBaseFreqs.count) } private class DynamicNTKScalingRoPE: Module { let dims: Int let maxPositionEmbeddings: Int? let traditional: Bool let base: Float var scale: Float let ropeType: String let ropeScaling: [String: StringOrNumber]? init( dims: Int, maxPositionEmbeddings: Int?, traditional: Bool = false, base: Float = 10000, scale: Float = 1.0, ropeType: String = "default", ropeScaling: [String: StringOrNumber]? = nil ) { self.dims = dims self.maxPositionEmbeddings = maxPositionEmbeddings self.traditional = traditional self.base = computeBaseFrequency( base: base, dims: dims, ropeType: ropeType, ropeScaling: ropeScaling) self.scale = scale self.ropeType = ropeType self.ropeScaling = ropeScaling } func callAsFunction(_ x: MLXArray, offset: Int = 0) -> MLXArray { let seqLen = x.dim(1) + offset var base = self.base if let maxPositionEmbeddings, seqLen > maxPositionEmbeddings { let factorAdjustment = Float(seqLen) / Float(maxPositionEmbeddings) - 1 let dimensionRatio = Float(dims) / Float(Float(dims) - 2) let adjustedScale = scale * pow(1 + factorAdjustment, dimensionRatio) base *= adjustedScale } return MLXFast.RoPE( x, dimensions: dims, traditional: traditional, base: base, scale: scale, offset: offset) } } private class Attention: Module { let args: LlamaConfiguration let scale: Float @ModuleInfo(key: "q_proj") var wq: Linear @ModuleInfo(key: "k_proj") var wk: Linear @ModuleInfo(key: "v_proj") var wv: Linear @ModuleInfo(key: "o_proj") var wo: Linear let rope: DynamicNTKScalingRoPE init(_ args: LlamaConfiguration) { self.args = args let dim = args.hiddenSize let heads = args.attentionHeads let kvHeads = args.kvHeads let headDim = args.headDimensions ?? (args.hiddenSize / heads) self.scale = pow(Float(headDim), -0.5) self._wq.wrappedValue = Linear(dim, heads * headDim, bias: args.attentionBias) self._wk.wrappedValue = Linear(dim, kvHeads * headDim, bias: args.attentionBias) self._wv.wrappedValue = Linear(dim, kvHeads * headDim, bias: args.attentionBias) self._wo.wrappedValue = Linear(heads * headDim, dim, bias: args.attentionBias) self.rope = DynamicNTKScalingRoPE( dims: headDim, maxPositionEmbeddings: args.maxPositionEmbeddings, traditional: args.ropeTraditional, base: args.ropeTheta, scale: 1.0, ropeType: { if case .string(let value) = args.ropeScaling?["type"] { return value } else { return "default" } }(), ropeScaling: args.ropeScaling) } func callAsFunction( _ x: MLXArray, mask: MLXArray? = nil, cache: (MLXArray, MLXArray)? = nil ) -> (MLXArray, (MLXArray, MLXArray)) { let (B, L) = (x.dim(0), x.dim(1)) var queries = wq(x) var keys = wk(x) var values = wv(x) // Prepare the queries, keys and values for the attention computation queries = queries.reshaped(B, L, args.attentionHeads, -1).transposed(0, 2, 1, 3) keys = keys.reshaped(B, L, args.kvHeads, -1).transposed(0, 2, 1, 3) values = values.reshaped(B, L, args.kvHeads, -1).transposed(0, 2, 1, 3) if let (keyCache, valueCache) = cache { queries = rope(queries, offset: keyCache.dim(2)) keys = rope(keys, offset: keyCache.dim(2)) keys = concatenated([keyCache, keys], axis: 2) values = concatenated([valueCache, values], axis: 2) } else { queries = rope(queries) keys = rope(keys) } let output = MLXFast.scaledDotProductAttention( queries: queries, keys: keys, values: values, scale: scale, mask: mask ) .transposed(0, 2, 1, 3) .reshaped(B, L, -1) return (wo(output), (keys, values)) } } private class MLP: Module, UnaryLayer { @ModuleInfo(key: "gate_proj") var gate: Linear @ModuleInfo(key: "down_proj") var down: Linear @ModuleInfo(key: "up_proj") var up: Linear init(_ args: LlamaConfiguration) { self._gate.wrappedValue = Linear(args.hiddenSize, args.intermediateSize, bias: args.mlpBias) self._down.wrappedValue = Linear(args.intermediateSize, args.hiddenSize, bias: args.mlpBias) self._up.wrappedValue = Linear(args.hiddenSize, args.intermediateSize, bias: args.mlpBias) } func callAsFunction(_ x: MLXArray) -> MLXArray { let activation = silu(gate(x)) return down(activation * up(x)) } } private class TransformerBlock: Module { @ModuleInfo(key: "self_attn") var attention: Attention @ModuleInfo(key: "mlp") var mlp: MLP @ModuleInfo(key: "input_layernorm") var inputLayerNorm: RMSNorm @ModuleInfo(key: "post_attention_layernorm") var postAttentionLayerNorm: RMSNorm init(_ args: LlamaConfiguration) { self._attention.wrappedValue = Attention(args) self._mlp.wrappedValue = MLP(args) self._inputLayerNorm.wrappedValue = RMSNorm( dimensions: args.hiddenSize, eps: args.rmsNormEps) self._postAttentionLayerNorm.wrappedValue = RMSNorm( dimensions: args.hiddenSize, eps: args.rmsNormEps) } func callAsFunction( _ x: MLXArray, mask: MLXArray? = nil, cache: (MLXArray, MLXArray)? = nil ) -> (MLXArray, (MLXArray, MLXArray)) { var (r, cache) = attention(inputLayerNorm(x), mask: mask, cache: cache) let h = x + r r = mlp(postAttentionLayerNorm(h)) let out = h + r return (out, cache) } } private class LlamaModelInner: Module { @ModuleInfo(key: "embed_tokens") var embedTokens: Embedding let layers: [TransformerBlock] let norm: RMSNorm init(_ args: LlamaConfiguration) { precondition(args.vocabularySize > 0) self._embedTokens.wrappedValue = Embedding( embeddingCount: args.vocabularySize, dimensions: args.hiddenSize) self.layers = (0 ..< args.hiddenLayers).map { _ in TransformerBlock(args) } self.norm = RMSNorm(dimensions: args.hiddenSize, eps: args.rmsNormEps) } func callAsFunction(_ inputs: MLXArray, cache: [(MLXArray, MLXArray)]? = nil) -> ( MLXArray, [(MLXArray, MLXArray)] ) { var h = embedTokens(inputs) var mask: MLXArray? = nil if h.dim(1) > 1 { mask = MultiHeadAttention.createAdditiveCausalMask(h.dim(1)) mask = mask?.asType(h.dtype) } var newCache = [(MLXArray, MLXArray)]() for (i, layer) in layers.enumerated() { var cacheUpdate: (MLXArray, MLXArray) (h, cacheUpdate) = layer(h, mask: mask, cache: cache?[i]) newCache.append(cacheUpdate) } return (norm(h), newCache) } } public class LlamaModel: Module, LLMModel { public let vocabularySize: Int fileprivate let model: LlamaModelInner @ModuleInfo(key: "lm_head") var lmHead: Linear? public init(_ args: LlamaConfiguration) { self.vocabularySize = args.vocabularySize self.model = LlamaModelInner(args) if !args.tieWordEmbeddings { self._lmHead.wrappedValue = Linear(args.hiddenSize, args.vocabularySize, bias: false) } } public func callAsFunction(_ inputs: MLXArray, cache: [(MLXArray, MLXArray)]?) -> ( MLXArray, [(MLXArray, MLXArray)] ) { let (out, cache) = model(inputs, cache: cache) if let lmHead { return (lmHead(out), cache) } else { return (model.embedTokens.asLinear(out), cache) } } public func sanitize(weights: [String: MLXArray]) -> [String: MLXArray] { // Remove unused precomputed rotary frequencies weights.filter { !$0.key.contains("self_attn.rotary_emb.inv_freq") } } } public struct LlamaConfiguration: Codable { var hiddenSize: Int var hiddenLayers: Int var intermediateSize: Int var attentionHeads: Int var headDimensions: Int? var rmsNormEps: Float var vocabularySize: Int var kvHeads: Int var maxPositionEmbeddings: Int? var ropeTheta: Float = 10_000 var ropeTraditional: Bool = false var ropeScaling: [String: StringOrNumber]? var tieWordEmbeddings: Bool = true var attentionBias: Bool = false var mlpBias: Bool = false enum CodingKeys: String, CodingKey { case hiddenSize = "hidden_size" case hiddenLayers = "num_hidden_layers" case intermediateSize = "intermediate_size" case attentionHeads = "num_attention_heads" case headDimensions = "head_dim" case rmsNormEps = "rms_norm_eps" case vocabularySize = "vocab_size" case kvHeads = "num_key_value_heads" case maxPositionEmbeddings = "max_position_embeddings" case ropeTheta = "rope_theta" case ropeTraditional = "rope_traditional" case ropeScaling = "rope_scaling" case tieWordEmbeddings = "tie_word_embeddings" case attentionBias = "attention_bias" case mlpBias = "mlp_bias" } public init(from decoder: Decoder) throws { let container = try decoder.container(keyedBy: CodingKeys.self) hiddenSize = try container.decode(Int.self, forKey: .hiddenSize) hiddenLayers = try container.decode(Int.self, forKey: .hiddenLayers) intermediateSize = try container.decode(Int.self, forKey: .intermediateSize) attentionHeads = try container.decode(Int.self, forKey: .attentionHeads) headDimensions = try container.decodeIfPresent(Int.self, forKey: .headDimensions) rmsNormEps = try container.decode(Float.self, forKey: .rmsNormEps) vocabularySize = try container.decode(Int.self, forKey: .vocabularySize) kvHeads = try container.decodeIfPresent(Int.self, forKey: .kvHeads) ?? attentionHeads maxPositionEmbeddings = try container.decodeIfPresent( Int.self, forKey: .maxPositionEmbeddings) if let ropeTheta = try container.decodeIfPresent(Float.self, forKey: .ropeTheta) { self.ropeTheta = ropeTheta } if let ropeTraditional = try container.decodeIfPresent(Bool.self, forKey: .ropeTraditional) { self.ropeTraditional = ropeTraditional } ropeScaling = try container.decodeIfPresent( [String: StringOrNumber].self, forKey: .ropeScaling) if let tieWordEmbeddings = try container.decodeIfPresent( Bool.self, forKey: .tieWordEmbeddings) { self.tieWordEmbeddings = tieWordEmbeddings } if let attentionBias = try container.decodeIfPresent(Bool.self, forKey: .attentionBias) { self.attentionBias = attentionBias } if let mlpBias = try container.decodeIfPresent(Bool.self, forKey: .mlpBias) { self.mlpBias = mlpBias } if let ropeScaling { if ropeScaling["factor"] == nil { throw DecodingError.dataCorruptedError( forKey: .ropeScaling, in: container, debugDescription: "rope_scaling must contain 'factor'") } if let ropeType = ropeScaling["type"] ?? ropeScaling["rope_type"] { if case .string = ropeType { let options = [ StringOrNumber.string("linear"), StringOrNumber.string("dynamic"), StringOrNumber.string("llama3"), ] if !options.contains(ropeType) { throw DecodingError.dataCorruptedError( forKey: .ropeScaling, in: container, debugDescription: "rope_scaling 'type' currently only supports 'linear', 'dynamic', or 'llama3'" ) } } } else { throw DecodingError.dataCorruptedError( forKey: .ropeScaling, in: container, debugDescription: "rope_scaling must contain either 'type' or 'rope_type'") } } } } // MARK: - LoRA extension LlamaModel: LoRAModel { public func loraLinearLayers() -> LoRALinearLayers { model.layers.map { ($0.attention, ["q_proj", "v_proj"]) } } }