diff --git a/Applications/LLMEval/ContentView.swift b/Applications/LLMEval/ContentView.swift index 52821d5..b050d99 100644 --- a/Applications/LLMEval/ContentView.swift +++ b/Applications/LLMEval/ContentView.swift @@ -159,7 +159,7 @@ class LLMEvaluator { /// this controls which model loads -- phi4bit is one of the smaller ones so this will fit on /// more devices - let modelConfiguration = ModelConfiguration.phi34bit + let modelConfiguration = ModelConfiguration.phi3_4bit /// parameters controlling the output let generateParameters = GenerateParameters(temperature: 0.6) diff --git a/Libraries/LLM/Llama.swift b/Libraries/LLM/Llama.swift index 80b3387..c948b60 100644 --- a/Libraries/LLM/Llama.swift +++ b/Libraries/LLM/Llama.swift @@ -7,6 +7,86 @@ import MLXNN // port of https://github.com/ml-explore/mlx-examples/blob/main/llms/mlx_lm/models/llama.py +func computeBaseFrequency( + base: Float, dims: Int, ropeType: String, ropeScaling: [String: StringOrNumber]? +) + -> Float +{ + if ropeType != "llama3" { + return base + } + + guard let ropeScaling = ropeScaling else { + return base + } + + guard case .float(let factor) = ropeScaling["factor"], + case .float(let lowFreqFactor) = ropeScaling["low_freq_factor"] ?? .float(1.0), + case .float(let highFreqFactor) = ropeScaling["high_freq_factor"] ?? .float(4.0), + case .float(let oldContextLen) = ropeScaling["original_max_position_embeddings"] + ?? .float(8192) + else { + return base + } + + let lowFreqWavelen = oldContextLen / lowFreqFactor + let highFreqWavelen = oldContextLen / highFreqFactor + + let freqs = (0 ..< dims).compactMap { index -> Float? in + if index % 2 == 0 { + return pow(base, Float(index) / Float(dims)) + } + return nil + } + + let newBaseFreqs = freqs.map { freq -> Float in + let wavelen = 2 * .pi / freq + let smooth = max( + 0, min(1, (wavelen - highFreqWavelen) / (lowFreqWavelen - highFreqWavelen))) + return freq * ((1 - smooth) * factor + smooth) + } + + return newBaseFreqs.reduce(0, +) / Float(newBaseFreqs.count) +} + +private class DynamicNTKScalingRoPE: Module { + let dims: Int + let maxPositionEmbeddings: Int? + let traditional: Bool + let base: Float + var scale: Float + let ropeType: String + let ropeScaling: [String: StringOrNumber]? + + init( + dims: Int, maxPositionEmbeddings: Int?, traditional: Bool = false, + base: Float = 10000, scale: Float = 1.0, ropeType: String = "default", + ropeScaling: [String: StringOrNumber]? = nil + ) { + self.dims = dims + self.maxPositionEmbeddings = maxPositionEmbeddings + self.traditional = traditional + self.base = computeBaseFrequency( + base: base, dims: dims, ropeType: ropeType, ropeScaling: ropeScaling) + self.scale = scale + self.ropeType = ropeType + self.ropeScaling = ropeScaling + } + + func callAsFunction(_ x: MLXArray, offset: Int = 0) -> MLXArray { + let seqLen = x.dim(1) + offset + var base = self.base + if let maxPositionEmbeddings, seqLen > maxPositionEmbeddings { + let factorAdjustment = Float(seqLen) / Float(maxPositionEmbeddings) - 1 + let dimensionRatio = Float(dims) / Float(Float(dims) - 2) + let adjustedScale = scale * pow(1 + factorAdjustment, dimensionRatio) + base *= adjustedScale + } + return MLXFast.RoPE( + x, dimensions: dims, traditional: traditional, base: base, scale: scale, offset: offset) + } +} + private class Attention: Module { let args: LlamaConfiguration @@ -17,9 +97,9 @@ private class Attention: Module { @ModuleInfo(key: "v_proj") var wv: Linear @ModuleInfo(key: "o_proj") var wo: Linear - let rope: RoPE + let rope: DynamicNTKScalingRoPE - public init(_ args: LlamaConfiguration) { + init(_ args: LlamaConfiguration) { self.args = args let dim = args.hiddenSize @@ -29,31 +109,28 @@ private class Attention: Module { let headDim = args.headDimensions ?? (args.hiddenSize / heads) self.scale = pow(Float(headDim), -0.5) - self._wq.wrappedValue = Linear(dim, heads * headDim, bias: false) - self._wk.wrappedValue = Linear(dim, kvHeads * headDim, bias: false) - self._wv.wrappedValue = Linear(dim, kvHeads * headDim, bias: false) - self._wo.wrappedValue = Linear(heads * headDim, dim, bias: false) + self._wq.wrappedValue = Linear(dim, heads * headDim, bias: args.attentionBias) + self._wk.wrappedValue = Linear(dim, kvHeads * headDim, bias: args.attentionBias) + self._wv.wrappedValue = Linear(dim, kvHeads * headDim, bias: args.attentionBias) + self._wo.wrappedValue = Linear(heads * headDim, dim, bias: args.attentionBias) - let ropeScale: Float - if let ropeScaling = args.ropeScaling, ropeScaling["type"] == .string("linear"), - let factor = ropeScaling["factor"] - { - switch factor { - case .string: - fatalError("ropeScaling.factor must be a float") - case .float(let v): - ropeScale = 1 / v - } - } else { - ropeScale = 1 - } - - self.rope = RoPE( - dimensions: headDim, traditional: args.ropeTraditional, base: args.ropeTheta, - scale: ropeScale) + self.rope = DynamicNTKScalingRoPE( + dims: headDim, + maxPositionEmbeddings: args.maxPositionEmbeddings, + traditional: args.ropeTraditional, + base: args.ropeTheta, + scale: 1.0, + ropeType: { + if case .string(let value) = args.ropeScaling?["type"] { + return value + } else { + return "default" + } + }(), + ropeScaling: args.ropeScaling) } - public func callAsFunction( + func callAsFunction( _ x: MLXArray, mask: MLXArray? = nil, cache: (MLXArray, MLXArray)? = nil ) -> (MLXArray, (MLXArray, MLXArray)) { let (B, L) = (x.dim(0), x.dim(1)) @@ -62,7 +139,7 @@ private class Attention: Module { var keys = wk(x) var values = wv(x) - // prepare the queries, keys and values for the attention computation + // Prepare the queries, keys and values for the attention computation queries = queries.reshaped(B, L, args.attentionHeads, -1).transposed(0, 2, 1, 3) keys = keys.reshaped(B, L, args.kvHeads, -1).transposed(0, 2, 1, 3) values = values.reshaped(B, L, args.kvHeads, -1).transposed(0, 2, 1, 3) @@ -93,35 +170,35 @@ private class MLP: Module, UnaryLayer { @ModuleInfo(key: "down_proj") var down: Linear @ModuleInfo(key: "up_proj") var up: Linear - public init(dimensions: Int, hiddenDimensions: Int) { - self._gate.wrappedValue = Linear(dimensions, hiddenDimensions, bias: false) - self._down.wrappedValue = Linear(hiddenDimensions, dimensions, bias: false) - self._up.wrappedValue = Linear(dimensions, hiddenDimensions, bias: false) + init(_ args: LlamaConfiguration) { + self._gate.wrappedValue = Linear(args.hiddenSize, args.intermediateSize, bias: args.mlpBias) + self._down.wrappedValue = Linear(args.intermediateSize, args.hiddenSize, bias: args.mlpBias) + self._up.wrappedValue = Linear(args.hiddenSize, args.intermediateSize, bias: args.mlpBias) } - public func callAsFunction(_ x: MLXArray) -> MLXArray { - down(silu(gate(x)) * up(x)) + func callAsFunction(_ x: MLXArray) -> MLXArray { + let activation = silu(gate(x)) + return down(activation * up(x)) } } private class TransformerBlock: Module { - @ModuleInfo(key: "self_attn") var attention: Attention - let mlp: MLP + @ModuleInfo(key: "mlp") var mlp: MLP @ModuleInfo(key: "input_layernorm") var inputLayerNorm: RMSNorm @ModuleInfo(key: "post_attention_layernorm") var postAttentionLayerNorm: RMSNorm - public init(_ args: LlamaConfiguration) { + init(_ args: LlamaConfiguration) { self._attention.wrappedValue = Attention(args) - self.mlp = MLP(dimensions: args.hiddenSize, hiddenDimensions: args.intermediateSize) + self._mlp.wrappedValue = MLP(args) self._inputLayerNorm.wrappedValue = RMSNorm( dimensions: args.hiddenSize, eps: args.rmsNormEps) self._postAttentionLayerNorm.wrappedValue = RMSNorm( dimensions: args.hiddenSize, eps: args.rmsNormEps) } - public func callAsFunction( + func callAsFunction( _ x: MLXArray, mask: MLXArray? = nil, cache: (MLXArray, MLXArray)? = nil ) -> (MLXArray, (MLXArray, MLXArray)) { var (r, cache) = attention(inputLayerNorm(x), mask: mask, cache: cache) @@ -132,27 +209,24 @@ private class TransformerBlock: Module { } } -public class LlamaModelInner: Module { +private class LlamaModelInner: Module { @ModuleInfo(key: "embed_tokens") var embedTokens: Embedding - fileprivate let layers: [TransformerBlock] + let layers: [TransformerBlock] let norm: RMSNorm - public init(_ args: LlamaConfiguration) { + init(_ args: LlamaConfiguration) { precondition(args.vocabularySize > 0) self._embedTokens.wrappedValue = Embedding( embeddingCount: args.vocabularySize, dimensions: args.hiddenSize) - self.layers = (0 ..< args.hiddenLayers) - .map { _ in - TransformerBlock(args) - } + self.layers = (0 ..< args.hiddenLayers).map { _ in TransformerBlock(args) } self.norm = RMSNorm(dimensions: args.hiddenSize, eps: args.rmsNormEps) } - public func callAsFunction(_ inputs: MLXArray, cache: [(MLXArray, MLXArray)]? = nil) -> ( + func callAsFunction(_ inputs: MLXArray, cache: [(MLXArray, MLXArray)]? = nil) -> ( MLXArray, [(MLXArray, MLXArray)] ) { var h = embedTokens(inputs) @@ -178,7 +252,7 @@ public class LlamaModelInner: Module { public class LlamaModel: Module, LLMModel { public let vocabularySize: Int - let model: LlamaModelInner + fileprivate let model: LlamaModelInner @ModuleInfo(key: "lm_head") var lmHead: Linear? @@ -202,7 +276,7 @@ public class LlamaModel: Module, LLMModel { } public func sanitize(weights: [String: MLXArray]) -> [String: MLXArray] { - // Remove unused precomputed rotary freqs + // Remove unused precomputed rotary frequencies weights.filter { !$0.key.contains("self_attn.rotary_emb.inv_freq") } @@ -215,14 +289,17 @@ public struct LlamaConfiguration: Codable { var hiddenLayers: Int var intermediateSize: Int var attentionHeads: Int - var headDimensions: Int? = nil + var headDimensions: Int? var rmsNormEps: Float var vocabularySize: Int var kvHeads: Int + var maxPositionEmbeddings: Int? var ropeTheta: Float = 10_000 var ropeTraditional: Bool = false - var ropeScaling: [String: StringOrNumber]? = nil - var tieWordEmbeddings: Bool = false + var ropeScaling: [String: StringOrNumber]? + var tieWordEmbeddings: Bool = true + var attentionBias: Bool = false + var mlpBias: Bool = false enum CodingKeys: String, CodingKey { case hiddenSize = "hidden_size" @@ -233,45 +310,75 @@ public struct LlamaConfiguration: Codable { case rmsNormEps = "rms_norm_eps" case vocabularySize = "vocab_size" case kvHeads = "num_key_value_heads" + case maxPositionEmbeddings = "max_position_embeddings" case ropeTheta = "rope_theta" case ropeTraditional = "rope_traditional" case ropeScaling = "rope_scaling" case tieWordEmbeddings = "tie_word_embeddings" + case attentionBias = "attention_bias" + case mlpBias = "mlp_bias" } public init(from decoder: Decoder) throws { - // custom implementation to handle optional keys with required values - let container: KeyedDecodingContainer = - try decoder.container( - keyedBy: LlamaConfiguration.CodingKeys.self) + let container = try decoder.container(keyedBy: CodingKeys.self) - self.hiddenSize = try container.decode( - Int.self, forKey: LlamaConfiguration.CodingKeys.hiddenSize) - self.hiddenLayers = try container.decode( - Int.self, forKey: LlamaConfiguration.CodingKeys.hiddenLayers) - self.intermediateSize = try container.decode( - Int.self, forKey: LlamaConfiguration.CodingKeys.intermediateSize) - self.attentionHeads = try container.decode( - Int.self, forKey: LlamaConfiguration.CodingKeys.attentionHeads) - self.headDimensions = try container.decodeIfPresent( - Int.self, forKey: LlamaConfiguration.CodingKeys.headDimensions) - self.rmsNormEps = try container.decode( - Float.self, forKey: LlamaConfiguration.CodingKeys.rmsNormEps) - self.vocabularySize = try container.decode( - Int.self, forKey: LlamaConfiguration.CodingKeys.vocabularySize) - self.kvHeads = try container.decode(Int.self, forKey: LlamaConfiguration.CodingKeys.kvHeads) - self.ropeTheta = - try container.decodeIfPresent( - Float.self, forKey: LlamaConfiguration.CodingKeys.ropeTheta) - ?? 10_000 - self.ropeTraditional = - try container.decodeIfPresent( - Bool.self, forKey: LlamaConfiguration.CodingKeys.ropeTraditional) ?? false - self.ropeScaling = try container.decodeIfPresent( - [String: StringOrNumber].self, forKey: LlamaConfiguration.CodingKeys.ropeScaling) - self.tieWordEmbeddings = - try container.decodeIfPresent(Bool.self, forKey: .tieWordEmbeddings) ?? false + hiddenSize = try container.decode(Int.self, forKey: .hiddenSize) + hiddenLayers = try container.decode(Int.self, forKey: .hiddenLayers) + intermediateSize = try container.decode(Int.self, forKey: .intermediateSize) + attentionHeads = try container.decode(Int.self, forKey: .attentionHeads) + headDimensions = try container.decodeIfPresent(Int.self, forKey: .headDimensions) + rmsNormEps = try container.decode(Float.self, forKey: .rmsNormEps) + vocabularySize = try container.decode(Int.self, forKey: .vocabularySize) + kvHeads = try container.decodeIfPresent(Int.self, forKey: .kvHeads) ?? attentionHeads + maxPositionEmbeddings = try container.decodeIfPresent( + Int.self, forKey: .maxPositionEmbeddings) + if let ropeTheta = try container.decodeIfPresent(Float.self, forKey: .ropeTheta) { + self.ropeTheta = ropeTheta + } + if let ropeTraditional = try container.decodeIfPresent(Bool.self, forKey: .ropeTraditional) + { + self.ropeTraditional = ropeTraditional + } + ropeScaling = try container.decodeIfPresent( + [String: StringOrNumber].self, forKey: .ropeScaling) + if let tieWordEmbeddings = try container.decodeIfPresent( + Bool.self, forKey: .tieWordEmbeddings) + { + self.tieWordEmbeddings = tieWordEmbeddings + } + if let attentionBias = try container.decodeIfPresent(Bool.self, forKey: .attentionBias) { + self.attentionBias = attentionBias + } + if let mlpBias = try container.decodeIfPresent(Bool.self, forKey: .mlpBias) { + self.mlpBias = mlpBias + } + if let ropeScaling { + if ropeScaling["factor"] == nil { + throw DecodingError.dataCorruptedError( + forKey: .ropeScaling, in: container, + debugDescription: "rope_scaling must contain 'factor'") + } + if let ropeType = ropeScaling["type"] ?? ropeScaling["rope_type"] { + if case .string = ropeType { + let options = [ + StringOrNumber.string("linear"), StringOrNumber.string("dynamic"), + StringOrNumber.string("llama3"), + ] + if !options.contains(ropeType) { + throw DecodingError.dataCorruptedError( + forKey: .ropeScaling, in: container, + debugDescription: + "rope_scaling 'type' currently only supports 'linear', 'dynamic', or 'llama3'" + ) + } + } + } else { + throw DecodingError.dataCorruptedError( + forKey: .ropeScaling, in: container, + debugDescription: "rope_scaling must contain either 'type' or 'rope_type'") + } + } } } diff --git a/Libraries/LLM/Models.swift b/Libraries/LLM/Models.swift index 73ab0ab..9a9b7ba 100644 --- a/Libraries/LLM/Models.swift +++ b/Libraries/LLM/Models.swift @@ -151,7 +151,7 @@ extension ModelConfiguration { defaultPrompt: "Why is the sky blue?" ) - public static let phi34bit = ModelConfiguration( + public static let phi3_4bit = ModelConfiguration( id: "mlx-community/Phi-3-mini-4k-instruct-4bit-no-q-embed", defaultPrompt: "what is the gravity on mars and the moon?", extraEOSTokens: ["<|end|>"] @@ -199,9 +199,17 @@ extension ModelConfiguration { "\(prompt)" } - public static let llama38B4bit = ModelConfiguration( + public static let llama3_1_8B_4bit = ModelConfiguration( + id: "mlx-community/Meta-Llama-3.1-8B-Instruct-4bit", + defaultPrompt: "What is the difference between a fruit and a vegetable?" + ) { + prompt in + "<|begin_of_text|><|start_header_id|>system<|end_header_id|>\nYou are a helpful assistant<|eot_id|>\n<|start_header_id|>user<|end_header_id|>\n\(prompt)<|eot_id|>\n<|start_header_id|>assistant<|end_header_id|>" + } + + public static let llama3_8B_4bit = ModelConfiguration( id: "mlx-community/Meta-Llama-3-8B-Instruct-4bit", - defaultPrompt: "what is the difference between a fruit and a vegetable?" + defaultPrompt: "What is the difference between a fruit and a vegetable?" ) { prompt in "<|begin_of_text|><|start_header_id|>system<|end_header_id|>\nYou are a helpful assistant<|eot_id|>\n<|start_header_id|>user<|end_header_id|>\n\(prompt)<|eot_id|>\n<|start_header_id|>assistant<|end_header_id|>" @@ -220,12 +228,13 @@ extension ModelConfiguration { case .idle: bootstrapState = .bootstrapping register(configurations: [ + llama3_1_8B_4bit, mistralNeMo4bit, smolLM_135M_4bit, mistral7B4bit, codeLlama13b4bit, phi4bit, - phi34bit, + phi3_4bit, gemma2bQuantized, gemma_2_9b_it_4bit, qwen205b4bit,