Use fast (#38)
* update to latest mlx swift and use fast norms * gpu usage -> memory usage
This commit is contained in:
@@ -85,7 +85,7 @@ struct ContentView: View {
|
|||||||
.toolbar {
|
.toolbar {
|
||||||
ToolbarItem {
|
ToolbarItem {
|
||||||
Label(
|
Label(
|
||||||
"GPU Usage: \(deviceStat.gpuUsage.activeMemory.formatted(.byteCount(style: .memory)))",
|
"Memory Usage: \(deviceStat.gpuUsage.activeMemory.formatted(.byteCount(style: .memory)))",
|
||||||
systemImage: "info.circle.fill"
|
systemImage: "info.circle.fill"
|
||||||
)
|
)
|
||||||
.labelStyle(.titleAndIcon)
|
.labelStyle(.titleAndIcon)
|
||||||
|
|||||||
@@ -9,7 +9,6 @@ import MLXNN
|
|||||||
|
|
||||||
// specialized norm for gemma
|
// specialized norm for gemma
|
||||||
private class RMSNorm: Module, UnaryLayer {
|
private class RMSNorm: Module, UnaryLayer {
|
||||||
|
|
||||||
let weight: MLXArray
|
let weight: MLXArray
|
||||||
let eps: Float
|
let eps: Float
|
||||||
|
|
||||||
@@ -19,13 +18,8 @@ private class RMSNorm: Module, UnaryLayer {
|
|||||||
super.init()
|
super.init()
|
||||||
}
|
}
|
||||||
|
|
||||||
func norm(_ x: MLXArray) -> MLXArray {
|
|
||||||
x * rsqrt(x.square().mean(axis: -1, keepDims: true) + eps)
|
|
||||||
}
|
|
||||||
|
|
||||||
public func callAsFunction(_ x: MLXArray) -> MLXArray {
|
public func callAsFunction(_ x: MLXArray) -> MLXArray {
|
||||||
let output = norm(x.asType(Float.self)).asType(x.dtype)
|
return MLXFast.rmsNorm(x, weight: 1.0 + self.weight, eps: self.eps)
|
||||||
return (1 + weight) * output
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -7,12 +7,6 @@ import MLXNN
|
|||||||
|
|
||||||
// https://github.com/ml-explore/mlx-examples/blob/main/llms/mlx_lm/models/phi.py
|
// https://github.com/ml-explore/mlx-examples/blob/main/llms/mlx_lm/models/phi.py
|
||||||
|
|
||||||
private class LayerNorm: MLXNN.LayerNorm {
|
|
||||||
override func callAsFunction(_ x: MLXArray) -> MLXArray {
|
|
||||||
super.callAsFunction(x.asType(Float.self)).asType(x.dtype)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
private class PhiAttention: Module {
|
private class PhiAttention: Module {
|
||||||
|
|
||||||
let args: PhiConfiguration
|
let args: PhiConfiguration
|
||||||
|
|||||||
@@ -1,5 +1,5 @@
|
|||||||
{
|
{
|
||||||
"originHash" : "73a943caf561dd1482e57053ba01456f25dea3267b2a84a996e42284e17aa6fc",
|
"originHash" : "da53546673b6d05016b6e5640c18814c7dba5b5af8db34715afe6d633037c758",
|
||||||
"pins" : [
|
"pins" : [
|
||||||
{
|
{
|
||||||
"identity" : "gzipswift",
|
"identity" : "gzipswift",
|
||||||
@@ -16,7 +16,7 @@
|
|||||||
"location" : "https://github.com/ml-explore/mlx-swift",
|
"location" : "https://github.com/ml-explore/mlx-swift",
|
||||||
"state" : {
|
"state" : {
|
||||||
"branch" : "main",
|
"branch" : "main",
|
||||||
"revision" : "24e71937e12efe01a0d28a429a703036fae2ff8a"
|
"revision" : "5e51224ac869366017859dc0b07f6d2dc51b3bae"
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@@ -83,5 +83,5 @@
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
"version" : 2
|
"version" : 3
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user