adopt MLXFast.scaledDotProductAttention (#23)

This commit is contained in:
David Koski
2024-03-12 14:04:43 -07:00
committed by GitHub
parent a94bf79d7e
commit 0fb74cbfdc
7 changed files with 39 additions and 82 deletions

View File

@@ -2,6 +2,7 @@
import Foundation
import MLX
import MLXFast
import MLXNN
// https://github.com/ml-explore/mlx-examples/blob/main/llms/mlx_lm/models/phi.py
@@ -17,7 +18,6 @@ private class PhiAttention: Module {
let args: PhiConfiguration
let heads: Int
let headDim: Int
let repeats: Int
@ModuleInfo(key: "q_proj") var wq: Linear
@ModuleInfo(key: "k_proj") var wk: Linear
@@ -33,7 +33,6 @@ private class PhiAttention: Module {
self.heads = args.attentionHeads
self.headDim = args.hiddenSize / heads
let kvHeads = args.kvHeads
self.repeats = heads / kvHeads
if headDim * heads != hiddenSize {
fatalError("hidden_size must be divisible by num_heads")
@@ -63,11 +62,6 @@ private class PhiAttention: Module {
keys = keys.reshaped(B, L, args.kvHeads, headDim).transposed(0, 2, 1, 3)
values = values.reshaped(B, L, args.kvHeads, headDim).transposed(0, 2, 1, 3)
if repeats > 1 {
keys = MLXArray.repeat(keys, count: repeats, axis: 1)
values = MLXArray.repeat(values, count: repeats, axis: 1)
}
// Add RoPE to the queries and keys and combine them with the cache
if let (keyCache, valueCache) = cache {
queries = rope(queries, offset: keyCache.dim(2))
@@ -84,15 +78,13 @@ private class PhiAttention: Module {
// Finally perform the attention computation
let scale = sqrt(1 / Float(queries.dim(-1)))
var scores = (queries * scale).matmul(keys.transposed(0, 1, 3, 2))
if let mask {
scores = scores + mask
}
let output = MLXFast.scaledDotProductAttention(
queries: queries, keys: keys, values: values, scale: scale, mask: mask
)
.transposed(0, 2, 1, 3)
.reshaped(B, L, -1)
scores = softMax(scores, axis: -1).asType(values.dtype)
let valuesHat = matmul(scores, values).transposed(0, 2, 1, 3).reshaped(B, L, -1)
return (dense(valuesHat), (keys, values))
return (dense(output), (keys, values))
}
}