diff --git a/Libraries/LLM/Gemma.swift b/Libraries/LLM/Gemma.swift index b934629..ae3a7fa 100644 --- a/Libraries/LLM/Gemma.swift +++ b/Libraries/LLM/Gemma.swift @@ -202,9 +202,11 @@ public class GemmaModelInner: Module { public class GemmaModel: Module, LLMModel { + public let vocabularySize: Int let model: GemmaModelInner public init(_ args: GemmaConfiguration) { + self.vocabularySize = args.vocabularySize self.model = GemmaModelInner(args) } diff --git a/Libraries/LLM/LLMModel.swift b/Libraries/LLM/LLMModel.swift index 7bb6f8e..885dce3 100644 --- a/Libraries/LLM/LLMModel.swift +++ b/Libraries/LLM/LLMModel.swift @@ -6,6 +6,9 @@ import MLXNN // Interface for all LLM Models public protocol LLMModel: Module { + + var vocabularySize: Int { get } + func callAsFunction(_ inputs: MLXArray, cache: [(MLXArray, MLXArray)]?) -> ( MLXArray, [(MLXArray, MLXArray)] ) diff --git a/Libraries/LLM/Llama.swift b/Libraries/LLM/Llama.swift index ac64551..90d2296 100644 --- a/Libraries/LLM/Llama.swift +++ b/Libraries/LLM/Llama.swift @@ -187,11 +187,13 @@ public class LlamaModelInner: Module { public class LlamaModel: Module, LLMModel { + public let vocabularySize: Int let model: LlamaModelInner @ModuleInfo(key: "lm_head") var lmHead: Linear public init(_ args: LlamaConfiguration) { + self.vocabularySize = args.vocabularySize self.model = LlamaModelInner(args) self._lmHead.wrappedValue = Linear(args.hiddenSize, args.vocabularySize, bias: false) } diff --git a/Libraries/LLM/Phi.swift b/Libraries/LLM/Phi.swift index 1f7d6b5..e4a55eb 100644 --- a/Libraries/LLM/Phi.swift +++ b/Libraries/LLM/Phi.swift @@ -6,54 +6,7 @@ import MLXNN // https://github.com/ml-explore/mlx-examples/blob/main/llms/mlx_lm/models/phi.py -// TODO: remove once open classes are in - -public class MLXLayerNorm: Module, UnaryLayer { - - let dimensions: Int - let eps: Float - - let weight: MLXArray? - let bias: MLXArray? - - /// Applies layer normalization [1] on the inputs. - /// - /// See [LayerNorm python docs](https://ml-explore.github.io/mlx/build/html/python/nn/_autosummary/mlx.nn.LayerNorm.html) for more information. - /// - /// ### References - /// 1. [https://arxiv.org/abs/1607.06450](https://arxiv.org/abs/1607.06450) - /// - /// - Parameters: - /// - dimensions: number of features in the input - /// - eps: value added to the denominator for numerical stability - /// - affine: if `true` adds a trainable `weight` and `bias` - public init(dimensions: Int, eps: Float = 1e-5, affine: Bool = true) { - self.dimensions = dimensions - self.eps = eps - - if affine { - self.weight = MLXArray.ones([dimensions]) - self.bias = MLXArray.zeros([dimensions]) - } else { - self.weight = nil - self.bias = nil - } - } - - public func callAsFunction(_ x: MLXArray) -> MLXArray { - let means = mean(x, axis: -1, keepDims: true) - let variance = variance(x, axis: -1, keepDims: true) - let x = (x - means) * rsqrt(variance + eps) - - if let weight, let bias { - return weight * x + bias - } else { - return x - } - } -} - -private class LayerNorm: MLXLayerNorm { +private class LayerNorm: MLXNN.LayerNorm { override func callAsFunction(_ x: MLXArray) -> MLXArray { super.callAsFunction(x.asType(Float.self)).asType(x.dtype) } @@ -223,11 +176,14 @@ private class PhiModelInner: Module { public class PhiModel: Module, LLMModel { + public let vocabularySize: Int + fileprivate let model: PhiModelInner @ModuleInfo(key: "lm_head") var lmHead: Linear public init(_ args: PhiConfiguration) { + self.vocabularySize = args.vocabularySize self.model = PhiModelInner(args) self._lmHead.wrappedValue = Linear(args.hiddenSize, args.vocabularySize, bias: true) } diff --git a/Libraries/LLM/Util.swift b/Libraries/LLM/Util.swift index c579179..233fd57 100644 --- a/Libraries/LLM/Util.swift +++ b/Libraries/LLM/Util.swift @@ -8,12 +8,16 @@ import MLXNN import MLXRandom import Tokenizers +struct LLMError: Error { + let message: String +} + /// Load and return the model and tokenizer public func load( hub: HubApi = HubApi(), name: String, progressHandler: @escaping (Progress) -> Void = { _ in } ) async throws -> (LLMModel, Tokenizer) { // note: this doesn't have a way to pass the HubApi - let tokenizer = try await AutoTokenizer.from(pretrained: name) + let tokenizer = try await loadTokenizer(name: name) // download the model weights and config let repo = Hub.Repo(id: name) @@ -28,21 +32,80 @@ public func load( let model = try baseConfig.modelType.createModel(configuration: configurationURL) - // set up the model + // load the weights + let weights = try loadArrays(url: modelDirectory.appending(component: "weights.00.safetensors")) + + // quantize if needed if let quantization = baseConfig.quantization { - QuantizedLinear.quantize( - model: model, groupSize: quantization.groupSize, bits: quantization.bits) + quantizeIfNeeded(model: model, weights: weights, quantization: quantization) } // apply the loaded weights - let weights = try loadArrays(url: modelDirectory.appending(component: "weights.00.safetensors")) let parameters = ModuleParameters.unflattened(weights) try model.update(parameters: parameters, verify: [.all]) - eval(model.parameters()) + + eval(model) return (model, tokenizer) } +public func loadTokenizer(name: String) async throws -> Tokenizer { + // from AutoTokenizer.from() -- this lets us override parts of the configuration + let config = LanguageModelConfigurationFromHub(modelName: name) + guard var tokenizerConfig = try await config.tokenizerConfig else { + throw LLMError(message: "missing config") + } + let tokenizerData = try await config.tokenizerData + + if let tokenizerClass = tokenizerConfig.tokenizerClass?.stringValue, + let replacement = replacementTokenizers[tokenizerClass] + { + var dictionary = tokenizerConfig.dictionary + dictionary["tokenizer_class"] = replacement + tokenizerConfig = Config(dictionary) + } + + return try PreTrainedTokenizer(tokenizerConfig: tokenizerConfig, tokenizerData: tokenizerData) +} + +/// overrides for TokenizerModel/knownTokenizers +let replacementTokenizers = [ + "CodeLlamaTokenizer": "LlamaTokenizer" +] + +private func quantizeIfNeeded( + model: LLMModel, weights: [String: MLXArray], quantization: BaseConfiguration.Quantization +) { + + func linearPredicate(layer: Module) -> Bool { + if let layer = layer as? Linear { + // avoid quantizing gate layers, otherwise we have to re-quant and upload all the mixtral models + return layer.weight.dim(0) != 8 + } + return false + } + + var predicate = linearPredicate(layer:) + + // for legacy models that don't have lm_head quant due to non-32 dims + if weights["lm_head.scales"] == nil { + let vocabularySize = model.vocabularySize + + func vocabularySizePredicate(layer: Module) -> Bool { + if let layer = layer as? Linear { + return layer.weight.dim(0) != 8 && layer.weight.dim(0) != vocabularySize + } + return false + } + + predicate = vocabularySizePredicate(layer:) + } + + QuantizedLinear.quantize( + model: model, groupSize: quantization.groupSize, bits: quantization.bits, + predicate: predicate) +} + private func sample(logits: MLXArray, temp: Float) -> MLXArray { if temp == 0 { return argMax(logits, axis: -1)