handle partially quantized models (#76)

* handle partially quantized models

- fix for #53 #71 #69 #74
- in order to test the models
	- I added a default prompt of an appropriate form
	- while working on the model configuration also added additional stop tokens (#74)
- fixed the repetitionPenalty code (#71)
This commit is contained in:
David Koski
2024-05-28 16:35:11 -07:00
committed by GitHub
parent 65f4968e5f
commit 9d74afd119
12 changed files with 139 additions and 67 deletions

View File

@@ -12,7 +12,7 @@ private func topPSampling(logits: MLXArray, topP: Float, temp: Float) -> MLXArra
logits = logits.asType(.float32)
}
let probs = softMax(logits / temp, axis: -1)
let probs = softmax(logits / temp, axis: -1)
let sortedIndices = argSort(probs, axis: -1)
// probs shape is [B,V] and after take it will be [1, B, V], so we squeeze it back to [B, V]
@@ -31,7 +31,7 @@ private func applyRepetitionPenalty(
) -> MLXArray {
if repetitionContext.shape[0] > 0 {
let indices = repetitionContext
var selectedLogits = take(logits, indices, axis: -1).squeezed(axis: 0)
var selectedLogits = logits[0..., indices]
selectedLogits = MLX.where(
selectedLogits .< 0, selectedLogits * penalty, selectedLogits / penalty)
@@ -100,7 +100,7 @@ public struct TokenIterator: Sequence, IteratorProtocol {
if prompt.shape[0] <= parameters.repetitionContextSize {
self.repetitionContext = prompt
} else {
self.repetitionContext = prompt[-parameters.repetitionContextSize ... -1]
self.repetitionContext = prompt[(-parameters.repetitionContextSize)...]
}
} else {
self.repetitionContext = []
@@ -120,9 +120,8 @@ public struct TokenIterator: Sequence, IteratorProtocol {
y = sample(logits: logits, temp: parameters.temperature, topP: parameters.topP)
// append the current token to the context and check repetitionPenalty context see if need to remove the first token
if parameters.repetitionContextSize > 1 {
repetitionContext = concatenated([repetitionContext, y], axis: 0)
if repetitionContext.shape[0] > parameters.repetitionContextSize {
repetitionContext = repetitionContext[1...]
repetitionContext = repetitionContext[(-parameters.repetitionContextSize)...]
}
}
@@ -174,14 +173,31 @@ public enum GenerateDisposition {
/// - parameters: generation parameters
/// - model: model to evaluate
/// - tokenizer: tokenizer to convert tokens back into strings and recognizer special tokens
/// - configuration: the model configuration
/// - didGenerate: visitor for the tokens as they are generated
public func generate(
promptTokens: [Int], parameters: GenerateParameters, model: LLMModel, tokenizer: Tokenizer,
extraEOSTokens: Set<String>? = nil,
didGenerate: ([Int]) async -> GenerateDisposition
) async -> GenerateResult {
var start = Date.timeIntervalSinceReferenceDate
var promptTime: TimeInterval = 0
// build a set of additional stop tokens
let additionalEOSTokenIds = Set(
(extraEOSTokens ?? [])
.map {
tokenizer.encode(text: $0)
}
.filter {
// discard anything that is not a single token. sometimes
// the tokenizer will insert a <s> token, so accept that too
$0.count == 1 || ($0.count == 2 && $0[0] == 1)
}
.map {
$0.last!
})
var tokens = [Int]()
for token in TokenIterator(
@@ -196,7 +212,9 @@ public func generate(
}
let t = token.item(Int.self)
if t == tokenizer.unknownTokenId || t == tokenizer.eosTokenId {
if t == tokenizer.unknownTokenId || t == tokenizer.eosTokenId
|| additionalEOSTokenIds.contains(t)
{
break
}

View File

@@ -12,4 +12,15 @@ public protocol LLMModel: Module {
func callAsFunction(_ inputs: MLXArray, cache: [(MLXArray, MLXArray)]?) -> (
MLXArray, [(MLXArray, MLXArray)]
)
/// Optionally preprocess the weights and modify / remove values as needed.
func sanitize(weights: [String: MLXArray]) -> [String: MLXArray]
}
extension LLMModel {
public func sanitize(weights: [String: MLXArray]) -> [String: MLXArray] {
weights
}
}

View File

@@ -194,6 +194,13 @@ public class LlamaModel: Module, LLMModel {
let (out, cache) = model(inputs, cache: cache)
return (lmHead(out), cache)
}
public func sanitize(weights: [String: MLXArray]) -> [String: MLXArray] {
// Remove unused precomputed rotary freqs
weights.filter {
!$0.key.contains("self_attn.rotary_emb.inv_freq")
}
}
}
public struct LlamaConfiguration: Codable {

View File

@@ -54,9 +54,15 @@ public func load(
}
}
// per-model cleanup
weights = model.sanitize(weights: weights)
// quantize if needed
if let quantization = baseConfig.quantization {
quantizeIfNeeded(model: model, weights: weights, quantization: quantization)
quantize(model: model, groupSize: quantization.groupSize, bits: quantization.bits) {
path, module in
weights["\(path).scales"] != nil
}
}
// apply the loaded weights
@@ -76,38 +82,3 @@ public func load(
hub: hub, configuration: newConfiguration, progressHandler: progressHandler)
}
}
// MARK: - Quantization
private func quantizeIfNeeded(
model: LLMModel, weights: [String: MLXArray], quantization: BaseConfiguration.Quantization
) {
func linearPredicate(layer: Module) -> Bool {
if let layer = layer as? Linear {
// avoid quantizing gate layers, otherwise we have to re-quant and upload all the mixtral models
return layer.weight.dim(0) != 8
}
return false
}
var predicate = linearPredicate(layer:)
// for legacy models that don't have lm_head quant due to non-32 dims
if weights["lm_head.scales"] == nil {
let vocabularySize = model.vocabularySize
func vocabularySizePredicate(layer: Module) -> Bool {
if let layer = layer as? Linear {
return layer.weight.dim(0) != 8 && layer.weight.dim(0) != vocabularySize
}
return false
}
predicate = vocabularySizePredicate(layer:)
}
QuantizedLinear.quantize(
model: model, groupSize: quantization.groupSize, bits: quantization.bits,
predicate: predicate)
}

View File

@@ -377,7 +377,7 @@ public enum LoRATrain {
/// - training with ``train(model:train:validate:optimizer:loss:tokenizer:parameters:progress:)``
/// - loss evaluation with ``evaluate(model:dataset:loss:tokenizer:batchSize:batchCount:)``
/// - fusing with ``fuse(model:layers:deQuantize:)``
/// - text generation with ``generate(promptTokens:parameters:model:tokenizer:didGenerate:)``
/// - text generation with ``generate(promptTokens:parameters:model:tokenizer:additionalEOSTokens:didGenerate:)``
/// - note that this is just using normal model text generation
///
/// - Parameters:

View File

@@ -33,6 +33,12 @@ public struct ModelConfiguration {
/// overrides for TokenizerModel/knownTokenizers -- useful before swift-transformers is updated
public let overrideTokenizer: String?
/// A reasonable default prompt for the model
public let defaultPrompt: String
/// Additional tokens to use for end of string
public let extraEOSTokens: Set<String>
/// custom preparation logic for the prompt. custom tokenizers provide more capability, but this
/// allows some minor formtting changes, e.g. wrapping the user input in the expected prompt
/// format
@@ -40,21 +46,29 @@ public struct ModelConfiguration {
public init(
id: String, tokenizerId: String? = nil, overrideTokenizer: String? = nil,
defaultPrompt: String = "hello",
extraEOSTokens: Set<String> = [],
preparePrompt: ((String) -> String)? = nil
) {
self.id = .id(id)
self.tokenizerId = tokenizerId
self.overrideTokenizer = overrideTokenizer
self.defaultPrompt = defaultPrompt
self.extraEOSTokens = extraEOSTokens
self.preparePrompt = preparePrompt
}
public init(
directory: URL, tokenizerId: String? = nil, overrideTokenizer: String? = nil,
defaultPrompt: String = "hello",
extraEOSTokens: Set<String> = [],
preparePrompt: ((String) -> String)? = nil
) {
self.id = .directory(directory)
self.tokenizerId = tokenizerId
self.overrideTokenizer = overrideTokenizer
self.defaultPrompt = defaultPrompt
self.extraEOSTokens = extraEOSTokens
self.preparePrompt = preparePrompt
}
@@ -98,11 +112,16 @@ public struct ModelConfiguration {
extension ModelConfiguration {
public static let mistral7B4bit = ModelConfiguration(
id: "mlx-community/Mistral-7B-v0.1-hf-4bit-mlx")
id: "mlx-community/Mistral-7B-v0.1-hf-4bit-mlx",
// https://www.promptingguide.ai/models/mistral-7b
defaultPrompt: "describe the swift language"
)
public static let codeLlama13b4bit = ModelConfiguration(
id: "mlx-community/CodeLlama-13b-Instruct-hf-4bit-MLX",
overrideTokenizer: "PreTrainedTokenizer"
overrideTokenizer: "PreTrainedTokenizer",
defaultPrompt: "func sortArray(_ array: [Int]) -> String { <FILL_ME> }"
) { prompt in
// given the prompt: func sortArray(_ array: [Int]) -> String { <FILL_ME> }
// the python code produces this (via its custom tokenizer):
@@ -111,13 +130,17 @@ extension ModelConfiguration {
"<PRE> " + prompt.replacingOccurrences(of: "<FILL_ME>", with: "<SUF>") + " <MID>"
}
public static let phi4bit = ModelConfiguration(id: "mlx-community/phi-2-hf-4bit-mlx") {
prompt in
"Instruct: \(prompt)\nOutput: "
}
public static let phi4bit = ModelConfiguration(
id: "mlx-community/phi-2-hf-4bit-mlx",
// https://www.promptingguide.ai/models/phi-2
defaultPrompt: "Why is the sky blue?"
)
public static let phi34bit = ModelConfiguration(
id: "mlx-community/Phi-3-mini-4k-instruct-4bit-no-q-embed"
id: "mlx-community/Phi-3-mini-4k-instruct-4bit-no-q-embed",
defaultPrompt: "what is the gravity on mars and the moon?",
extraEOSTokens: ["<|end|>"]
) {
prompt in
"<s><|user|>\n\(prompt)<|end|>\n<|assistant|>\n"
@@ -125,26 +148,35 @@ extension ModelConfiguration {
public static let gemma2bQuantized = ModelConfiguration(
id: "mlx-community/quantized-gemma-2b-it",
overrideTokenizer: "PreTrainedTokenizer"
overrideTokenizer: "PreTrainedTokenizer",
// https://www.promptingguide.ai/models/gemma
defaultPrompt: "what is the difference between lettuce and cabbage?"
) { prompt in
"<start_of_turn>user \(prompt)<end_of_turn><start_of_turn>model"
}
public static let qwen205b4bit = ModelConfiguration(
id: "mlx-community/Qwen1.5-0.5B-Chat-4bit",
overrideTokenizer: "PreTrainedTokenizer"
overrideTokenizer: "PreTrainedTokenizer",
defaultPrompt: "why is the sky blue?"
) { prompt in
"<|im_start|>system\nYou are a helpful assistant<|im_end|>\n<|im_start|>user\n\(prompt)<|im_end|>\n<|im_start|>assistant"
}
public static let openelm270m4bit = ModelConfiguration(
id: "mlx-community/OpenELM-270M-Instruct"
id: "mlx-community/OpenELM-270M-Instruct",
// https://huggingface.co/apple/OpenELM
defaultPrompt: "Once upon a time there was"
) { prompt in
"\(prompt)"
}
public static let llama38B4bit = ModelConfiguration(
id: "mlx-community/Meta-Llama-3-8B-Instruct-4bit"
id: "mlx-community/Meta-Llama-3-8B-Instruct-4bit",
defaultPrompt: "what is the difference between a fruit and a vegetable?"
) {
prompt in
"<|begin_of_text|><|start_header_id|>system<|end_header_id|>\nYou are a helpful assistant<|eot_id|>\n<|start_header_id|>user<|end_header_id|>\n\(prompt)<|eot_id|>\n<|start_header_id|>assistant<|end_header_id|>"

View File

@@ -179,10 +179,12 @@ public class Qwen2ModelInner: Module {
public class Qwen2Model: Module, LLMModel {
public let vocabularySize: Int
let model: Qwen2ModelInner
let configuration: Qwen2Configuration
@ModuleInfo(key: "lm_head") var lmHead: Linear
public init(_ args: Qwen2Configuration) {
self.configuration = args
self.vocabularySize = args.vocabularySize
self.model = Qwen2ModelInner(args)
_lmHead.wrappedValue = Linear(args.hiddenSize, args.vocabularySize, bias: false)
@@ -191,8 +193,26 @@ public class Qwen2Model: Module, LLMModel {
public func callAsFunction(_ inputs: MLXArray, cache: [(MLXArray, MLXArray)]?) -> (
MLXArray, [(MLXArray, MLXArray)]
) {
let (out, cache) = model(inputs, cache: cache)
return (lmHead(out), cache)
var (out, cache) = model(inputs, cache: cache)
if configuration.tieWordEmbeddings {
out = model.embedTokens.asLinear(out)
} else {
out = lmHead(out)
}
return (out, cache)
}
public func sanitize(weights: [String: MLXArray]) -> [String: MLXArray] {
var weights = weights
if configuration.tieWordEmbeddings {
weights["lm_head.weight"] = nil
}
// Remove unused precomputed rotary freqs
return weights.filter {
!$0.key.contains("self_attn.rotary_emb.inv_freq")
}
}
}
@@ -207,6 +227,7 @@ public struct Qwen2Configuration: Codable {
var ropeTheta: Float = 1_000_000
var ropeTraditional: Bool = false
var ropeScaling: [String: StringOrNumber]? = nil
var tieWordEmbeddings = false
enum CodingKeys: String, CodingKey {
case hiddenSize = "hidden_size"
@@ -219,6 +240,7 @@ public struct Qwen2Configuration: Codable {
case ropeTheta = "rope_theta"
case ropeTraditional = "rope_traditional"
case ropeScaling = "rope_scaling"
case tieWordEmbeddings = "tie_word_embeddings"
}
public init(from decoder: Decoder) throws {
@@ -249,6 +271,8 @@ public struct Qwen2Configuration: Codable {
Bool.self, forKey: Qwen2Configuration.CodingKeys.ropeTraditional) ?? false
self.ropeScaling = try container.decodeIfPresent(
[String: StringOrNumber].self, forKey: Qwen2Configuration.CodingKeys.ropeScaling)
self.tieWordEmbeddings =
try container.decodeIfPresent(Bool.self, forKey: .tieWordEmbeddings) ?? false
}
}