fix for #2 -- CodeLlama crashes
- add replacement tokenizer class for unknown tokenizers - fix quantization for models that don't have lm_head quantized Requires https://github.com/ml-explore/mlx-swift/pull/28
This commit is contained in:
@@ -202,9 +202,11 @@ public class GemmaModelInner: Module {
|
|||||||
|
|
||||||
public class GemmaModel: Module, LLMModel {
|
public class GemmaModel: Module, LLMModel {
|
||||||
|
|
||||||
|
public let vocabularySize: Int
|
||||||
let model: GemmaModelInner
|
let model: GemmaModelInner
|
||||||
|
|
||||||
public init(_ args: GemmaConfiguration) {
|
public init(_ args: GemmaConfiguration) {
|
||||||
|
self.vocabularySize = args.vocabularySize
|
||||||
self.model = GemmaModelInner(args)
|
self.model = GemmaModelInner(args)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -6,6 +6,9 @@ import MLXNN
|
|||||||
|
|
||||||
// Interface for all LLM Models
|
// Interface for all LLM Models
|
||||||
public protocol LLMModel: Module {
|
public protocol LLMModel: Module {
|
||||||
|
|
||||||
|
var vocabularySize: Int { get }
|
||||||
|
|
||||||
func callAsFunction(_ inputs: MLXArray, cache: [(MLXArray, MLXArray)]?) -> (
|
func callAsFunction(_ inputs: MLXArray, cache: [(MLXArray, MLXArray)]?) -> (
|
||||||
MLXArray, [(MLXArray, MLXArray)]
|
MLXArray, [(MLXArray, MLXArray)]
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -187,11 +187,13 @@ public class LlamaModelInner: Module {
|
|||||||
|
|
||||||
public class LlamaModel: Module, LLMModel {
|
public class LlamaModel: Module, LLMModel {
|
||||||
|
|
||||||
|
public let vocabularySize: Int
|
||||||
let model: LlamaModelInner
|
let model: LlamaModelInner
|
||||||
|
|
||||||
@ModuleInfo(key: "lm_head") var lmHead: Linear
|
@ModuleInfo(key: "lm_head") var lmHead: Linear
|
||||||
|
|
||||||
public init(_ args: LlamaConfiguration) {
|
public init(_ args: LlamaConfiguration) {
|
||||||
|
self.vocabularySize = args.vocabularySize
|
||||||
self.model = LlamaModelInner(args)
|
self.model = LlamaModelInner(args)
|
||||||
self._lmHead.wrappedValue = Linear(args.hiddenSize, args.vocabularySize, bias: false)
|
self._lmHead.wrappedValue = Linear(args.hiddenSize, args.vocabularySize, bias: false)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -6,54 +6,7 @@ import MLXNN
|
|||||||
|
|
||||||
// https://github.com/ml-explore/mlx-examples/blob/main/llms/mlx_lm/models/phi.py
|
// https://github.com/ml-explore/mlx-examples/blob/main/llms/mlx_lm/models/phi.py
|
||||||
|
|
||||||
// TODO: remove once open classes are in
|
private class LayerNorm: MLXNN.LayerNorm {
|
||||||
|
|
||||||
public class MLXLayerNorm: Module, UnaryLayer {
|
|
||||||
|
|
||||||
let dimensions: Int
|
|
||||||
let eps: Float
|
|
||||||
|
|
||||||
let weight: MLXArray?
|
|
||||||
let bias: MLXArray?
|
|
||||||
|
|
||||||
/// Applies layer normalization [1] on the inputs.
|
|
||||||
///
|
|
||||||
/// See [LayerNorm python docs](https://ml-explore.github.io/mlx/build/html/python/nn/_autosummary/mlx.nn.LayerNorm.html) for more information.
|
|
||||||
///
|
|
||||||
/// ### References
|
|
||||||
/// 1. [https://arxiv.org/abs/1607.06450](https://arxiv.org/abs/1607.06450)
|
|
||||||
///
|
|
||||||
/// - Parameters:
|
|
||||||
/// - dimensions: number of features in the input
|
|
||||||
/// - eps: value added to the denominator for numerical stability
|
|
||||||
/// - affine: if `true` adds a trainable `weight` and `bias`
|
|
||||||
public init(dimensions: Int, eps: Float = 1e-5, affine: Bool = true) {
|
|
||||||
self.dimensions = dimensions
|
|
||||||
self.eps = eps
|
|
||||||
|
|
||||||
if affine {
|
|
||||||
self.weight = MLXArray.ones([dimensions])
|
|
||||||
self.bias = MLXArray.zeros([dimensions])
|
|
||||||
} else {
|
|
||||||
self.weight = nil
|
|
||||||
self.bias = nil
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
public func callAsFunction(_ x: MLXArray) -> MLXArray {
|
|
||||||
let means = mean(x, axis: -1, keepDims: true)
|
|
||||||
let variance = variance(x, axis: -1, keepDims: true)
|
|
||||||
let x = (x - means) * rsqrt(variance + eps)
|
|
||||||
|
|
||||||
if let weight, let bias {
|
|
||||||
return weight * x + bias
|
|
||||||
} else {
|
|
||||||
return x
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
private class LayerNorm: MLXLayerNorm {
|
|
||||||
override func callAsFunction(_ x: MLXArray) -> MLXArray {
|
override func callAsFunction(_ x: MLXArray) -> MLXArray {
|
||||||
super.callAsFunction(x.asType(Float.self)).asType(x.dtype)
|
super.callAsFunction(x.asType(Float.self)).asType(x.dtype)
|
||||||
}
|
}
|
||||||
@@ -223,11 +176,14 @@ private class PhiModelInner: Module {
|
|||||||
|
|
||||||
public class PhiModel: Module, LLMModel {
|
public class PhiModel: Module, LLMModel {
|
||||||
|
|
||||||
|
public let vocabularySize: Int
|
||||||
|
|
||||||
fileprivate let model: PhiModelInner
|
fileprivate let model: PhiModelInner
|
||||||
|
|
||||||
@ModuleInfo(key: "lm_head") var lmHead: Linear
|
@ModuleInfo(key: "lm_head") var lmHead: Linear
|
||||||
|
|
||||||
public init(_ args: PhiConfiguration) {
|
public init(_ args: PhiConfiguration) {
|
||||||
|
self.vocabularySize = args.vocabularySize
|
||||||
self.model = PhiModelInner(args)
|
self.model = PhiModelInner(args)
|
||||||
self._lmHead.wrappedValue = Linear(args.hiddenSize, args.vocabularySize, bias: true)
|
self._lmHead.wrappedValue = Linear(args.hiddenSize, args.vocabularySize, bias: true)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -8,12 +8,16 @@ import MLXNN
|
|||||||
import MLXRandom
|
import MLXRandom
|
||||||
import Tokenizers
|
import Tokenizers
|
||||||
|
|
||||||
|
struct LLMError: Error {
|
||||||
|
let message: String
|
||||||
|
}
|
||||||
|
|
||||||
/// Load and return the model and tokenizer
|
/// Load and return the model and tokenizer
|
||||||
public func load(
|
public func load(
|
||||||
hub: HubApi = HubApi(), name: String, progressHandler: @escaping (Progress) -> Void = { _ in }
|
hub: HubApi = HubApi(), name: String, progressHandler: @escaping (Progress) -> Void = { _ in }
|
||||||
) async throws -> (LLMModel, Tokenizer) {
|
) async throws -> (LLMModel, Tokenizer) {
|
||||||
// note: this doesn't have a way to pass the HubApi
|
// note: this doesn't have a way to pass the HubApi
|
||||||
let tokenizer = try await AutoTokenizer.from(pretrained: name)
|
let tokenizer = try await loadTokenizer(name: name)
|
||||||
|
|
||||||
// download the model weights and config
|
// download the model weights and config
|
||||||
let repo = Hub.Repo(id: name)
|
let repo = Hub.Repo(id: name)
|
||||||
@@ -28,21 +32,80 @@ public func load(
|
|||||||
|
|
||||||
let model = try baseConfig.modelType.createModel(configuration: configurationURL)
|
let model = try baseConfig.modelType.createModel(configuration: configurationURL)
|
||||||
|
|
||||||
// set up the model
|
// load the weights
|
||||||
|
let weights = try loadArrays(url: modelDirectory.appending(component: "weights.00.safetensors"))
|
||||||
|
|
||||||
|
// quantize if needed
|
||||||
if let quantization = baseConfig.quantization {
|
if let quantization = baseConfig.quantization {
|
||||||
QuantizedLinear.quantize(
|
quantizeIfNeeded(model: model, weights: weights, quantization: quantization)
|
||||||
model: model, groupSize: quantization.groupSize, bits: quantization.bits)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// apply the loaded weights
|
// apply the loaded weights
|
||||||
let weights = try loadArrays(url: modelDirectory.appending(component: "weights.00.safetensors"))
|
|
||||||
let parameters = ModuleParameters.unflattened(weights)
|
let parameters = ModuleParameters.unflattened(weights)
|
||||||
try model.update(parameters: parameters, verify: [.all])
|
try model.update(parameters: parameters, verify: [.all])
|
||||||
eval(model.parameters())
|
|
||||||
|
eval(model)
|
||||||
|
|
||||||
return (model, tokenizer)
|
return (model, tokenizer)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public func loadTokenizer(name: String) async throws -> Tokenizer {
|
||||||
|
// from AutoTokenizer.from() -- this lets us override parts of the configuration
|
||||||
|
let config = LanguageModelConfigurationFromHub(modelName: name)
|
||||||
|
guard var tokenizerConfig = try await config.tokenizerConfig else {
|
||||||
|
throw LLMError(message: "missing config")
|
||||||
|
}
|
||||||
|
let tokenizerData = try await config.tokenizerData
|
||||||
|
|
||||||
|
if let tokenizerClass = tokenizerConfig.tokenizerClass?.stringValue,
|
||||||
|
let replacement = replacementTokenizers[tokenizerClass]
|
||||||
|
{
|
||||||
|
var dictionary = tokenizerConfig.dictionary
|
||||||
|
dictionary["tokenizer_class"] = replacement
|
||||||
|
tokenizerConfig = Config(dictionary)
|
||||||
|
}
|
||||||
|
|
||||||
|
return try PreTrainedTokenizer(tokenizerConfig: tokenizerConfig, tokenizerData: tokenizerData)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// overrides for TokenizerModel/knownTokenizers
|
||||||
|
let replacementTokenizers = [
|
||||||
|
"CodeLlamaTokenizer": "LlamaTokenizer"
|
||||||
|
]
|
||||||
|
|
||||||
|
private func quantizeIfNeeded(
|
||||||
|
model: LLMModel, weights: [String: MLXArray], quantization: BaseConfiguration.Quantization
|
||||||
|
) {
|
||||||
|
|
||||||
|
func linearPredicate(layer: Module) -> Bool {
|
||||||
|
if let layer = layer as? Linear {
|
||||||
|
// avoid quantizing gate layers, otherwise we have to re-quant and upload all the mixtral models
|
||||||
|
return layer.weight.dim(0) != 8
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
var predicate = linearPredicate(layer:)
|
||||||
|
|
||||||
|
// for legacy models that don't have lm_head quant due to non-32 dims
|
||||||
|
if weights["lm_head.scales"] == nil {
|
||||||
|
let vocabularySize = model.vocabularySize
|
||||||
|
|
||||||
|
func vocabularySizePredicate(layer: Module) -> Bool {
|
||||||
|
if let layer = layer as? Linear {
|
||||||
|
return layer.weight.dim(0) != 8 && layer.weight.dim(0) != vocabularySize
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
predicate = vocabularySizePredicate(layer:)
|
||||||
|
}
|
||||||
|
|
||||||
|
QuantizedLinear.quantize(
|
||||||
|
model: model, groupSize: quantization.groupSize, bits: quantization.bits,
|
||||||
|
predicate: predicate)
|
||||||
|
}
|
||||||
|
|
||||||
private func sample(logits: MLXArray, temp: Float) -> MLXArray {
|
private func sample(logits: MLXArray, temp: Float) -> MLXArray {
|
||||||
if temp == 0 {
|
if temp == 0 {
|
||||||
return argMax(logits, axis: -1)
|
return argMax(logits, axis: -1)
|
||||||
|
|||||||
Reference in New Issue
Block a user