handle partially quantized models (#76)
* handle partially quantized models - fix for #53 #71 #69 #74 - in order to test the models - I added a default prompt of an appropriate form - while working on the model configuration also added additional stop tokens (#74) - fixed the repetitionPenalty code (#71)
This commit is contained in:
@@ -12,7 +12,7 @@ private func topPSampling(logits: MLXArray, topP: Float, temp: Float) -> MLXArra
|
||||
logits = logits.asType(.float32)
|
||||
}
|
||||
|
||||
let probs = softMax(logits / temp, axis: -1)
|
||||
let probs = softmax(logits / temp, axis: -1)
|
||||
let sortedIndices = argSort(probs, axis: -1)
|
||||
|
||||
// probs shape is [B,V] and after take it will be [1, B, V], so we squeeze it back to [B, V]
|
||||
@@ -31,7 +31,7 @@ private func applyRepetitionPenalty(
|
||||
) -> MLXArray {
|
||||
if repetitionContext.shape[0] > 0 {
|
||||
let indices = repetitionContext
|
||||
var selectedLogits = take(logits, indices, axis: -1).squeezed(axis: 0)
|
||||
var selectedLogits = logits[0..., indices]
|
||||
|
||||
selectedLogits = MLX.where(
|
||||
selectedLogits .< 0, selectedLogits * penalty, selectedLogits / penalty)
|
||||
@@ -100,7 +100,7 @@ public struct TokenIterator: Sequence, IteratorProtocol {
|
||||
if prompt.shape[0] <= parameters.repetitionContextSize {
|
||||
self.repetitionContext = prompt
|
||||
} else {
|
||||
self.repetitionContext = prompt[-parameters.repetitionContextSize ... -1]
|
||||
self.repetitionContext = prompt[(-parameters.repetitionContextSize)...]
|
||||
}
|
||||
} else {
|
||||
self.repetitionContext = []
|
||||
@@ -120,9 +120,8 @@ public struct TokenIterator: Sequence, IteratorProtocol {
|
||||
y = sample(logits: logits, temp: parameters.temperature, topP: parameters.topP)
|
||||
// append the current token to the context and check repetitionPenalty context see if need to remove the first token
|
||||
if parameters.repetitionContextSize > 1 {
|
||||
repetitionContext = concatenated([repetitionContext, y], axis: 0)
|
||||
if repetitionContext.shape[0] > parameters.repetitionContextSize {
|
||||
repetitionContext = repetitionContext[1...]
|
||||
repetitionContext = repetitionContext[(-parameters.repetitionContextSize)...]
|
||||
}
|
||||
}
|
||||
|
||||
@@ -174,14 +173,31 @@ public enum GenerateDisposition {
|
||||
/// - parameters: generation parameters
|
||||
/// - model: model to evaluate
|
||||
/// - tokenizer: tokenizer to convert tokens back into strings and recognizer special tokens
|
||||
/// - configuration: the model configuration
|
||||
/// - didGenerate: visitor for the tokens as they are generated
|
||||
public func generate(
|
||||
promptTokens: [Int], parameters: GenerateParameters, model: LLMModel, tokenizer: Tokenizer,
|
||||
extraEOSTokens: Set<String>? = nil,
|
||||
didGenerate: ([Int]) async -> GenerateDisposition
|
||||
) async -> GenerateResult {
|
||||
var start = Date.timeIntervalSinceReferenceDate
|
||||
var promptTime: TimeInterval = 0
|
||||
|
||||
// build a set of additional stop tokens
|
||||
let additionalEOSTokenIds = Set(
|
||||
(extraEOSTokens ?? [])
|
||||
.map {
|
||||
tokenizer.encode(text: $0)
|
||||
}
|
||||
.filter {
|
||||
// discard anything that is not a single token. sometimes
|
||||
// the tokenizer will insert a <s> token, so accept that too
|
||||
$0.count == 1 || ($0.count == 2 && $0[0] == 1)
|
||||
}
|
||||
.map {
|
||||
$0.last!
|
||||
})
|
||||
|
||||
var tokens = [Int]()
|
||||
|
||||
for token in TokenIterator(
|
||||
@@ -196,7 +212,9 @@ public func generate(
|
||||
}
|
||||
|
||||
let t = token.item(Int.self)
|
||||
if t == tokenizer.unknownTokenId || t == tokenizer.eosTokenId {
|
||||
if t == tokenizer.unknownTokenId || t == tokenizer.eosTokenId
|
||||
|| additionalEOSTokenIds.contains(t)
|
||||
{
|
||||
break
|
||||
}
|
||||
|
||||
|
||||
@@ -12,4 +12,15 @@ public protocol LLMModel: Module {
|
||||
func callAsFunction(_ inputs: MLXArray, cache: [(MLXArray, MLXArray)]?) -> (
|
||||
MLXArray, [(MLXArray, MLXArray)]
|
||||
)
|
||||
|
||||
/// Optionally preprocess the weights and modify / remove values as needed.
|
||||
func sanitize(weights: [String: MLXArray]) -> [String: MLXArray]
|
||||
}
|
||||
|
||||
extension LLMModel {
|
||||
|
||||
public func sanitize(weights: [String: MLXArray]) -> [String: MLXArray] {
|
||||
weights
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
@@ -194,6 +194,13 @@ public class LlamaModel: Module, LLMModel {
|
||||
let (out, cache) = model(inputs, cache: cache)
|
||||
return (lmHead(out), cache)
|
||||
}
|
||||
|
||||
public func sanitize(weights: [String: MLXArray]) -> [String: MLXArray] {
|
||||
// Remove unused precomputed rotary freqs
|
||||
weights.filter {
|
||||
!$0.key.contains("self_attn.rotary_emb.inv_freq")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
public struct LlamaConfiguration: Codable {
|
||||
|
||||
@@ -54,9 +54,15 @@ public func load(
|
||||
}
|
||||
}
|
||||
|
||||
// per-model cleanup
|
||||
weights = model.sanitize(weights: weights)
|
||||
|
||||
// quantize if needed
|
||||
if let quantization = baseConfig.quantization {
|
||||
quantizeIfNeeded(model: model, weights: weights, quantization: quantization)
|
||||
quantize(model: model, groupSize: quantization.groupSize, bits: quantization.bits) {
|
||||
path, module in
|
||||
weights["\(path).scales"] != nil
|
||||
}
|
||||
}
|
||||
|
||||
// apply the loaded weights
|
||||
@@ -76,38 +82,3 @@ public func load(
|
||||
hub: hub, configuration: newConfiguration, progressHandler: progressHandler)
|
||||
}
|
||||
}
|
||||
|
||||
// MARK: - Quantization
|
||||
|
||||
private func quantizeIfNeeded(
|
||||
model: LLMModel, weights: [String: MLXArray], quantization: BaseConfiguration.Quantization
|
||||
) {
|
||||
|
||||
func linearPredicate(layer: Module) -> Bool {
|
||||
if let layer = layer as? Linear {
|
||||
// avoid quantizing gate layers, otherwise we have to re-quant and upload all the mixtral models
|
||||
return layer.weight.dim(0) != 8
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
var predicate = linearPredicate(layer:)
|
||||
|
||||
// for legacy models that don't have lm_head quant due to non-32 dims
|
||||
if weights["lm_head.scales"] == nil {
|
||||
let vocabularySize = model.vocabularySize
|
||||
|
||||
func vocabularySizePredicate(layer: Module) -> Bool {
|
||||
if let layer = layer as? Linear {
|
||||
return layer.weight.dim(0) != 8 && layer.weight.dim(0) != vocabularySize
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
predicate = vocabularySizePredicate(layer:)
|
||||
}
|
||||
|
||||
QuantizedLinear.quantize(
|
||||
model: model, groupSize: quantization.groupSize, bits: quantization.bits,
|
||||
predicate: predicate)
|
||||
}
|
||||
|
||||
@@ -377,7 +377,7 @@ public enum LoRATrain {
|
||||
/// - training with ``train(model:train:validate:optimizer:loss:tokenizer:parameters:progress:)``
|
||||
/// - loss evaluation with ``evaluate(model:dataset:loss:tokenizer:batchSize:batchCount:)``
|
||||
/// - fusing with ``fuse(model:layers:deQuantize:)``
|
||||
/// - text generation with ``generate(promptTokens:parameters:model:tokenizer:didGenerate:)``
|
||||
/// - text generation with ``generate(promptTokens:parameters:model:tokenizer:additionalEOSTokens:didGenerate:)``
|
||||
/// - note that this is just using normal model text generation
|
||||
///
|
||||
/// - Parameters:
|
||||
|
||||
@@ -33,6 +33,12 @@ public struct ModelConfiguration {
|
||||
/// overrides for TokenizerModel/knownTokenizers -- useful before swift-transformers is updated
|
||||
public let overrideTokenizer: String?
|
||||
|
||||
/// A reasonable default prompt for the model
|
||||
public let defaultPrompt: String
|
||||
|
||||
/// Additional tokens to use for end of string
|
||||
public let extraEOSTokens: Set<String>
|
||||
|
||||
/// custom preparation logic for the prompt. custom tokenizers provide more capability, but this
|
||||
/// allows some minor formtting changes, e.g. wrapping the user input in the expected prompt
|
||||
/// format
|
||||
@@ -40,21 +46,29 @@ public struct ModelConfiguration {
|
||||
|
||||
public init(
|
||||
id: String, tokenizerId: String? = nil, overrideTokenizer: String? = nil,
|
||||
defaultPrompt: String = "hello",
|
||||
extraEOSTokens: Set<String> = [],
|
||||
preparePrompt: ((String) -> String)? = nil
|
||||
) {
|
||||
self.id = .id(id)
|
||||
self.tokenizerId = tokenizerId
|
||||
self.overrideTokenizer = overrideTokenizer
|
||||
self.defaultPrompt = defaultPrompt
|
||||
self.extraEOSTokens = extraEOSTokens
|
||||
self.preparePrompt = preparePrompt
|
||||
}
|
||||
|
||||
public init(
|
||||
directory: URL, tokenizerId: String? = nil, overrideTokenizer: String? = nil,
|
||||
defaultPrompt: String = "hello",
|
||||
extraEOSTokens: Set<String> = [],
|
||||
preparePrompt: ((String) -> String)? = nil
|
||||
) {
|
||||
self.id = .directory(directory)
|
||||
self.tokenizerId = tokenizerId
|
||||
self.overrideTokenizer = overrideTokenizer
|
||||
self.defaultPrompt = defaultPrompt
|
||||
self.extraEOSTokens = extraEOSTokens
|
||||
self.preparePrompt = preparePrompt
|
||||
}
|
||||
|
||||
@@ -98,11 +112,16 @@ public struct ModelConfiguration {
|
||||
extension ModelConfiguration {
|
||||
|
||||
public static let mistral7B4bit = ModelConfiguration(
|
||||
id: "mlx-community/Mistral-7B-v0.1-hf-4bit-mlx")
|
||||
id: "mlx-community/Mistral-7B-v0.1-hf-4bit-mlx",
|
||||
|
||||
// https://www.promptingguide.ai/models/mistral-7b
|
||||
defaultPrompt: "describe the swift language"
|
||||
)
|
||||
|
||||
public static let codeLlama13b4bit = ModelConfiguration(
|
||||
id: "mlx-community/CodeLlama-13b-Instruct-hf-4bit-MLX",
|
||||
overrideTokenizer: "PreTrainedTokenizer"
|
||||
overrideTokenizer: "PreTrainedTokenizer",
|
||||
defaultPrompt: "func sortArray(_ array: [Int]) -> String { <FILL_ME> }"
|
||||
) { prompt in
|
||||
// given the prompt: func sortArray(_ array: [Int]) -> String { <FILL_ME> }
|
||||
// the python code produces this (via its custom tokenizer):
|
||||
@@ -111,13 +130,17 @@ extension ModelConfiguration {
|
||||
"<PRE> " + prompt.replacingOccurrences(of: "<FILL_ME>", with: "<SUF>") + " <MID>"
|
||||
}
|
||||
|
||||
public static let phi4bit = ModelConfiguration(id: "mlx-community/phi-2-hf-4bit-mlx") {
|
||||
prompt in
|
||||
"Instruct: \(prompt)\nOutput: "
|
||||
}
|
||||
public static let phi4bit = ModelConfiguration(
|
||||
id: "mlx-community/phi-2-hf-4bit-mlx",
|
||||
|
||||
// https://www.promptingguide.ai/models/phi-2
|
||||
defaultPrompt: "Why is the sky blue?"
|
||||
)
|
||||
|
||||
public static let phi34bit = ModelConfiguration(
|
||||
id: "mlx-community/Phi-3-mini-4k-instruct-4bit-no-q-embed"
|
||||
id: "mlx-community/Phi-3-mini-4k-instruct-4bit-no-q-embed",
|
||||
defaultPrompt: "what is the gravity on mars and the moon?",
|
||||
extraEOSTokens: ["<|end|>"]
|
||||
) {
|
||||
prompt in
|
||||
"<s><|user|>\n\(prompt)<|end|>\n<|assistant|>\n"
|
||||
@@ -125,26 +148,35 @@ extension ModelConfiguration {
|
||||
|
||||
public static let gemma2bQuantized = ModelConfiguration(
|
||||
id: "mlx-community/quantized-gemma-2b-it",
|
||||
overrideTokenizer: "PreTrainedTokenizer"
|
||||
overrideTokenizer: "PreTrainedTokenizer",
|
||||
|
||||
// https://www.promptingguide.ai/models/gemma
|
||||
defaultPrompt: "what is the difference between lettuce and cabbage?"
|
||||
|
||||
) { prompt in
|
||||
"<start_of_turn>user \(prompt)<end_of_turn><start_of_turn>model"
|
||||
}
|
||||
|
||||
public static let qwen205b4bit = ModelConfiguration(
|
||||
id: "mlx-community/Qwen1.5-0.5B-Chat-4bit",
|
||||
overrideTokenizer: "PreTrainedTokenizer"
|
||||
overrideTokenizer: "PreTrainedTokenizer",
|
||||
defaultPrompt: "why is the sky blue?"
|
||||
) { prompt in
|
||||
"<|im_start|>system\nYou are a helpful assistant<|im_end|>\n<|im_start|>user\n\(prompt)<|im_end|>\n<|im_start|>assistant"
|
||||
}
|
||||
|
||||
public static let openelm270m4bit = ModelConfiguration(
|
||||
id: "mlx-community/OpenELM-270M-Instruct"
|
||||
id: "mlx-community/OpenELM-270M-Instruct",
|
||||
|
||||
// https://huggingface.co/apple/OpenELM
|
||||
defaultPrompt: "Once upon a time there was"
|
||||
) { prompt in
|
||||
"\(prompt)"
|
||||
}
|
||||
|
||||
public static let llama38B4bit = ModelConfiguration(
|
||||
id: "mlx-community/Meta-Llama-3-8B-Instruct-4bit"
|
||||
id: "mlx-community/Meta-Llama-3-8B-Instruct-4bit",
|
||||
defaultPrompt: "what is the difference between a fruit and a vegetable?"
|
||||
) {
|
||||
prompt in
|
||||
"<|begin_of_text|><|start_header_id|>system<|end_header_id|>\nYou are a helpful assistant<|eot_id|>\n<|start_header_id|>user<|end_header_id|>\n\(prompt)<|eot_id|>\n<|start_header_id|>assistant<|end_header_id|>"
|
||||
|
||||
@@ -179,10 +179,12 @@ public class Qwen2ModelInner: Module {
|
||||
public class Qwen2Model: Module, LLMModel {
|
||||
public let vocabularySize: Int
|
||||
let model: Qwen2ModelInner
|
||||
let configuration: Qwen2Configuration
|
||||
|
||||
@ModuleInfo(key: "lm_head") var lmHead: Linear
|
||||
|
||||
public init(_ args: Qwen2Configuration) {
|
||||
self.configuration = args
|
||||
self.vocabularySize = args.vocabularySize
|
||||
self.model = Qwen2ModelInner(args)
|
||||
_lmHead.wrappedValue = Linear(args.hiddenSize, args.vocabularySize, bias: false)
|
||||
@@ -191,8 +193,26 @@ public class Qwen2Model: Module, LLMModel {
|
||||
public func callAsFunction(_ inputs: MLXArray, cache: [(MLXArray, MLXArray)]?) -> (
|
||||
MLXArray, [(MLXArray, MLXArray)]
|
||||
) {
|
||||
let (out, cache) = model(inputs, cache: cache)
|
||||
return (lmHead(out), cache)
|
||||
var (out, cache) = model(inputs, cache: cache)
|
||||
if configuration.tieWordEmbeddings {
|
||||
out = model.embedTokens.asLinear(out)
|
||||
} else {
|
||||
out = lmHead(out)
|
||||
}
|
||||
return (out, cache)
|
||||
}
|
||||
|
||||
public func sanitize(weights: [String: MLXArray]) -> [String: MLXArray] {
|
||||
var weights = weights
|
||||
|
||||
if configuration.tieWordEmbeddings {
|
||||
weights["lm_head.weight"] = nil
|
||||
}
|
||||
|
||||
// Remove unused precomputed rotary freqs
|
||||
return weights.filter {
|
||||
!$0.key.contains("self_attn.rotary_emb.inv_freq")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -207,6 +227,7 @@ public struct Qwen2Configuration: Codable {
|
||||
var ropeTheta: Float = 1_000_000
|
||||
var ropeTraditional: Bool = false
|
||||
var ropeScaling: [String: StringOrNumber]? = nil
|
||||
var tieWordEmbeddings = false
|
||||
|
||||
enum CodingKeys: String, CodingKey {
|
||||
case hiddenSize = "hidden_size"
|
||||
@@ -219,6 +240,7 @@ public struct Qwen2Configuration: Codable {
|
||||
case ropeTheta = "rope_theta"
|
||||
case ropeTraditional = "rope_traditional"
|
||||
case ropeScaling = "rope_scaling"
|
||||
case tieWordEmbeddings = "tie_word_embeddings"
|
||||
}
|
||||
|
||||
public init(from decoder: Decoder) throws {
|
||||
@@ -249,6 +271,8 @@ public struct Qwen2Configuration: Codable {
|
||||
Bool.self, forKey: Qwen2Configuration.CodingKeys.ropeTraditional) ?? false
|
||||
self.ropeScaling = try container.decodeIfPresent(
|
||||
[String: StringOrNumber].self, forKey: Qwen2Configuration.CodingKeys.ropeScaling)
|
||||
self.tieWordEmbeddings =
|
||||
try container.decodeIfPresent(Bool.self, forKey: .tieWordEmbeddings) ?? false
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user