From c7919cf7fe3efd6dd754c19a1d976de3d9a3db14 Mon Sep 17 00:00:00 2001 From: David Koski Date: Mon, 26 Feb 2024 14:09:48 -0800 Subject: [PATCH] fix rmsnorm for gemma --- Libraries/LLM/Gemma.swift | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/Libraries/LLM/Gemma.swift b/Libraries/LLM/Gemma.swift index ae3a7fa..0e06851 100644 --- a/Libraries/LLM/Gemma.swift +++ b/Libraries/LLM/Gemma.swift @@ -19,10 +19,7 @@ private class RMSNorm: Module, UnaryLayer { } func norm(_ x: MLXArray) -> MLXArray { - let S = 1.0 / sqrt(Float(x.dim(-1))) - - let n = (x * S).square().sum(axis: -1, keepDims: true) - return rsqrt(n + eps) + x * rsqrt(x.square().mean(axis: -1, keepDims: true) + eps) } public func callAsFunction(_ x: MLXArray) -> MLXArray {