apply swift-format
This commit is contained in:
@@ -60,8 +60,8 @@ private class Attention: Module {
|
|||||||
}
|
}
|
||||||
|
|
||||||
public func callAsFunction(
|
public func callAsFunction(
|
||||||
_ x: MLXArray, mask: MLXArray? = nil, cache: (MLXArray, MLXArray)? = nil) -> (MLXArray, (MLXArray, MLXArray))
|
_ x: MLXArray, mask: MLXArray? = nil, cache: (MLXArray, MLXArray)? = nil
|
||||||
{
|
) -> (MLXArray, (MLXArray, MLXArray)) {
|
||||||
let (B, L) = (x.dim(0), x.dim(1))
|
let (B, L) = (x.dim(0), x.dim(1))
|
||||||
|
|
||||||
var queries = wq(x)
|
var queries = wq(x)
|
||||||
@@ -134,8 +134,8 @@ private class TransformerBlock: Module {
|
|||||||
}
|
}
|
||||||
|
|
||||||
public func callAsFunction(
|
public func callAsFunction(
|
||||||
_ x: MLXArray, mask: MLXArray? = nil, cache: (MLXArray, MLXArray)? = nil) -> (MLXArray, (MLXArray, MLXArray))
|
_ x: MLXArray, mask: MLXArray? = nil, cache: (MLXArray, MLXArray)? = nil
|
||||||
{
|
) -> (MLXArray, (MLXArray, MLXArray)) {
|
||||||
var (r, cache) = attention(inputLayerNorm(x), mask: mask, cache: cache)
|
var (r, cache) = attention(inputLayerNorm(x), mask: mask, cache: cache)
|
||||||
let h = x + r
|
let h = x + r
|
||||||
r = mlp(postAttentionLayerNorm(h))
|
r = mlp(postAttentionLayerNorm(h))
|
||||||
@@ -164,8 +164,8 @@ public class Qwen2ModelInner: Module {
|
|||||||
}
|
}
|
||||||
|
|
||||||
public func callAsFunction(_ inputs: MLXArray, cache: [(MLXArray, MLXArray)]? = nil) -> (
|
public func callAsFunction(_ inputs: MLXArray, cache: [(MLXArray, MLXArray)]? = nil) -> (
|
||||||
MLXArray, [(MLXArray, MLXArray)])
|
MLXArray, [(MLXArray, MLXArray)]
|
||||||
{
|
) {
|
||||||
var h = embedTokens(inputs)
|
var h = embedTokens(inputs)
|
||||||
|
|
||||||
var mask: MLXArray? = nil
|
var mask: MLXArray? = nil
|
||||||
@@ -199,8 +199,8 @@ public class Qwen2Model: Module, LLMModel {
|
|||||||
}
|
}
|
||||||
|
|
||||||
public func callAsFunction(_ inputs: MLXArray, cache: [(MLXArray, MLXArray)]?) -> (
|
public func callAsFunction(_ inputs: MLXArray, cache: [(MLXArray, MLXArray)]?) -> (
|
||||||
MLXArray, [(MLXArray, MLXArray)])
|
MLXArray, [(MLXArray, MLXArray)]
|
||||||
{
|
) {
|
||||||
let (out, cache) = model(inputs, cache: cache)
|
let (out, cache) = model(inputs, cache: cache)
|
||||||
return (lmHead(out), cache)
|
return (lmHead(out), cache)
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user