From dfc9f2fc019e8857836da28817325330222266f5 Mon Sep 17 00:00:00 2001 From: David Koski Date: Sun, 3 Mar 2024 18:40:49 -0800 Subject: [PATCH] apply swift-format --- Libraries/LLM/Models.swift | 2 +- Libraries/LLM/Qwen2.swift | 18 +++++++++--------- Tools/llm-tool/LLMTool.swift | 6 +++--- 3 files changed, 13 insertions(+), 13 deletions(-) diff --git a/Libraries/LLM/Models.swift b/Libraries/LLM/Models.swift index 309cf85..64c515e 100644 --- a/Libraries/LLM/Models.swift +++ b/Libraries/LLM/Models.swift @@ -84,7 +84,7 @@ extension ModelConfiguration { ) { prompt in "user \(prompt)model" } - + public static let qwen205b4bit = ModelConfiguration( id: "mlx-community/Qwen1.5-0.5B-Chat-4bit", overrideTokenizer: "PreTrainedTokenizer" diff --git a/Libraries/LLM/Qwen2.swift b/Libraries/LLM/Qwen2.swift index 82bfb83..5d627b0 100644 --- a/Libraries/LLM/Qwen2.swift +++ b/Libraries/LLM/Qwen2.swift @@ -42,7 +42,7 @@ private class Attention: Module { let ropeScale: Float if let ropeScaling = args.ropeScaling, ropeScaling["type"] == .string("linear"), - let factor = ropeScaling["factor"] + let factor = ropeScaling["factor"] { switch factor { case .string: @@ -60,8 +60,8 @@ private class Attention: Module { } public func callAsFunction( - _ x: MLXArray, mask: MLXArray? = nil, cache: (MLXArray, MLXArray)? = nil) -> (MLXArray, (MLXArray, MLXArray)) - { + _ x: MLXArray, mask: MLXArray? = nil, cache: (MLXArray, MLXArray)? = nil + ) -> (MLXArray, (MLXArray, MLXArray)) { let (B, L) = (x.dim(0), x.dim(1)) var queries = wq(x) @@ -134,8 +134,8 @@ private class TransformerBlock: Module { } public func callAsFunction( - _ x: MLXArray, mask: MLXArray? = nil, cache: (MLXArray, MLXArray)? = nil) -> (MLXArray, (MLXArray, MLXArray)) - { + _ x: MLXArray, mask: MLXArray? = nil, cache: (MLXArray, MLXArray)? = nil + ) -> (MLXArray, (MLXArray, MLXArray)) { var (r, cache) = attention(inputLayerNorm(x), mask: mask, cache: cache) let h = x + r r = mlp(postAttentionLayerNorm(h)) @@ -164,8 +164,8 @@ public class Qwen2ModelInner: Module { } public func callAsFunction(_ inputs: MLXArray, cache: [(MLXArray, MLXArray)]? = nil) -> ( - MLXArray, [(MLXArray, MLXArray)]) - { + MLXArray, [(MLXArray, MLXArray)] + ) { var h = embedTokens(inputs) var mask: MLXArray? = nil @@ -199,8 +199,8 @@ public class Qwen2Model: Module, LLMModel { } public func callAsFunction(_ inputs: MLXArray, cache: [(MLXArray, MLXArray)]?) -> ( - MLXArray, [(MLXArray, MLXArray)]) - { + MLXArray, [(MLXArray, MLXArray)] + ) { let (out, cache) = model(inputs, cache: cache) return (lmHead(out), cache) } diff --git a/Tools/llm-tool/LLMTool.swift b/Tools/llm-tool/LLMTool.swift index e0035e2..61f130e 100644 --- a/Tools/llm-tool/LLMTool.swift +++ b/Tools/llm-tool/LLMTool.swift @@ -42,9 +42,9 @@ struct SyncGenerator: AsyncParsableCommand { let modelConfiguration = ModelConfiguration.configuration(id: model) let (model, tokenizer) = try await load(configuration: modelConfiguration) - + print("Model loaded -> \(self.model)") - + let prompt = modelConfiguration.prepare(prompt: self.prompt) let promptTokens = tokenizer.encode(text: prompt) @@ -133,7 +133,7 @@ struct AsyncGenerator: AsyncParsableCommand { let modelConfiguration = ModelConfiguration.configuration(id: model) let (model, tokenizer) = try await load(configuration: modelConfiguration) - + print("Model loaded -> \(self.model)") let prompt = modelConfiguration.prepare(prompt: self.prompt)