apply swift-format

This commit is contained in:
David Koski
2024-03-03 18:40:49 -08:00
parent 0f454999a4
commit dfc9f2fc01
3 changed files with 13 additions and 13 deletions

View File

@@ -84,7 +84,7 @@ extension ModelConfiguration {
) { prompt in ) { prompt in
"<start_of_turn>user \(prompt)<end_of_turn><start_of_turn>model" "<start_of_turn>user \(prompt)<end_of_turn><start_of_turn>model"
} }
public static let qwen205b4bit = ModelConfiguration( public static let qwen205b4bit = ModelConfiguration(
id: "mlx-community/Qwen1.5-0.5B-Chat-4bit", id: "mlx-community/Qwen1.5-0.5B-Chat-4bit",
overrideTokenizer: "PreTrainedTokenizer" overrideTokenizer: "PreTrainedTokenizer"

View File

@@ -42,7 +42,7 @@ private class Attention: Module {
let ropeScale: Float let ropeScale: Float
if let ropeScaling = args.ropeScaling, ropeScaling["type"] == .string("linear"), if let ropeScaling = args.ropeScaling, ropeScaling["type"] == .string("linear"),
let factor = ropeScaling["factor"] let factor = ropeScaling["factor"]
{ {
switch factor { switch factor {
case .string: case .string:
@@ -60,8 +60,8 @@ private class Attention: Module {
} }
public func callAsFunction( public func callAsFunction(
_ x: MLXArray, mask: MLXArray? = nil, cache: (MLXArray, MLXArray)? = nil) -> (MLXArray, (MLXArray, MLXArray)) _ x: MLXArray, mask: MLXArray? = nil, cache: (MLXArray, MLXArray)? = nil
{ ) -> (MLXArray, (MLXArray, MLXArray)) {
let (B, L) = (x.dim(0), x.dim(1)) let (B, L) = (x.dim(0), x.dim(1))
var queries = wq(x) var queries = wq(x)
@@ -134,8 +134,8 @@ private class TransformerBlock: Module {
} }
public func callAsFunction( public func callAsFunction(
_ x: MLXArray, mask: MLXArray? = nil, cache: (MLXArray, MLXArray)? = nil) -> (MLXArray, (MLXArray, MLXArray)) _ x: MLXArray, mask: MLXArray? = nil, cache: (MLXArray, MLXArray)? = nil
{ ) -> (MLXArray, (MLXArray, MLXArray)) {
var (r, cache) = attention(inputLayerNorm(x), mask: mask, cache: cache) var (r, cache) = attention(inputLayerNorm(x), mask: mask, cache: cache)
let h = x + r let h = x + r
r = mlp(postAttentionLayerNorm(h)) r = mlp(postAttentionLayerNorm(h))
@@ -164,8 +164,8 @@ public class Qwen2ModelInner: Module {
} }
public func callAsFunction(_ inputs: MLXArray, cache: [(MLXArray, MLXArray)]? = nil) -> ( public func callAsFunction(_ inputs: MLXArray, cache: [(MLXArray, MLXArray)]? = nil) -> (
MLXArray, [(MLXArray, MLXArray)]) MLXArray, [(MLXArray, MLXArray)]
{ ) {
var h = embedTokens(inputs) var h = embedTokens(inputs)
var mask: MLXArray? = nil var mask: MLXArray? = nil
@@ -199,8 +199,8 @@ public class Qwen2Model: Module, LLMModel {
} }
public func callAsFunction(_ inputs: MLXArray, cache: [(MLXArray, MLXArray)]?) -> ( public func callAsFunction(_ inputs: MLXArray, cache: [(MLXArray, MLXArray)]?) -> (
MLXArray, [(MLXArray, MLXArray)]) MLXArray, [(MLXArray, MLXArray)]
{ ) {
let (out, cache) = model(inputs, cache: cache) let (out, cache) = model(inputs, cache: cache)
return (lmHead(out), cache) return (lmHead(out), cache)
} }

View File

@@ -42,9 +42,9 @@ struct SyncGenerator: AsyncParsableCommand {
let modelConfiguration = ModelConfiguration.configuration(id: model) let modelConfiguration = ModelConfiguration.configuration(id: model)
let (model, tokenizer) = try await load(configuration: modelConfiguration) let (model, tokenizer) = try await load(configuration: modelConfiguration)
print("Model loaded -> \(self.model)") print("Model loaded -> \(self.model)")
let prompt = modelConfiguration.prepare(prompt: self.prompt) let prompt = modelConfiguration.prepare(prompt: self.prompt)
let promptTokens = tokenizer.encode(text: prompt) let promptTokens = tokenizer.encode(text: prompt)
@@ -133,7 +133,7 @@ struct AsyncGenerator: AsyncParsableCommand {
let modelConfiguration = ModelConfiguration.configuration(id: model) let modelConfiguration = ModelConfiguration.configuration(id: model)
let (model, tokenizer) = try await load(configuration: modelConfiguration) let (model, tokenizer) = try await load(configuration: modelConfiguration)
print("Model loaded -> \(self.model)") print("Model loaded -> \(self.model)")
let prompt = modelConfiguration.prepare(prompt: self.prompt) let prompt = modelConfiguration.prepare(prompt: self.prompt)