fix rmsnorm for gemma

This commit is contained in:
David Koski
2024-02-26 14:09:48 -08:00
parent a2ff291608
commit c7919cf7fe

View File

@@ -19,10 +19,7 @@ private class RMSNorm: Module, UnaryLayer {
} }
func norm(_ x: MLXArray) -> MLXArray { func norm(_ x: MLXArray) -> MLXArray {
let S = 1.0 / sqrt(Float(x.dim(-1))) x * rsqrt(x.square().mean(axis: -1, keepDims: true) + eps)
let n = (x * S).square().sum(axis: -1, keepDims: true)
return rsqrt(n + eps)
} }
public func callAsFunction(_ x: MLXArray) -> MLXArray { public func callAsFunction(_ x: MLXArray) -> MLXArray {