diff --git a/Applications/LLMEval/ContentView.swift b/Applications/LLMEval/ContentView.swift index 78ff069..6fd9ec5 100644 --- a/Applications/LLMEval/ContentView.swift +++ b/Applications/LLMEval/ContentView.swift @@ -85,7 +85,7 @@ struct ContentView: View { .toolbar { ToolbarItem { Label( - "GPU Usage: \(deviceStat.gpuUsage.activeMemory.formatted(.byteCount(style: .memory)))", + "Memory Usage: \(deviceStat.gpuUsage.activeMemory.formatted(.byteCount(style: .memory)))", systemImage: "info.circle.fill" ) .labelStyle(.titleAndIcon) diff --git a/Libraries/LLM/Gemma.swift b/Libraries/LLM/Gemma.swift index fef085c..5be0144 100644 --- a/Libraries/LLM/Gemma.swift +++ b/Libraries/LLM/Gemma.swift @@ -9,7 +9,6 @@ import MLXNN // specialized norm for gemma private class RMSNorm: Module, UnaryLayer { - let weight: MLXArray let eps: Float @@ -19,13 +18,8 @@ private class RMSNorm: Module, UnaryLayer { super.init() } - func norm(_ x: MLXArray) -> MLXArray { - x * rsqrt(x.square().mean(axis: -1, keepDims: true) + eps) - } - public func callAsFunction(_ x: MLXArray) -> MLXArray { - let output = norm(x.asType(Float.self)).asType(x.dtype) - return (1 + weight) * output + return MLXFast.rmsNorm(x, weight: 1.0 + self.weight, eps: self.eps) } } diff --git a/Libraries/LLM/Phi.swift b/Libraries/LLM/Phi.swift index 4c9763b..31facd5 100644 --- a/Libraries/LLM/Phi.swift +++ b/Libraries/LLM/Phi.swift @@ -7,12 +7,6 @@ import MLXNN // https://github.com/ml-explore/mlx-examples/blob/main/llms/mlx_lm/models/phi.py -private class LayerNorm: MLXNN.LayerNorm { - override func callAsFunction(_ x: MLXArray) -> MLXArray { - super.callAsFunction(x.asType(Float.self)).asType(x.dtype) - } -} - private class PhiAttention: Module { let args: PhiConfiguration diff --git a/mlx-swift-examples.xcodeproj/project.xcworkspace/xcshareddata/swiftpm/Package.resolved b/mlx-swift-examples.xcodeproj/project.xcworkspace/xcshareddata/swiftpm/Package.resolved index 6819684..8f53509 100644 --- a/mlx-swift-examples.xcodeproj/project.xcworkspace/xcshareddata/swiftpm/Package.resolved +++ b/mlx-swift-examples.xcodeproj/project.xcworkspace/xcshareddata/swiftpm/Package.resolved @@ -1,5 +1,5 @@ { - "originHash" : "73a943caf561dd1482e57053ba01456f25dea3267b2a84a996e42284e17aa6fc", + "originHash" : "da53546673b6d05016b6e5640c18814c7dba5b5af8db34715afe6d633037c758", "pins" : [ { "identity" : "gzipswift", @@ -16,7 +16,7 @@ "location" : "https://github.com/ml-explore/mlx-swift", "state" : { "branch" : "main", - "revision" : "24e71937e12efe01a0d28a429a703036fae2ff8a" + "revision" : "5e51224ac869366017859dc0b07f6d2dc51b3bae" } }, { @@ -83,5 +83,5 @@ } } ], - "version" : 2 + "version" : 3 }