adopt MLXFast.scaledDotProductAttention (#23)

This commit is contained in:
David Koski
2024-03-12 14:04:43 -07:00
committed by GitHub
parent a94bf79d7e
commit 0fb74cbfdc
7 changed files with 39 additions and 82 deletions

View File

@@ -43,12 +43,17 @@ workflows:
when: when:
and: and:
- matches: - matches:
pattern: "^(?!pull/)[-\\w]+$" pattern: "^(?!pull/)[-_./\\w]+$"
value: << pipeline.git.branch >> value: << pipeline.git.branch >>
- not: << pipeline.parameters.nightly_build >> - not: << pipeline.parameters.nightly_build >>
- not: << pipeline.parameters.weekly_build >> - not: << pipeline.parameters.weekly_build >>
jobs: jobs:
- mac_build_and_test - hold:
type: approval
- apple/authenticate:
context: pr-approval
- mac_build_and_test:
requires: [ hold ]
prb: prb:
when: when:

View File

@@ -2,6 +2,7 @@
import Foundation import Foundation
import MLX import MLX
import MLXFast
import MLXNN import MLXNN
// port of https://github.com/ml-explore/mlx-examples/blob/main/llms/mlx_lm/models/gemma.py // port of https://github.com/ml-explore/mlx-examples/blob/main/llms/mlx_lm/models/gemma.py
@@ -31,7 +32,6 @@ private class RMSNorm: Module, UnaryLayer {
private class Attention: Module { private class Attention: Module {
let args: GemmaConfiguration let args: GemmaConfiguration
let repeats: Int
let scale: Float let scale: Float
@ModuleInfo(key: "q_proj") var wq: Linear @ModuleInfo(key: "q_proj") var wq: Linear
@@ -48,8 +48,6 @@ private class Attention: Module {
let heads = args.attentionHeads let heads = args.attentionHeads
let kvHeads = args.kvHeads let kvHeads = args.kvHeads
self.repeats = heads / kvHeads
let headDim = args.headDimensions let headDim = args.headDimensions
self.scale = pow(Float(headDim), -0.5) self.scale = pow(Float(headDim), -0.5)
@@ -76,11 +74,6 @@ private class Attention: Module {
keys = keys.reshaped(B, L, args.kvHeads, -1).transposed(0, 2, 1, 3) keys = keys.reshaped(B, L, args.kvHeads, -1).transposed(0, 2, 1, 3)
values = values.reshaped(B, L, args.kvHeads, -1).transposed(0, 2, 1, 3) values = values.reshaped(B, L, args.kvHeads, -1).transposed(0, 2, 1, 3)
if repeats > 1 {
keys = MLXArray.repeat(keys, count: repeats, axis: 1)
values = MLXArray.repeat(values, count: repeats, axis: 1)
}
if let (keyCache, valueCache) = cache { if let (keyCache, valueCache) = cache {
queries = rope(queries, offset: keyCache.dim(2)) queries = rope(queries, offset: keyCache.dim(2))
keys = rope(keys, offset: keyCache.dim(2)) keys = rope(keys, offset: keyCache.dim(2))
@@ -91,14 +84,11 @@ private class Attention: Module {
keys = rope(keys) keys = rope(keys)
} }
var scores = (queries * self.scale).matmul(keys.transposed(0, 1, 3, 2)) let output = MLXFast.scaledDotProductAttention(
if let mask { queries: queries, keys: keys, values: values, scale: scale, mask: mask
scores = scores + mask )
} .transposed(0, 2, 1, 3)
.reshaped(B, L, -1)
scores = softMax(scores.asType(.float32), axis: -1).asType(scores.dtype)
let output = matmul(scores, values).transposed(0, 2, 1, 3).reshaped(B, L, -1)
return (wo(output), (keys, values)) return (wo(output), (keys, values))
} }

View File

@@ -2,6 +2,7 @@
import Foundation import Foundation
import MLX import MLX
import MLXFast
import MLXNN import MLXNN
// port of https://github.com/ml-explore/mlx-examples/blob/main/llms/mlx_lm/models/llama.py // port of https://github.com/ml-explore/mlx-examples/blob/main/llms/mlx_lm/models/llama.py
@@ -9,7 +10,6 @@ import MLXNN
private class Attention: Module { private class Attention: Module {
let args: LlamaConfiguration let args: LlamaConfiguration
let repeats: Int
let scale: Float let scale: Float
@ModuleInfo(key: "q_proj") var wq: Linear @ModuleInfo(key: "q_proj") var wq: Linear
@@ -26,8 +26,6 @@ private class Attention: Module {
let heads = args.attentionHeads let heads = args.attentionHeads
let kvHeads = args.kvHeads let kvHeads = args.kvHeads
self.repeats = heads / kvHeads
let headDim = args.hiddenSize / heads let headDim = args.hiddenSize / heads
self.scale = pow(Float(headDim), -0.5) self.scale = pow(Float(headDim), -0.5)
@@ -69,11 +67,6 @@ private class Attention: Module {
keys = keys.reshaped(B, L, args.kvHeads, -1).transposed(0, 2, 1, 3) keys = keys.reshaped(B, L, args.kvHeads, -1).transposed(0, 2, 1, 3)
values = values.reshaped(B, L, args.kvHeads, -1).transposed(0, 2, 1, 3) values = values.reshaped(B, L, args.kvHeads, -1).transposed(0, 2, 1, 3)
if repeats > 1 {
keys = MLXArray.repeat(keys, count: repeats, axis: 1)
values = MLXArray.repeat(values, count: repeats, axis: 1)
}
if let (keyCache, valueCache) = cache { if let (keyCache, valueCache) = cache {
queries = rope(queries, offset: keyCache.dim(2)) queries = rope(queries, offset: keyCache.dim(2))
keys = rope(keys, offset: keyCache.dim(2)) keys = rope(keys, offset: keyCache.dim(2))
@@ -84,14 +77,11 @@ private class Attention: Module {
keys = rope(keys) keys = rope(keys)
} }
var scores = (queries * self.scale).matmul(keys.transposed(0, 1, 3, 2)) let output = MLXFast.scaledDotProductAttention(
if let mask { queries: queries, keys: keys, values: values, scale: scale, mask: mask
scores = scores + mask )
} .transposed(0, 2, 1, 3)
.reshaped(B, L, -1)
scores = softMax(scores.asType(.float32), axis: -1).asType(scores.dtype)
let output = matmul(scores, values).transposed(0, 2, 1, 3).reshaped(B, L, -1)
return (wo(output), (keys, values)) return (wo(output), (keys, values))
} }

View File

@@ -2,6 +2,7 @@
import Foundation import Foundation
import MLX import MLX
import MLXFast
import MLXNN import MLXNN
// https://github.com/ml-explore/mlx-examples/blob/main/llms/mlx_lm/models/phi.py // https://github.com/ml-explore/mlx-examples/blob/main/llms/mlx_lm/models/phi.py
@@ -17,7 +18,6 @@ private class PhiAttention: Module {
let args: PhiConfiguration let args: PhiConfiguration
let heads: Int let heads: Int
let headDim: Int let headDim: Int
let repeats: Int
@ModuleInfo(key: "q_proj") var wq: Linear @ModuleInfo(key: "q_proj") var wq: Linear
@ModuleInfo(key: "k_proj") var wk: Linear @ModuleInfo(key: "k_proj") var wk: Linear
@@ -33,7 +33,6 @@ private class PhiAttention: Module {
self.heads = args.attentionHeads self.heads = args.attentionHeads
self.headDim = args.hiddenSize / heads self.headDim = args.hiddenSize / heads
let kvHeads = args.kvHeads let kvHeads = args.kvHeads
self.repeats = heads / kvHeads
if headDim * heads != hiddenSize { if headDim * heads != hiddenSize {
fatalError("hidden_size must be divisible by num_heads") fatalError("hidden_size must be divisible by num_heads")
@@ -63,11 +62,6 @@ private class PhiAttention: Module {
keys = keys.reshaped(B, L, args.kvHeads, headDim).transposed(0, 2, 1, 3) keys = keys.reshaped(B, L, args.kvHeads, headDim).transposed(0, 2, 1, 3)
values = values.reshaped(B, L, args.kvHeads, headDim).transposed(0, 2, 1, 3) values = values.reshaped(B, L, args.kvHeads, headDim).transposed(0, 2, 1, 3)
if repeats > 1 {
keys = MLXArray.repeat(keys, count: repeats, axis: 1)
values = MLXArray.repeat(values, count: repeats, axis: 1)
}
// Add RoPE to the queries and keys and combine them with the cache // Add RoPE to the queries and keys and combine them with the cache
if let (keyCache, valueCache) = cache { if let (keyCache, valueCache) = cache {
queries = rope(queries, offset: keyCache.dim(2)) queries = rope(queries, offset: keyCache.dim(2))
@@ -84,15 +78,13 @@ private class PhiAttention: Module {
// Finally perform the attention computation // Finally perform the attention computation
let scale = sqrt(1 / Float(queries.dim(-1))) let scale = sqrt(1 / Float(queries.dim(-1)))
var scores = (queries * scale).matmul(keys.transposed(0, 1, 3, 2)) let output = MLXFast.scaledDotProductAttention(
if let mask { queries: queries, keys: keys, values: values, scale: scale, mask: mask
scores = scores + mask )
} .transposed(0, 2, 1, 3)
.reshaped(B, L, -1)
scores = softMax(scores, axis: -1).asType(values.dtype) return (dense(output), (keys, values))
let valuesHat = matmul(scores, values).transposed(0, 2, 1, 3).reshaped(B, L, -1)
return (dense(valuesHat), (keys, values))
} }
} }

View File

@@ -7,13 +7,13 @@
import Foundation import Foundation
import MLX import MLX
import MLXFast
import MLXNN import MLXNN
// port of https://github.com/ml-explore/mlx-examples/blob/main/llms/mlx_lm/models/qwen2.py // port of https://github.com/ml-explore/mlx-examples/blob/main/llms/mlx_lm/models/qwen2.py
private class Attention: Module { private class Attention: Module {
let args: Qwen2Configuration let args: Qwen2Configuration
let repeats: Int
let scale: Float let scale: Float
@ModuleInfo(key: "q_proj") var wq: Linear @ModuleInfo(key: "q_proj") var wq: Linear
@@ -30,8 +30,6 @@ private class Attention: Module {
let heads = args.attentionHeads let heads = args.attentionHeads
let kvHeads = args.kvHeads let kvHeads = args.kvHeads
self.repeats = heads / kvHeads
let headDim = args.hiddenSize / heads let headDim = args.hiddenSize / heads
self.scale = pow(Float(headDim), -0.5) self.scale = pow(Float(headDim), -0.5)
@@ -73,11 +71,6 @@ private class Attention: Module {
keys = keys.reshaped(B, L, args.kvHeads, -1).transposed(0, 2, 1, 3) keys = keys.reshaped(B, L, args.kvHeads, -1).transposed(0, 2, 1, 3)
values = values.reshaped(B, L, args.kvHeads, -1).transposed(0, 2, 1, 3) values = values.reshaped(B, L, args.kvHeads, -1).transposed(0, 2, 1, 3)
if repeats > 1 {
keys = MLXArray.repeat(keys, count: repeats, axis: 1)
values = MLXArray.repeat(values, count: repeats, axis: 1)
}
if let (keyCache, valueCache) = cache { if let (keyCache, valueCache) = cache {
queries = rope(queries, offset: keyCache.dim(2)) queries = rope(queries, offset: keyCache.dim(2))
keys = rope(keys, offset: keyCache.dim(2)) keys = rope(keys, offset: keyCache.dim(2))
@@ -88,14 +81,11 @@ private class Attention: Module {
keys = rope(keys) keys = rope(keys)
} }
var scores = (queries * scale).matmul(keys.transposed(0, 1, 3, 2)) let output = MLXFast.scaledDotProductAttention(
if let mask { queries: queries, keys: keys, values: values, scale: scale, mask: mask
scores = scores + mask )
} .transposed(0, 2, 1, 3)
.reshaped(B, L, -1)
scores = softMax(scores.asType(.float32), axis: -1).asType(scores.dtype)
let output = matmul(scores, values).transposed(0, 2, 1, 3).reshaped(B, L, -1)
return (wo(output), (keys, values)) return (wo(output), (keys, values))
} }

View File

@@ -7,13 +7,13 @@
import Foundation import Foundation
import MLX import MLX
import MLXFast
import MLXNN import MLXNN
// port of https://github.com/ml-explore/mlx-examples/blob/main/llms/mlx_lm/models/starcoder2.py // port of https://github.com/ml-explore/mlx-examples/blob/main/llms/mlx_lm/models/starcoder2.py
private class Attention: Module { private class Attention: Module {
let args: Starcoder2Configuration let args: Starcoder2Configuration
let repeats: Int
let scale: Float let scale: Float
@ModuleInfo(key: "q_proj") var wq: Linear @ModuleInfo(key: "q_proj") var wq: Linear
@@ -30,8 +30,6 @@ private class Attention: Module {
let heads = args.attentionHeads let heads = args.attentionHeads
let kvHeads = args.kvHeads let kvHeads = args.kvHeads
self.repeats = heads / kvHeads
let headDim = args.hiddenSize / heads let headDim = args.hiddenSize / heads
self.scale = pow(Float(headDim), -0.5) self.scale = pow(Float(headDim), -0.5)
@@ -57,11 +55,6 @@ private class Attention: Module {
keys = keys.reshaped(B, L, args.kvHeads, -1).transposed(0, 2, 1, 3) keys = keys.reshaped(B, L, args.kvHeads, -1).transposed(0, 2, 1, 3)
values = values.reshaped(B, L, args.kvHeads, -1).transposed(0, 2, 1, 3) values = values.reshaped(B, L, args.kvHeads, -1).transposed(0, 2, 1, 3)
if repeats > 1 {
keys = MLXArray.repeat(keys, count: repeats, axis: 1)
values = MLXArray.repeat(values, count: repeats, axis: 1)
}
if let (keyCache, valueCache) = cache { if let (keyCache, valueCache) = cache {
queries = rope(queries, offset: keyCache.dim(2)) queries = rope(queries, offset: keyCache.dim(2))
keys = rope(keys, offset: keyCache.dim(2)) keys = rope(keys, offset: keyCache.dim(2))
@@ -72,14 +65,11 @@ private class Attention: Module {
keys = rope(keys) keys = rope(keys)
} }
var scores = (queries * scale).matmul(keys.transposed(0, 1, 3, 2)) let output = MLXFast.scaledDotProductAttention(
if let mask { queries: queries, keys: keys, values: values, scale: scale, mask: mask
scores = scores + mask )
} .transposed(0, 2, 1, 3)
.reshaped(B, L, -1)
scores = softMax(scores.asType(.float32), axis: -1).asType(scores.dtype)
let output = matmul(scores, values).transposed(0, 2, 1, 3).reshaped(B, L, -1)
return (wo(output), (keys, values)) return (wo(output), (keys, values))
} }

View File

@@ -15,7 +15,7 @@
"location" : "https://github.com/ml-explore/mlx-swift", "location" : "https://github.com/ml-explore/mlx-swift",
"state" : { "state" : {
"branch" : "main", "branch" : "main",
"revision" : "eb249b04b1188b72c122223e5c640a41745a61b9" "revision" : "948000ceaa27c343f4dd5ce40f727f221bf45c6e"
} }
}, },
{ {