From 0fb74cbfdc767b252411b4d6fdbdc70c54f28f02 Mon Sep 17 00:00:00 2001 From: David Koski <46639364+davidkoski@users.noreply.github.com> Date: Tue, 12 Mar 2024 14:04:43 -0700 Subject: [PATCH] adopt MLXFast.scaledDotProductAttention (#23) --- .circleci/config.yml | 9 ++++++-- Libraries/LLM/Gemma.swift | 22 +++++-------------- Libraries/LLM/Llama.swift | 22 +++++-------------- Libraries/LLM/Phi.swift | 22 ++++++------------- Libraries/LLM/Qwen2.swift | 22 +++++-------------- Libraries/LLM/Starcoder2.swift | 22 +++++-------------- .../xcshareddata/swiftpm/Package.resolved | 2 +- 7 files changed, 39 insertions(+), 82 deletions(-) diff --git a/.circleci/config.yml b/.circleci/config.yml index 328335c..7ab611f 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -43,12 +43,17 @@ workflows: when: and: - matches: - pattern: "^(?!pull/)[-\\w]+$" + pattern: "^(?!pull/)[-_./\\w]+$" value: << pipeline.git.branch >> - not: << pipeline.parameters.nightly_build >> - not: << pipeline.parameters.weekly_build >> jobs: - - mac_build_and_test + - hold: + type: approval + - apple/authenticate: + context: pr-approval + - mac_build_and_test: + requires: [ hold ] prb: when: diff --git a/Libraries/LLM/Gemma.swift b/Libraries/LLM/Gemma.swift index 0e06851..fef085c 100644 --- a/Libraries/LLM/Gemma.swift +++ b/Libraries/LLM/Gemma.swift @@ -2,6 +2,7 @@ import Foundation import MLX +import MLXFast import MLXNN // port of https://github.com/ml-explore/mlx-examples/blob/main/llms/mlx_lm/models/gemma.py @@ -31,7 +32,6 @@ private class RMSNorm: Module, UnaryLayer { private class Attention: Module { let args: GemmaConfiguration - let repeats: Int let scale: Float @ModuleInfo(key: "q_proj") var wq: Linear @@ -48,8 +48,6 @@ private class Attention: Module { let heads = args.attentionHeads let kvHeads = args.kvHeads - self.repeats = heads / kvHeads - let headDim = args.headDimensions self.scale = pow(Float(headDim), -0.5) @@ -76,11 +74,6 @@ private class Attention: Module { keys = keys.reshaped(B, L, args.kvHeads, -1).transposed(0, 2, 1, 3) values = values.reshaped(B, L, args.kvHeads, -1).transposed(0, 2, 1, 3) - if repeats > 1 { - keys = MLXArray.repeat(keys, count: repeats, axis: 1) - values = MLXArray.repeat(values, count: repeats, axis: 1) - } - if let (keyCache, valueCache) = cache { queries = rope(queries, offset: keyCache.dim(2)) keys = rope(keys, offset: keyCache.dim(2)) @@ -91,14 +84,11 @@ private class Attention: Module { keys = rope(keys) } - var scores = (queries * self.scale).matmul(keys.transposed(0, 1, 3, 2)) - if let mask { - scores = scores + mask - } - - scores = softMax(scores.asType(.float32), axis: -1).asType(scores.dtype) - - let output = matmul(scores, values).transposed(0, 2, 1, 3).reshaped(B, L, -1) + let output = MLXFast.scaledDotProductAttention( + queries: queries, keys: keys, values: values, scale: scale, mask: mask + ) + .transposed(0, 2, 1, 3) + .reshaped(B, L, -1) return (wo(output), (keys, values)) } diff --git a/Libraries/LLM/Llama.swift b/Libraries/LLM/Llama.swift index 90d2296..273b5f7 100644 --- a/Libraries/LLM/Llama.swift +++ b/Libraries/LLM/Llama.swift @@ -2,6 +2,7 @@ import Foundation import MLX +import MLXFast import MLXNN // port of https://github.com/ml-explore/mlx-examples/blob/main/llms/mlx_lm/models/llama.py @@ -9,7 +10,6 @@ import MLXNN private class Attention: Module { let args: LlamaConfiguration - let repeats: Int let scale: Float @ModuleInfo(key: "q_proj") var wq: Linear @@ -26,8 +26,6 @@ private class Attention: Module { let heads = args.attentionHeads let kvHeads = args.kvHeads - self.repeats = heads / kvHeads - let headDim = args.hiddenSize / heads self.scale = pow(Float(headDim), -0.5) @@ -69,11 +67,6 @@ private class Attention: Module { keys = keys.reshaped(B, L, args.kvHeads, -1).transposed(0, 2, 1, 3) values = values.reshaped(B, L, args.kvHeads, -1).transposed(0, 2, 1, 3) - if repeats > 1 { - keys = MLXArray.repeat(keys, count: repeats, axis: 1) - values = MLXArray.repeat(values, count: repeats, axis: 1) - } - if let (keyCache, valueCache) = cache { queries = rope(queries, offset: keyCache.dim(2)) keys = rope(keys, offset: keyCache.dim(2)) @@ -84,14 +77,11 @@ private class Attention: Module { keys = rope(keys) } - var scores = (queries * self.scale).matmul(keys.transposed(0, 1, 3, 2)) - if let mask { - scores = scores + mask - } - - scores = softMax(scores.asType(.float32), axis: -1).asType(scores.dtype) - - let output = matmul(scores, values).transposed(0, 2, 1, 3).reshaped(B, L, -1) + let output = MLXFast.scaledDotProductAttention( + queries: queries, keys: keys, values: values, scale: scale, mask: mask + ) + .transposed(0, 2, 1, 3) + .reshaped(B, L, -1) return (wo(output), (keys, values)) } diff --git a/Libraries/LLM/Phi.swift b/Libraries/LLM/Phi.swift index e4a55eb..9a68e29 100644 --- a/Libraries/LLM/Phi.swift +++ b/Libraries/LLM/Phi.swift @@ -2,6 +2,7 @@ import Foundation import MLX +import MLXFast import MLXNN // https://github.com/ml-explore/mlx-examples/blob/main/llms/mlx_lm/models/phi.py @@ -17,7 +18,6 @@ private class PhiAttention: Module { let args: PhiConfiguration let heads: Int let headDim: Int - let repeats: Int @ModuleInfo(key: "q_proj") var wq: Linear @ModuleInfo(key: "k_proj") var wk: Linear @@ -33,7 +33,6 @@ private class PhiAttention: Module { self.heads = args.attentionHeads self.headDim = args.hiddenSize / heads let kvHeads = args.kvHeads - self.repeats = heads / kvHeads if headDim * heads != hiddenSize { fatalError("hidden_size must be divisible by num_heads") @@ -63,11 +62,6 @@ private class PhiAttention: Module { keys = keys.reshaped(B, L, args.kvHeads, headDim).transposed(0, 2, 1, 3) values = values.reshaped(B, L, args.kvHeads, headDim).transposed(0, 2, 1, 3) - if repeats > 1 { - keys = MLXArray.repeat(keys, count: repeats, axis: 1) - values = MLXArray.repeat(values, count: repeats, axis: 1) - } - // Add RoPE to the queries and keys and combine them with the cache if let (keyCache, valueCache) = cache { queries = rope(queries, offset: keyCache.dim(2)) @@ -84,15 +78,13 @@ private class PhiAttention: Module { // Finally perform the attention computation let scale = sqrt(1 / Float(queries.dim(-1))) - var scores = (queries * scale).matmul(keys.transposed(0, 1, 3, 2)) - if let mask { - scores = scores + mask - } + let output = MLXFast.scaledDotProductAttention( + queries: queries, keys: keys, values: values, scale: scale, mask: mask + ) + .transposed(0, 2, 1, 3) + .reshaped(B, L, -1) - scores = softMax(scores, axis: -1).asType(values.dtype) - let valuesHat = matmul(scores, values).transposed(0, 2, 1, 3).reshaped(B, L, -1) - - return (dense(valuesHat), (keys, values)) + return (dense(output), (keys, values)) } } diff --git a/Libraries/LLM/Qwen2.swift b/Libraries/LLM/Qwen2.swift index 5d627b0..3f6753a 100644 --- a/Libraries/LLM/Qwen2.swift +++ b/Libraries/LLM/Qwen2.swift @@ -7,13 +7,13 @@ import Foundation import MLX +import MLXFast import MLXNN // port of https://github.com/ml-explore/mlx-examples/blob/main/llms/mlx_lm/models/qwen2.py private class Attention: Module { let args: Qwen2Configuration - let repeats: Int let scale: Float @ModuleInfo(key: "q_proj") var wq: Linear @@ -30,8 +30,6 @@ private class Attention: Module { let heads = args.attentionHeads let kvHeads = args.kvHeads - self.repeats = heads / kvHeads - let headDim = args.hiddenSize / heads self.scale = pow(Float(headDim), -0.5) @@ -73,11 +71,6 @@ private class Attention: Module { keys = keys.reshaped(B, L, args.kvHeads, -1).transposed(0, 2, 1, 3) values = values.reshaped(B, L, args.kvHeads, -1).transposed(0, 2, 1, 3) - if repeats > 1 { - keys = MLXArray.repeat(keys, count: repeats, axis: 1) - values = MLXArray.repeat(values, count: repeats, axis: 1) - } - if let (keyCache, valueCache) = cache { queries = rope(queries, offset: keyCache.dim(2)) keys = rope(keys, offset: keyCache.dim(2)) @@ -88,14 +81,11 @@ private class Attention: Module { keys = rope(keys) } - var scores = (queries * scale).matmul(keys.transposed(0, 1, 3, 2)) - if let mask { - scores = scores + mask - } - - scores = softMax(scores.asType(.float32), axis: -1).asType(scores.dtype) - - let output = matmul(scores, values).transposed(0, 2, 1, 3).reshaped(B, L, -1) + let output = MLXFast.scaledDotProductAttention( + queries: queries, keys: keys, values: values, scale: scale, mask: mask + ) + .transposed(0, 2, 1, 3) + .reshaped(B, L, -1) return (wo(output), (keys, values)) } diff --git a/Libraries/LLM/Starcoder2.swift b/Libraries/LLM/Starcoder2.swift index c577988..b9bdae0 100644 --- a/Libraries/LLM/Starcoder2.swift +++ b/Libraries/LLM/Starcoder2.swift @@ -7,13 +7,13 @@ import Foundation import MLX +import MLXFast import MLXNN // port of https://github.com/ml-explore/mlx-examples/blob/main/llms/mlx_lm/models/starcoder2.py private class Attention: Module { let args: Starcoder2Configuration - let repeats: Int let scale: Float @ModuleInfo(key: "q_proj") var wq: Linear @@ -30,8 +30,6 @@ private class Attention: Module { let heads = args.attentionHeads let kvHeads = args.kvHeads - self.repeats = heads / kvHeads - let headDim = args.hiddenSize / heads self.scale = pow(Float(headDim), -0.5) @@ -57,11 +55,6 @@ private class Attention: Module { keys = keys.reshaped(B, L, args.kvHeads, -1).transposed(0, 2, 1, 3) values = values.reshaped(B, L, args.kvHeads, -1).transposed(0, 2, 1, 3) - if repeats > 1 { - keys = MLXArray.repeat(keys, count: repeats, axis: 1) - values = MLXArray.repeat(values, count: repeats, axis: 1) - } - if let (keyCache, valueCache) = cache { queries = rope(queries, offset: keyCache.dim(2)) keys = rope(keys, offset: keyCache.dim(2)) @@ -72,14 +65,11 @@ private class Attention: Module { keys = rope(keys) } - var scores = (queries * scale).matmul(keys.transposed(0, 1, 3, 2)) - if let mask { - scores = scores + mask - } - - scores = softMax(scores.asType(.float32), axis: -1).asType(scores.dtype) - - let output = matmul(scores, values).transposed(0, 2, 1, 3).reshaped(B, L, -1) + let output = MLXFast.scaledDotProductAttention( + queries: queries, keys: keys, values: values, scale: scale, mask: mask + ) + .transposed(0, 2, 1, 3) + .reshaped(B, L, -1) return (wo(output), (keys, values)) } diff --git a/mlx-swift-examples.xcodeproj/project.xcworkspace/xcshareddata/swiftpm/Package.resolved b/mlx-swift-examples.xcodeproj/project.xcworkspace/xcshareddata/swiftpm/Package.resolved index 09ab51b..4b4f79e 100644 --- a/mlx-swift-examples.xcodeproj/project.xcworkspace/xcshareddata/swiftpm/Package.resolved +++ b/mlx-swift-examples.xcodeproj/project.xcworkspace/xcshareddata/swiftpm/Package.resolved @@ -15,7 +15,7 @@ "location" : "https://github.com/ml-explore/mlx-swift", "state" : { "branch" : "main", - "revision" : "eb249b04b1188b72c122223e5c640a41745a61b9" + "revision" : "948000ceaa27c343f4dd5ce40f727f221bf45c6e" } }, {