handle partially quantized models (#76)
* handle partially quantized models - fix for #53 #71 #69 #74 - in order to test the models - I added a default prompt of an appropriate form - while working on the model configuration also added additional stop tokens (#74) - fixed the repetitionPenalty code (#71)
This commit is contained in:
@@ -179,10 +179,12 @@ public class Qwen2ModelInner: Module {
|
||||
public class Qwen2Model: Module, LLMModel {
|
||||
public let vocabularySize: Int
|
||||
let model: Qwen2ModelInner
|
||||
let configuration: Qwen2Configuration
|
||||
|
||||
@ModuleInfo(key: "lm_head") var lmHead: Linear
|
||||
|
||||
public init(_ args: Qwen2Configuration) {
|
||||
self.configuration = args
|
||||
self.vocabularySize = args.vocabularySize
|
||||
self.model = Qwen2ModelInner(args)
|
||||
_lmHead.wrappedValue = Linear(args.hiddenSize, args.vocabularySize, bias: false)
|
||||
@@ -191,8 +193,26 @@ public class Qwen2Model: Module, LLMModel {
|
||||
public func callAsFunction(_ inputs: MLXArray, cache: [(MLXArray, MLXArray)]?) -> (
|
||||
MLXArray, [(MLXArray, MLXArray)]
|
||||
) {
|
||||
let (out, cache) = model(inputs, cache: cache)
|
||||
return (lmHead(out), cache)
|
||||
var (out, cache) = model(inputs, cache: cache)
|
||||
if configuration.tieWordEmbeddings {
|
||||
out = model.embedTokens.asLinear(out)
|
||||
} else {
|
||||
out = lmHead(out)
|
||||
}
|
||||
return (out, cache)
|
||||
}
|
||||
|
||||
public func sanitize(weights: [String: MLXArray]) -> [String: MLXArray] {
|
||||
var weights = weights
|
||||
|
||||
if configuration.tieWordEmbeddings {
|
||||
weights["lm_head.weight"] = nil
|
||||
}
|
||||
|
||||
// Remove unused precomputed rotary freqs
|
||||
return weights.filter {
|
||||
!$0.key.contains("self_attn.rotary_emb.inv_freq")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -207,6 +227,7 @@ public struct Qwen2Configuration: Codable {
|
||||
var ropeTheta: Float = 1_000_000
|
||||
var ropeTraditional: Bool = false
|
||||
var ropeScaling: [String: StringOrNumber]? = nil
|
||||
var tieWordEmbeddings = false
|
||||
|
||||
enum CodingKeys: String, CodingKey {
|
||||
case hiddenSize = "hidden_size"
|
||||
@@ -219,6 +240,7 @@ public struct Qwen2Configuration: Codable {
|
||||
case ropeTheta = "rope_theta"
|
||||
case ropeTraditional = "rope_traditional"
|
||||
case ropeScaling = "rope_scaling"
|
||||
case tieWordEmbeddings = "tie_word_embeddings"
|
||||
}
|
||||
|
||||
public init(from decoder: Decoder) throws {
|
||||
@@ -249,6 +271,8 @@ public struct Qwen2Configuration: Codable {
|
||||
Bool.self, forKey: Qwen2Configuration.CodingKeys.ropeTraditional) ?? false
|
||||
self.ropeScaling = try container.decodeIfPresent(
|
||||
[String: StringOrNumber].self, forKey: Qwen2Configuration.CodingKeys.ropeScaling)
|
||||
self.tieWordEmbeddings =
|
||||
try container.decodeIfPresent(Bool.self, forKey: .tieWordEmbeddings) ?? false
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user