handle partially quantized models (#76)

* handle partially quantized models

- fix for #53 #71 #69 #74
- in order to test the models
	- I added a default prompt of an appropriate form
	- while working on the model configuration also added additional stop tokens (#74)
- fixed the repetitionPenalty code (#71)
This commit is contained in:
David Koski
2024-05-28 16:35:11 -07:00
committed by GitHub
parent 65f4968e5f
commit 9d74afd119
12 changed files with 139 additions and 67 deletions

View File

@@ -179,10 +179,12 @@ public class Qwen2ModelInner: Module {
public class Qwen2Model: Module, LLMModel {
public let vocabularySize: Int
let model: Qwen2ModelInner
let configuration: Qwen2Configuration
@ModuleInfo(key: "lm_head") var lmHead: Linear
public init(_ args: Qwen2Configuration) {
self.configuration = args
self.vocabularySize = args.vocabularySize
self.model = Qwen2ModelInner(args)
_lmHead.wrappedValue = Linear(args.hiddenSize, args.vocabularySize, bias: false)
@@ -191,8 +193,26 @@ public class Qwen2Model: Module, LLMModel {
public func callAsFunction(_ inputs: MLXArray, cache: [(MLXArray, MLXArray)]?) -> (
MLXArray, [(MLXArray, MLXArray)]
) {
let (out, cache) = model(inputs, cache: cache)
return (lmHead(out), cache)
var (out, cache) = model(inputs, cache: cache)
if configuration.tieWordEmbeddings {
out = model.embedTokens.asLinear(out)
} else {
out = lmHead(out)
}
return (out, cache)
}
public func sanitize(weights: [String: MLXArray]) -> [String: MLXArray] {
var weights = weights
if configuration.tieWordEmbeddings {
weights["lm_head.weight"] = nil
}
// Remove unused precomputed rotary freqs
return weights.filter {
!$0.key.contains("self_attn.rotary_emb.inv_freq")
}
}
}
@@ -207,6 +227,7 @@ public struct Qwen2Configuration: Codable {
var ropeTheta: Float = 1_000_000
var ropeTraditional: Bool = false
var ropeScaling: [String: StringOrNumber]? = nil
var tieWordEmbeddings = false
enum CodingKeys: String, CodingKey {
case hiddenSize = "hidden_size"
@@ -219,6 +240,7 @@ public struct Qwen2Configuration: Codable {
case ropeTheta = "rope_theta"
case ropeTraditional = "rope_traditional"
case ropeScaling = "rope_scaling"
case tieWordEmbeddings = "tie_word_embeddings"
}
public init(from decoder: Decoder) throws {
@@ -249,6 +271,8 @@ public struct Qwen2Configuration: Codable {
Bool.self, forKey: Qwen2Configuration.CodingKeys.ropeTraditional) ?? false
self.ropeScaling = try container.decodeIfPresent(
[String: StringOrNumber].self, forKey: Qwen2Configuration.CodingKeys.ropeScaling)
self.tieWordEmbeddings =
try container.decodeIfPresent(Bool.self, forKey: .tieWordEmbeddings) ?? false
}
}