adopt MLXFast.scaledDotProductAttention (#23)
This commit is contained in:
@@ -43,12 +43,17 @@ workflows:
|
|||||||
when:
|
when:
|
||||||
and:
|
and:
|
||||||
- matches:
|
- matches:
|
||||||
pattern: "^(?!pull/)[-\\w]+$"
|
pattern: "^(?!pull/)[-_./\\w]+$"
|
||||||
value: << pipeline.git.branch >>
|
value: << pipeline.git.branch >>
|
||||||
- not: << pipeline.parameters.nightly_build >>
|
- not: << pipeline.parameters.nightly_build >>
|
||||||
- not: << pipeline.parameters.weekly_build >>
|
- not: << pipeline.parameters.weekly_build >>
|
||||||
jobs:
|
jobs:
|
||||||
- mac_build_and_test
|
- hold:
|
||||||
|
type: approval
|
||||||
|
- apple/authenticate:
|
||||||
|
context: pr-approval
|
||||||
|
- mac_build_and_test:
|
||||||
|
requires: [ hold ]
|
||||||
|
|
||||||
prb:
|
prb:
|
||||||
when:
|
when:
|
||||||
|
|||||||
@@ -2,6 +2,7 @@
|
|||||||
|
|
||||||
import Foundation
|
import Foundation
|
||||||
import MLX
|
import MLX
|
||||||
|
import MLXFast
|
||||||
import MLXNN
|
import MLXNN
|
||||||
|
|
||||||
// port of https://github.com/ml-explore/mlx-examples/blob/main/llms/mlx_lm/models/gemma.py
|
// port of https://github.com/ml-explore/mlx-examples/blob/main/llms/mlx_lm/models/gemma.py
|
||||||
@@ -31,7 +32,6 @@ private class RMSNorm: Module, UnaryLayer {
|
|||||||
private class Attention: Module {
|
private class Attention: Module {
|
||||||
|
|
||||||
let args: GemmaConfiguration
|
let args: GemmaConfiguration
|
||||||
let repeats: Int
|
|
||||||
let scale: Float
|
let scale: Float
|
||||||
|
|
||||||
@ModuleInfo(key: "q_proj") var wq: Linear
|
@ModuleInfo(key: "q_proj") var wq: Linear
|
||||||
@@ -48,8 +48,6 @@ private class Attention: Module {
|
|||||||
let heads = args.attentionHeads
|
let heads = args.attentionHeads
|
||||||
let kvHeads = args.kvHeads
|
let kvHeads = args.kvHeads
|
||||||
|
|
||||||
self.repeats = heads / kvHeads
|
|
||||||
|
|
||||||
let headDim = args.headDimensions
|
let headDim = args.headDimensions
|
||||||
self.scale = pow(Float(headDim), -0.5)
|
self.scale = pow(Float(headDim), -0.5)
|
||||||
|
|
||||||
@@ -76,11 +74,6 @@ private class Attention: Module {
|
|||||||
keys = keys.reshaped(B, L, args.kvHeads, -1).transposed(0, 2, 1, 3)
|
keys = keys.reshaped(B, L, args.kvHeads, -1).transposed(0, 2, 1, 3)
|
||||||
values = values.reshaped(B, L, args.kvHeads, -1).transposed(0, 2, 1, 3)
|
values = values.reshaped(B, L, args.kvHeads, -1).transposed(0, 2, 1, 3)
|
||||||
|
|
||||||
if repeats > 1 {
|
|
||||||
keys = MLXArray.repeat(keys, count: repeats, axis: 1)
|
|
||||||
values = MLXArray.repeat(values, count: repeats, axis: 1)
|
|
||||||
}
|
|
||||||
|
|
||||||
if let (keyCache, valueCache) = cache {
|
if let (keyCache, valueCache) = cache {
|
||||||
queries = rope(queries, offset: keyCache.dim(2))
|
queries = rope(queries, offset: keyCache.dim(2))
|
||||||
keys = rope(keys, offset: keyCache.dim(2))
|
keys = rope(keys, offset: keyCache.dim(2))
|
||||||
@@ -91,14 +84,11 @@ private class Attention: Module {
|
|||||||
keys = rope(keys)
|
keys = rope(keys)
|
||||||
}
|
}
|
||||||
|
|
||||||
var scores = (queries * self.scale).matmul(keys.transposed(0, 1, 3, 2))
|
let output = MLXFast.scaledDotProductAttention(
|
||||||
if let mask {
|
queries: queries, keys: keys, values: values, scale: scale, mask: mask
|
||||||
scores = scores + mask
|
)
|
||||||
}
|
.transposed(0, 2, 1, 3)
|
||||||
|
.reshaped(B, L, -1)
|
||||||
scores = softMax(scores.asType(.float32), axis: -1).asType(scores.dtype)
|
|
||||||
|
|
||||||
let output = matmul(scores, values).transposed(0, 2, 1, 3).reshaped(B, L, -1)
|
|
||||||
|
|
||||||
return (wo(output), (keys, values))
|
return (wo(output), (keys, values))
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -2,6 +2,7 @@
|
|||||||
|
|
||||||
import Foundation
|
import Foundation
|
||||||
import MLX
|
import MLX
|
||||||
|
import MLXFast
|
||||||
import MLXNN
|
import MLXNN
|
||||||
|
|
||||||
// port of https://github.com/ml-explore/mlx-examples/blob/main/llms/mlx_lm/models/llama.py
|
// port of https://github.com/ml-explore/mlx-examples/blob/main/llms/mlx_lm/models/llama.py
|
||||||
@@ -9,7 +10,6 @@ import MLXNN
|
|||||||
private class Attention: Module {
|
private class Attention: Module {
|
||||||
|
|
||||||
let args: LlamaConfiguration
|
let args: LlamaConfiguration
|
||||||
let repeats: Int
|
|
||||||
let scale: Float
|
let scale: Float
|
||||||
|
|
||||||
@ModuleInfo(key: "q_proj") var wq: Linear
|
@ModuleInfo(key: "q_proj") var wq: Linear
|
||||||
@@ -26,8 +26,6 @@ private class Attention: Module {
|
|||||||
let heads = args.attentionHeads
|
let heads = args.attentionHeads
|
||||||
let kvHeads = args.kvHeads
|
let kvHeads = args.kvHeads
|
||||||
|
|
||||||
self.repeats = heads / kvHeads
|
|
||||||
|
|
||||||
let headDim = args.hiddenSize / heads
|
let headDim = args.hiddenSize / heads
|
||||||
self.scale = pow(Float(headDim), -0.5)
|
self.scale = pow(Float(headDim), -0.5)
|
||||||
|
|
||||||
@@ -69,11 +67,6 @@ private class Attention: Module {
|
|||||||
keys = keys.reshaped(B, L, args.kvHeads, -1).transposed(0, 2, 1, 3)
|
keys = keys.reshaped(B, L, args.kvHeads, -1).transposed(0, 2, 1, 3)
|
||||||
values = values.reshaped(B, L, args.kvHeads, -1).transposed(0, 2, 1, 3)
|
values = values.reshaped(B, L, args.kvHeads, -1).transposed(0, 2, 1, 3)
|
||||||
|
|
||||||
if repeats > 1 {
|
|
||||||
keys = MLXArray.repeat(keys, count: repeats, axis: 1)
|
|
||||||
values = MLXArray.repeat(values, count: repeats, axis: 1)
|
|
||||||
}
|
|
||||||
|
|
||||||
if let (keyCache, valueCache) = cache {
|
if let (keyCache, valueCache) = cache {
|
||||||
queries = rope(queries, offset: keyCache.dim(2))
|
queries = rope(queries, offset: keyCache.dim(2))
|
||||||
keys = rope(keys, offset: keyCache.dim(2))
|
keys = rope(keys, offset: keyCache.dim(2))
|
||||||
@@ -84,14 +77,11 @@ private class Attention: Module {
|
|||||||
keys = rope(keys)
|
keys = rope(keys)
|
||||||
}
|
}
|
||||||
|
|
||||||
var scores = (queries * self.scale).matmul(keys.transposed(0, 1, 3, 2))
|
let output = MLXFast.scaledDotProductAttention(
|
||||||
if let mask {
|
queries: queries, keys: keys, values: values, scale: scale, mask: mask
|
||||||
scores = scores + mask
|
)
|
||||||
}
|
.transposed(0, 2, 1, 3)
|
||||||
|
.reshaped(B, L, -1)
|
||||||
scores = softMax(scores.asType(.float32), axis: -1).asType(scores.dtype)
|
|
||||||
|
|
||||||
let output = matmul(scores, values).transposed(0, 2, 1, 3).reshaped(B, L, -1)
|
|
||||||
|
|
||||||
return (wo(output), (keys, values))
|
return (wo(output), (keys, values))
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -2,6 +2,7 @@
|
|||||||
|
|
||||||
import Foundation
|
import Foundation
|
||||||
import MLX
|
import MLX
|
||||||
|
import MLXFast
|
||||||
import MLXNN
|
import MLXNN
|
||||||
|
|
||||||
// https://github.com/ml-explore/mlx-examples/blob/main/llms/mlx_lm/models/phi.py
|
// https://github.com/ml-explore/mlx-examples/blob/main/llms/mlx_lm/models/phi.py
|
||||||
@@ -17,7 +18,6 @@ private class PhiAttention: Module {
|
|||||||
let args: PhiConfiguration
|
let args: PhiConfiguration
|
||||||
let heads: Int
|
let heads: Int
|
||||||
let headDim: Int
|
let headDim: Int
|
||||||
let repeats: Int
|
|
||||||
|
|
||||||
@ModuleInfo(key: "q_proj") var wq: Linear
|
@ModuleInfo(key: "q_proj") var wq: Linear
|
||||||
@ModuleInfo(key: "k_proj") var wk: Linear
|
@ModuleInfo(key: "k_proj") var wk: Linear
|
||||||
@@ -33,7 +33,6 @@ private class PhiAttention: Module {
|
|||||||
self.heads = args.attentionHeads
|
self.heads = args.attentionHeads
|
||||||
self.headDim = args.hiddenSize / heads
|
self.headDim = args.hiddenSize / heads
|
||||||
let kvHeads = args.kvHeads
|
let kvHeads = args.kvHeads
|
||||||
self.repeats = heads / kvHeads
|
|
||||||
|
|
||||||
if headDim * heads != hiddenSize {
|
if headDim * heads != hiddenSize {
|
||||||
fatalError("hidden_size must be divisible by num_heads")
|
fatalError("hidden_size must be divisible by num_heads")
|
||||||
@@ -63,11 +62,6 @@ private class PhiAttention: Module {
|
|||||||
keys = keys.reshaped(B, L, args.kvHeads, headDim).transposed(0, 2, 1, 3)
|
keys = keys.reshaped(B, L, args.kvHeads, headDim).transposed(0, 2, 1, 3)
|
||||||
values = values.reshaped(B, L, args.kvHeads, headDim).transposed(0, 2, 1, 3)
|
values = values.reshaped(B, L, args.kvHeads, headDim).transposed(0, 2, 1, 3)
|
||||||
|
|
||||||
if repeats > 1 {
|
|
||||||
keys = MLXArray.repeat(keys, count: repeats, axis: 1)
|
|
||||||
values = MLXArray.repeat(values, count: repeats, axis: 1)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Add RoPE to the queries and keys and combine them with the cache
|
// Add RoPE to the queries and keys and combine them with the cache
|
||||||
if let (keyCache, valueCache) = cache {
|
if let (keyCache, valueCache) = cache {
|
||||||
queries = rope(queries, offset: keyCache.dim(2))
|
queries = rope(queries, offset: keyCache.dim(2))
|
||||||
@@ -84,15 +78,13 @@ private class PhiAttention: Module {
|
|||||||
|
|
||||||
// Finally perform the attention computation
|
// Finally perform the attention computation
|
||||||
let scale = sqrt(1 / Float(queries.dim(-1)))
|
let scale = sqrt(1 / Float(queries.dim(-1)))
|
||||||
var scores = (queries * scale).matmul(keys.transposed(0, 1, 3, 2))
|
let output = MLXFast.scaledDotProductAttention(
|
||||||
if let mask {
|
queries: queries, keys: keys, values: values, scale: scale, mask: mask
|
||||||
scores = scores + mask
|
)
|
||||||
}
|
.transposed(0, 2, 1, 3)
|
||||||
|
.reshaped(B, L, -1)
|
||||||
|
|
||||||
scores = softMax(scores, axis: -1).asType(values.dtype)
|
return (dense(output), (keys, values))
|
||||||
let valuesHat = matmul(scores, values).transposed(0, 2, 1, 3).reshaped(B, L, -1)
|
|
||||||
|
|
||||||
return (dense(valuesHat), (keys, values))
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -7,13 +7,13 @@
|
|||||||
|
|
||||||
import Foundation
|
import Foundation
|
||||||
import MLX
|
import MLX
|
||||||
|
import MLXFast
|
||||||
import MLXNN
|
import MLXNN
|
||||||
|
|
||||||
// port of https://github.com/ml-explore/mlx-examples/blob/main/llms/mlx_lm/models/qwen2.py
|
// port of https://github.com/ml-explore/mlx-examples/blob/main/llms/mlx_lm/models/qwen2.py
|
||||||
|
|
||||||
private class Attention: Module {
|
private class Attention: Module {
|
||||||
let args: Qwen2Configuration
|
let args: Qwen2Configuration
|
||||||
let repeats: Int
|
|
||||||
let scale: Float
|
let scale: Float
|
||||||
|
|
||||||
@ModuleInfo(key: "q_proj") var wq: Linear
|
@ModuleInfo(key: "q_proj") var wq: Linear
|
||||||
@@ -30,8 +30,6 @@ private class Attention: Module {
|
|||||||
let heads = args.attentionHeads
|
let heads = args.attentionHeads
|
||||||
let kvHeads = args.kvHeads
|
let kvHeads = args.kvHeads
|
||||||
|
|
||||||
self.repeats = heads / kvHeads
|
|
||||||
|
|
||||||
let headDim = args.hiddenSize / heads
|
let headDim = args.hiddenSize / heads
|
||||||
self.scale = pow(Float(headDim), -0.5)
|
self.scale = pow(Float(headDim), -0.5)
|
||||||
|
|
||||||
@@ -73,11 +71,6 @@ private class Attention: Module {
|
|||||||
keys = keys.reshaped(B, L, args.kvHeads, -1).transposed(0, 2, 1, 3)
|
keys = keys.reshaped(B, L, args.kvHeads, -1).transposed(0, 2, 1, 3)
|
||||||
values = values.reshaped(B, L, args.kvHeads, -1).transposed(0, 2, 1, 3)
|
values = values.reshaped(B, L, args.kvHeads, -1).transposed(0, 2, 1, 3)
|
||||||
|
|
||||||
if repeats > 1 {
|
|
||||||
keys = MLXArray.repeat(keys, count: repeats, axis: 1)
|
|
||||||
values = MLXArray.repeat(values, count: repeats, axis: 1)
|
|
||||||
}
|
|
||||||
|
|
||||||
if let (keyCache, valueCache) = cache {
|
if let (keyCache, valueCache) = cache {
|
||||||
queries = rope(queries, offset: keyCache.dim(2))
|
queries = rope(queries, offset: keyCache.dim(2))
|
||||||
keys = rope(keys, offset: keyCache.dim(2))
|
keys = rope(keys, offset: keyCache.dim(2))
|
||||||
@@ -88,14 +81,11 @@ private class Attention: Module {
|
|||||||
keys = rope(keys)
|
keys = rope(keys)
|
||||||
}
|
}
|
||||||
|
|
||||||
var scores = (queries * scale).matmul(keys.transposed(0, 1, 3, 2))
|
let output = MLXFast.scaledDotProductAttention(
|
||||||
if let mask {
|
queries: queries, keys: keys, values: values, scale: scale, mask: mask
|
||||||
scores = scores + mask
|
)
|
||||||
}
|
.transposed(0, 2, 1, 3)
|
||||||
|
.reshaped(B, L, -1)
|
||||||
scores = softMax(scores.asType(.float32), axis: -1).asType(scores.dtype)
|
|
||||||
|
|
||||||
let output = matmul(scores, values).transposed(0, 2, 1, 3).reshaped(B, L, -1)
|
|
||||||
|
|
||||||
return (wo(output), (keys, values))
|
return (wo(output), (keys, values))
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -7,13 +7,13 @@
|
|||||||
|
|
||||||
import Foundation
|
import Foundation
|
||||||
import MLX
|
import MLX
|
||||||
|
import MLXFast
|
||||||
import MLXNN
|
import MLXNN
|
||||||
|
|
||||||
// port of https://github.com/ml-explore/mlx-examples/blob/main/llms/mlx_lm/models/starcoder2.py
|
// port of https://github.com/ml-explore/mlx-examples/blob/main/llms/mlx_lm/models/starcoder2.py
|
||||||
|
|
||||||
private class Attention: Module {
|
private class Attention: Module {
|
||||||
let args: Starcoder2Configuration
|
let args: Starcoder2Configuration
|
||||||
let repeats: Int
|
|
||||||
let scale: Float
|
let scale: Float
|
||||||
|
|
||||||
@ModuleInfo(key: "q_proj") var wq: Linear
|
@ModuleInfo(key: "q_proj") var wq: Linear
|
||||||
@@ -30,8 +30,6 @@ private class Attention: Module {
|
|||||||
let heads = args.attentionHeads
|
let heads = args.attentionHeads
|
||||||
let kvHeads = args.kvHeads
|
let kvHeads = args.kvHeads
|
||||||
|
|
||||||
self.repeats = heads / kvHeads
|
|
||||||
|
|
||||||
let headDim = args.hiddenSize / heads
|
let headDim = args.hiddenSize / heads
|
||||||
self.scale = pow(Float(headDim), -0.5)
|
self.scale = pow(Float(headDim), -0.5)
|
||||||
|
|
||||||
@@ -57,11 +55,6 @@ private class Attention: Module {
|
|||||||
keys = keys.reshaped(B, L, args.kvHeads, -1).transposed(0, 2, 1, 3)
|
keys = keys.reshaped(B, L, args.kvHeads, -1).transposed(0, 2, 1, 3)
|
||||||
values = values.reshaped(B, L, args.kvHeads, -1).transposed(0, 2, 1, 3)
|
values = values.reshaped(B, L, args.kvHeads, -1).transposed(0, 2, 1, 3)
|
||||||
|
|
||||||
if repeats > 1 {
|
|
||||||
keys = MLXArray.repeat(keys, count: repeats, axis: 1)
|
|
||||||
values = MLXArray.repeat(values, count: repeats, axis: 1)
|
|
||||||
}
|
|
||||||
|
|
||||||
if let (keyCache, valueCache) = cache {
|
if let (keyCache, valueCache) = cache {
|
||||||
queries = rope(queries, offset: keyCache.dim(2))
|
queries = rope(queries, offset: keyCache.dim(2))
|
||||||
keys = rope(keys, offset: keyCache.dim(2))
|
keys = rope(keys, offset: keyCache.dim(2))
|
||||||
@@ -72,14 +65,11 @@ private class Attention: Module {
|
|||||||
keys = rope(keys)
|
keys = rope(keys)
|
||||||
}
|
}
|
||||||
|
|
||||||
var scores = (queries * scale).matmul(keys.transposed(0, 1, 3, 2))
|
let output = MLXFast.scaledDotProductAttention(
|
||||||
if let mask {
|
queries: queries, keys: keys, values: values, scale: scale, mask: mask
|
||||||
scores = scores + mask
|
)
|
||||||
}
|
.transposed(0, 2, 1, 3)
|
||||||
|
.reshaped(B, L, -1)
|
||||||
scores = softMax(scores.asType(.float32), axis: -1).asType(scores.dtype)
|
|
||||||
|
|
||||||
let output = matmul(scores, values).transposed(0, 2, 1, 3).reshaped(B, L, -1)
|
|
||||||
|
|
||||||
return (wo(output), (keys, values))
|
return (wo(output), (keys, values))
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -15,7 +15,7 @@
|
|||||||
"location" : "https://github.com/ml-explore/mlx-swift",
|
"location" : "https://github.com/ml-explore/mlx-swift",
|
||||||
"state" : {
|
"state" : {
|
||||||
"branch" : "main",
|
"branch" : "main",
|
||||||
"revision" : "eb249b04b1188b72c122223e5c640a41745a61b9"
|
"revision" : "948000ceaa27c343f4dd5ce40f727f221bf45c6e"
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
|
|||||||
Reference in New Issue
Block a user