apply swift-format

This commit is contained in:
David Koski
2024-03-03 18:40:49 -08:00
parent 0f454999a4
commit dfc9f2fc01
3 changed files with 13 additions and 13 deletions

View File

@@ -42,7 +42,7 @@ private class Attention: Module {
let ropeScale: Float let ropeScale: Float
if let ropeScaling = args.ropeScaling, ropeScaling["type"] == .string("linear"), if let ropeScaling = args.ropeScaling, ropeScaling["type"] == .string("linear"),
let factor = ropeScaling["factor"] let factor = ropeScaling["factor"]
{ {
switch factor { switch factor {
case .string: case .string:
@@ -60,8 +60,8 @@ private class Attention: Module {
} }
public func callAsFunction( public func callAsFunction(
_ x: MLXArray, mask: MLXArray? = nil, cache: (MLXArray, MLXArray)? = nil) -> (MLXArray, (MLXArray, MLXArray)) _ x: MLXArray, mask: MLXArray? = nil, cache: (MLXArray, MLXArray)? = nil
{ ) -> (MLXArray, (MLXArray, MLXArray)) {
let (B, L) = (x.dim(0), x.dim(1)) let (B, L) = (x.dim(0), x.dim(1))
var queries = wq(x) var queries = wq(x)
@@ -134,8 +134,8 @@ private class TransformerBlock: Module {
} }
public func callAsFunction( public func callAsFunction(
_ x: MLXArray, mask: MLXArray? = nil, cache: (MLXArray, MLXArray)? = nil) -> (MLXArray, (MLXArray, MLXArray)) _ x: MLXArray, mask: MLXArray? = nil, cache: (MLXArray, MLXArray)? = nil
{ ) -> (MLXArray, (MLXArray, MLXArray)) {
var (r, cache) = attention(inputLayerNorm(x), mask: mask, cache: cache) var (r, cache) = attention(inputLayerNorm(x), mask: mask, cache: cache)
let h = x + r let h = x + r
r = mlp(postAttentionLayerNorm(h)) r = mlp(postAttentionLayerNorm(h))
@@ -164,8 +164,8 @@ public class Qwen2ModelInner: Module {
} }
public func callAsFunction(_ inputs: MLXArray, cache: [(MLXArray, MLXArray)]? = nil) -> ( public func callAsFunction(_ inputs: MLXArray, cache: [(MLXArray, MLXArray)]? = nil) -> (
MLXArray, [(MLXArray, MLXArray)]) MLXArray, [(MLXArray, MLXArray)]
{ ) {
var h = embedTokens(inputs) var h = embedTokens(inputs)
var mask: MLXArray? = nil var mask: MLXArray? = nil
@@ -199,8 +199,8 @@ public class Qwen2Model: Module, LLMModel {
} }
public func callAsFunction(_ inputs: MLXArray, cache: [(MLXArray, MLXArray)]?) -> ( public func callAsFunction(_ inputs: MLXArray, cache: [(MLXArray, MLXArray)]?) -> (
MLXArray, [(MLXArray, MLXArray)]) MLXArray, [(MLXArray, MLXArray)]
{ ) {
let (out, cache) = model(inputs, cache: cache) let (out, cache) = model(inputs, cache: cache)
return (lmHead(out), cache) return (lmHead(out), cache)
} }