Add Gemma 2 (#88)

This commit is contained in:
Anthony DePasquale
2024-07-01 18:35:43 +02:00
committed by GitHub
parent 7957378077
commit 0c08f3a7e4
3 changed files with 125 additions and 0 deletions

View File

@@ -262,3 +262,111 @@ extension GemmaModel: LoRAModel {
model.layers.map { ($0.attention, ["q_proj", "v_proj"]) }
}
}
// Gemma 2
// Port of https://github.com/ml-explore/mlx-examples/blob/main/llms/mlx_lm/models/gemma2.py
// Minimal changes from Gemma TransformerBlock
private class Gemma2TransformerBlock: Module {
@ModuleInfo(key: "self_attn") var attention: Attention
let mlp: MLP
@ModuleInfo(key: "input_layernorm") var inputLayerNorm: RMSNorm
@ModuleInfo(key: "pre_feedforward_layernorm") var preFeedforwardLayerNorm: RMSNorm
@ModuleInfo(key: "post_feedforward_layernorm") var postFeedforwardLayerNorm: RMSNorm
@ModuleInfo(key: "post_attention_layernorm") var postAttentionLayerNorm: RMSNorm
public init(_ args: GemmaConfiguration) {
self._attention.wrappedValue = Attention(args)
self.mlp = MLP(dimensions: args.hiddenSize, hiddenDimensions: args.intermediateSize)
self._inputLayerNorm.wrappedValue = RMSNorm(
dimensions: args.hiddenSize, eps: args.rmsNormEps)
self._preFeedforwardLayerNorm.wrappedValue = RMSNorm(
dimensions: args.hiddenSize, eps: args.rmsNormEps)
self._postFeedforwardLayerNorm.wrappedValue = RMSNorm(
dimensions: args.hiddenSize, eps: args.rmsNormEps)
self._postAttentionLayerNorm.wrappedValue = RMSNorm(
dimensions: args.hiddenSize, eps: args.rmsNormEps)
}
public func callAsFunction(
_ x: MLXArray, mask: MLXArray? = nil, cache: (MLXArray, MLXArray)? = nil
) -> (MLXArray, (MLXArray, MLXArray)) {
var (r, cache) = attention(inputLayerNorm(x), mask: mask, cache: cache)
let h = x + postAttentionLayerNorm(r)
r = mlp(preFeedforwardLayerNorm(h))
let out = h + postFeedforwardLayerNorm(r)
return (out, cache)
}
}
// Uses Gemma2TransformerBlock, otherwise same as GemmaModelInner
public class Gemma2ModelInner: Module {
@ModuleInfo(key: "embed_tokens") var embedTokens: Embedding
fileprivate let layers: [Gemma2TransformerBlock]
fileprivate let norm: RMSNorm
let hiddenScale: Float
public init(_ args: GemmaConfiguration) {
precondition(args.vocabularySize > 0)
self._embedTokens.wrappedValue = Embedding(
embeddingCount: args.vocabularySize, dimensions: args.hiddenSize)
self.hiddenScale = pow(Float(args.hiddenSize), 0.5)
self.layers = (0 ..< args.hiddenLayers)
.map { _ in
Gemma2TransformerBlock(args)
}
self.norm = RMSNorm(dimensions: args.hiddenSize, eps: args.rmsNormEps)
}
public func callAsFunction(_ inputs: MLXArray, cache: [(MLXArray, MLXArray)]? = nil) -> (
MLXArray, [(MLXArray, MLXArray)]
) {
var h = embedTokens(inputs)
h = h * hiddenScale
var mask: MLXArray? = nil
if h.dim(1) > 1 {
mask = MultiHeadAttention.createAdditiveCausalMask(h.dim(1))
mask = mask?.asType(h.dtype)
}
var newCache = [(MLXArray, MLXArray)]()
for (i, layer) in layers.enumerated() {
var cacheUpdate: (MLXArray, MLXArray)
(h, cacheUpdate) = layer(h, mask: mask, cache: cache?[i])
newCache.append(cacheUpdate)
}
return (norm(h), newCache)
}
}
// Uses Gemma2ModelInner, otherwise same as GemmaModel
public class Gemma2Model: Module, LLMModel {
public let vocabularySize: Int
let model: Gemma2ModelInner
public init(_ args: GemmaConfiguration) {
self.vocabularySize = args.vocabularySize
self.model = Gemma2ModelInner(args)
}
public func callAsFunction(_ inputs: MLXArray, cache: [(MLXArray, MLXArray)]?) -> (
MLXArray, [(MLXArray, MLXArray)]
) {
var (out, cache) = model(inputs, cache: cache)
out = model.embedTokens.asLinear(out)
return (out, cache)
}
}