Add Gemma 2 (#88)
This commit is contained in:
committed by
GitHub
parent
7957378077
commit
0c08f3a7e4
@@ -32,6 +32,7 @@ public enum ModelType: String, Codable {
|
|||||||
case phi
|
case phi
|
||||||
case phi3
|
case phi3
|
||||||
case gemma
|
case gemma
|
||||||
|
case gemma2
|
||||||
case qwen2
|
case qwen2
|
||||||
case starcoder2
|
case starcoder2
|
||||||
case cohere
|
case cohere
|
||||||
@@ -55,6 +56,10 @@ public enum ModelType: String, Codable {
|
|||||||
let configuration = try JSONDecoder().decode(
|
let configuration = try JSONDecoder().decode(
|
||||||
GemmaConfiguration.self, from: Data(contentsOf: configuration))
|
GemmaConfiguration.self, from: Data(contentsOf: configuration))
|
||||||
return GemmaModel(configuration)
|
return GemmaModel(configuration)
|
||||||
|
case .gemma2:
|
||||||
|
let configuration = try JSONDecoder().decode(
|
||||||
|
GemmaConfiguration.self, from: Data(contentsOf: configuration))
|
||||||
|
return Gemma2Model(configuration)
|
||||||
case .qwen2:
|
case .qwen2:
|
||||||
let configuration = try JSONDecoder().decode(
|
let configuration = try JSONDecoder().decode(
|
||||||
Qwen2Configuration.self, from: Data(contentsOf: configuration))
|
Qwen2Configuration.self, from: Data(contentsOf: configuration))
|
||||||
|
|||||||
@@ -262,3 +262,111 @@ extension GemmaModel: LoRAModel {
|
|||||||
model.layers.map { ($0.attention, ["q_proj", "v_proj"]) }
|
model.layers.map { ($0.attention, ["q_proj", "v_proj"]) }
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Gemma 2
|
||||||
|
|
||||||
|
// Port of https://github.com/ml-explore/mlx-examples/blob/main/llms/mlx_lm/models/gemma2.py
|
||||||
|
|
||||||
|
// Minimal changes from Gemma TransformerBlock
|
||||||
|
private class Gemma2TransformerBlock: Module {
|
||||||
|
|
||||||
|
@ModuleInfo(key: "self_attn") var attention: Attention
|
||||||
|
let mlp: MLP
|
||||||
|
|
||||||
|
@ModuleInfo(key: "input_layernorm") var inputLayerNorm: RMSNorm
|
||||||
|
@ModuleInfo(key: "pre_feedforward_layernorm") var preFeedforwardLayerNorm: RMSNorm
|
||||||
|
@ModuleInfo(key: "post_feedforward_layernorm") var postFeedforwardLayerNorm: RMSNorm
|
||||||
|
@ModuleInfo(key: "post_attention_layernorm") var postAttentionLayerNorm: RMSNorm
|
||||||
|
|
||||||
|
public init(_ args: GemmaConfiguration) {
|
||||||
|
self._attention.wrappedValue = Attention(args)
|
||||||
|
self.mlp = MLP(dimensions: args.hiddenSize, hiddenDimensions: args.intermediateSize)
|
||||||
|
self._inputLayerNorm.wrappedValue = RMSNorm(
|
||||||
|
dimensions: args.hiddenSize, eps: args.rmsNormEps)
|
||||||
|
self._preFeedforwardLayerNorm.wrappedValue = RMSNorm(
|
||||||
|
dimensions: args.hiddenSize, eps: args.rmsNormEps)
|
||||||
|
self._postFeedforwardLayerNorm.wrappedValue = RMSNorm(
|
||||||
|
dimensions: args.hiddenSize, eps: args.rmsNormEps)
|
||||||
|
self._postAttentionLayerNorm.wrappedValue = RMSNorm(
|
||||||
|
dimensions: args.hiddenSize, eps: args.rmsNormEps)
|
||||||
|
}
|
||||||
|
|
||||||
|
public func callAsFunction(
|
||||||
|
_ x: MLXArray, mask: MLXArray? = nil, cache: (MLXArray, MLXArray)? = nil
|
||||||
|
) -> (MLXArray, (MLXArray, MLXArray)) {
|
||||||
|
var (r, cache) = attention(inputLayerNorm(x), mask: mask, cache: cache)
|
||||||
|
let h = x + postAttentionLayerNorm(r)
|
||||||
|
r = mlp(preFeedforwardLayerNorm(h))
|
||||||
|
let out = h + postFeedforwardLayerNorm(r)
|
||||||
|
return (out, cache)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Uses Gemma2TransformerBlock, otherwise same as GemmaModelInner
|
||||||
|
public class Gemma2ModelInner: Module {
|
||||||
|
|
||||||
|
@ModuleInfo(key: "embed_tokens") var embedTokens: Embedding
|
||||||
|
|
||||||
|
fileprivate let layers: [Gemma2TransformerBlock]
|
||||||
|
fileprivate let norm: RMSNorm
|
||||||
|
|
||||||
|
let hiddenScale: Float
|
||||||
|
|
||||||
|
public init(_ args: GemmaConfiguration) {
|
||||||
|
precondition(args.vocabularySize > 0)
|
||||||
|
|
||||||
|
self._embedTokens.wrappedValue = Embedding(
|
||||||
|
embeddingCount: args.vocabularySize, dimensions: args.hiddenSize)
|
||||||
|
|
||||||
|
self.hiddenScale = pow(Float(args.hiddenSize), 0.5)
|
||||||
|
|
||||||
|
self.layers = (0 ..< args.hiddenLayers)
|
||||||
|
.map { _ in
|
||||||
|
Gemma2TransformerBlock(args)
|
||||||
|
}
|
||||||
|
self.norm = RMSNorm(dimensions: args.hiddenSize, eps: args.rmsNormEps)
|
||||||
|
}
|
||||||
|
|
||||||
|
public func callAsFunction(_ inputs: MLXArray, cache: [(MLXArray, MLXArray)]? = nil) -> (
|
||||||
|
MLXArray, [(MLXArray, MLXArray)]
|
||||||
|
) {
|
||||||
|
var h = embedTokens(inputs)
|
||||||
|
h = h * hiddenScale
|
||||||
|
|
||||||
|
var mask: MLXArray? = nil
|
||||||
|
if h.dim(1) > 1 {
|
||||||
|
mask = MultiHeadAttention.createAdditiveCausalMask(h.dim(1))
|
||||||
|
mask = mask?.asType(h.dtype)
|
||||||
|
}
|
||||||
|
|
||||||
|
var newCache = [(MLXArray, MLXArray)]()
|
||||||
|
|
||||||
|
for (i, layer) in layers.enumerated() {
|
||||||
|
var cacheUpdate: (MLXArray, MLXArray)
|
||||||
|
(h, cacheUpdate) = layer(h, mask: mask, cache: cache?[i])
|
||||||
|
newCache.append(cacheUpdate)
|
||||||
|
}
|
||||||
|
|
||||||
|
return (norm(h), newCache)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Uses Gemma2ModelInner, otherwise same as GemmaModel
|
||||||
|
public class Gemma2Model: Module, LLMModel {
|
||||||
|
|
||||||
|
public let vocabularySize: Int
|
||||||
|
let model: Gemma2ModelInner
|
||||||
|
|
||||||
|
public init(_ args: GemmaConfiguration) {
|
||||||
|
self.vocabularySize = args.vocabularySize
|
||||||
|
self.model = Gemma2ModelInner(args)
|
||||||
|
}
|
||||||
|
|
||||||
|
public func callAsFunction(_ inputs: MLXArray, cache: [(MLXArray, MLXArray)]?) -> (
|
||||||
|
MLXArray, [(MLXArray, MLXArray)]
|
||||||
|
) {
|
||||||
|
var (out, cache) = model(inputs, cache: cache)
|
||||||
|
out = model.embedTokens.asLinear(out)
|
||||||
|
return (out, cache)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@@ -157,6 +157,17 @@ extension ModelConfiguration {
|
|||||||
"<start_of_turn>user \(prompt)<end_of_turn><start_of_turn>model"
|
"<start_of_turn>user \(prompt)<end_of_turn><start_of_turn>model"
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public static let gemma_2_9b_it_4bit = ModelConfiguration(
|
||||||
|
id: "mlx-community/gemma-2-9b-it-4bit",
|
||||||
|
overrideTokenizer: "PreTrainedTokenizer",
|
||||||
|
|
||||||
|
// https://www.promptingguide.ai/models/gemma
|
||||||
|
defaultPrompt: "What is the difference between lettuce and cabbage?"
|
||||||
|
|
||||||
|
) { prompt in
|
||||||
|
"<start_of_turn>user \(prompt)<end_of_turn><start_of_turn>model"
|
||||||
|
}
|
||||||
|
|
||||||
public static let qwen205b4bit = ModelConfiguration(
|
public static let qwen205b4bit = ModelConfiguration(
|
||||||
id: "mlx-community/Qwen1.5-0.5B-Chat-4bit",
|
id: "mlx-community/Qwen1.5-0.5B-Chat-4bit",
|
||||||
overrideTokenizer: "PreTrainedTokenizer",
|
overrideTokenizer: "PreTrainedTokenizer",
|
||||||
@@ -200,6 +211,7 @@ extension ModelConfiguration {
|
|||||||
phi4bit,
|
phi4bit,
|
||||||
phi34bit,
|
phi34bit,
|
||||||
gemma2bQuantized,
|
gemma2bQuantized,
|
||||||
|
gemma_2_9b_it_4bit,
|
||||||
qwen205b4bit,
|
qwen205b4bit,
|
||||||
openelm270m4bit,
|
openelm270m4bit,
|
||||||
])
|
])
|
||||||
|
|||||||
Reference in New Issue
Block a user