adopt MLXFast.scaledDotProductAttention (#23)

This commit is contained in:
David Koski
2024-03-12 14:04:43 -07:00
committed by GitHub
parent a94bf79d7e
commit 0fb74cbfdc
7 changed files with 39 additions and 82 deletions

View File

@@ -2,6 +2,7 @@
import Foundation
import MLX
import MLXFast
import MLXNN
// port of https://github.com/ml-explore/mlx-examples/blob/main/llms/mlx_lm/models/gemma.py
@@ -31,7 +32,6 @@ private class RMSNorm: Module, UnaryLayer {
private class Attention: Module {
let args: GemmaConfiguration
let repeats: Int
let scale: Float
@ModuleInfo(key: "q_proj") var wq: Linear
@@ -48,8 +48,6 @@ private class Attention: Module {
let heads = args.attentionHeads
let kvHeads = args.kvHeads
self.repeats = heads / kvHeads
let headDim = args.headDimensions
self.scale = pow(Float(headDim), -0.5)
@@ -76,11 +74,6 @@ private class Attention: Module {
keys = keys.reshaped(B, L, args.kvHeads, -1).transposed(0, 2, 1, 3)
values = values.reshaped(B, L, args.kvHeads, -1).transposed(0, 2, 1, 3)
if repeats > 1 {
keys = MLXArray.repeat(keys, count: repeats, axis: 1)
values = MLXArray.repeat(values, count: repeats, axis: 1)
}
if let (keyCache, valueCache) = cache {
queries = rope(queries, offset: keyCache.dim(2))
keys = rope(keys, offset: keyCache.dim(2))
@@ -91,14 +84,11 @@ private class Attention: Module {
keys = rope(keys)
}
var scores = (queries * self.scale).matmul(keys.transposed(0, 1, 3, 2))
if let mask {
scores = scores + mask
}
scores = softMax(scores.asType(.float32), axis: -1).asType(scores.dtype)
let output = matmul(scores, values).transposed(0, 2, 1, 3).reshaped(B, L, -1)
let output = MLXFast.scaledDotProductAttention(
queries: queries, keys: keys, values: values, scale: scale, mask: mask
)
.transposed(0, 2, 1, 3)
.reshaped(B, L, -1)
return (wo(output), (keys, values))
}