From ac273a14eadb7a57083f781bc493dcf5b60d3280 Mon Sep 17 00:00:00 2001 From: David Koski <46639364+davidkoski@users.noreply.github.com> Date: Thu, 14 Mar 2024 13:18:40 -0700 Subject: [PATCH] fix float types in Phi (use float16) (#25) - per suggestions in #23 ensure that the values that go into the cache are float16 --- Libraries/LLM/Phi.swift | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/Libraries/LLM/Phi.swift b/Libraries/LLM/Phi.swift index 9a68e29..4c9763b 100644 --- a/Libraries/LLM/Phi.swift +++ b/Libraries/LLM/Phi.swift @@ -73,14 +73,12 @@ private class PhiAttention: Module { keys = rope(keys) } - queries = queries.asType(Float.self) - keys = keys.asType(Float.self) - // Finally perform the attention computation let scale = sqrt(1 / Float(queries.dim(-1))) let output = MLXFast.scaledDotProductAttention( - queries: queries, keys: keys, values: values, scale: scale, mask: mask + queries: queries.asType(.float32), keys: keys, values: values, scale: scale, mask: mask ) + .asType(values.dtype) .transposed(0, 2, 1, 3) .reshaped(B, L, -1)