handle partially quantized models (#76)
* handle partially quantized models - fix for #53 #71 #69 #74 - in order to test the models - I added a default prompt of an appropriate form - while working on the model configuration also added additional stop tokens (#74) - fixed the repetitionPenalty code (#71)
This commit is contained in:
@@ -54,9 +54,15 @@ public func load(
|
||||
}
|
||||
}
|
||||
|
||||
// per-model cleanup
|
||||
weights = model.sanitize(weights: weights)
|
||||
|
||||
// quantize if needed
|
||||
if let quantization = baseConfig.quantization {
|
||||
quantizeIfNeeded(model: model, weights: weights, quantization: quantization)
|
||||
quantize(model: model, groupSize: quantization.groupSize, bits: quantization.bits) {
|
||||
path, module in
|
||||
weights["\(path).scales"] != nil
|
||||
}
|
||||
}
|
||||
|
||||
// apply the loaded weights
|
||||
@@ -76,38 +82,3 @@ public func load(
|
||||
hub: hub, configuration: newConfiguration, progressHandler: progressHandler)
|
||||
}
|
||||
}
|
||||
|
||||
// MARK: - Quantization
|
||||
|
||||
private func quantizeIfNeeded(
|
||||
model: LLMModel, weights: [String: MLXArray], quantization: BaseConfiguration.Quantization
|
||||
) {
|
||||
|
||||
func linearPredicate(layer: Module) -> Bool {
|
||||
if let layer = layer as? Linear {
|
||||
// avoid quantizing gate layers, otherwise we have to re-quant and upload all the mixtral models
|
||||
return layer.weight.dim(0) != 8
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
var predicate = linearPredicate(layer:)
|
||||
|
||||
// for legacy models that don't have lm_head quant due to non-32 dims
|
||||
if weights["lm_head.scales"] == nil {
|
||||
let vocabularySize = model.vocabularySize
|
||||
|
||||
func vocabularySizePredicate(layer: Module) -> Bool {
|
||||
if let layer = layer as? Linear {
|
||||
return layer.weight.dim(0) != 8 && layer.weight.dim(0) != vocabularySize
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
predicate = vocabularySizePredicate(layer:)
|
||||
}
|
||||
|
||||
QuantizedLinear.quantize(
|
||||
model: model, groupSize: quantization.groupSize, bits: quantization.bits,
|
||||
predicate: predicate)
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user