apply swift-format
This commit is contained in:
@@ -84,7 +84,7 @@ extension ModelConfiguration {
|
|||||||
) { prompt in
|
) { prompt in
|
||||||
"<start_of_turn>user \(prompt)<end_of_turn><start_of_turn>model"
|
"<start_of_turn>user \(prompt)<end_of_turn><start_of_turn>model"
|
||||||
}
|
}
|
||||||
|
|
||||||
public static let qwen205b4bit = ModelConfiguration(
|
public static let qwen205b4bit = ModelConfiguration(
|
||||||
id: "mlx-community/Qwen1.5-0.5B-Chat-4bit",
|
id: "mlx-community/Qwen1.5-0.5B-Chat-4bit",
|
||||||
overrideTokenizer: "PreTrainedTokenizer"
|
overrideTokenizer: "PreTrainedTokenizer"
|
||||||
|
|||||||
@@ -42,7 +42,7 @@ private class Attention: Module {
|
|||||||
|
|
||||||
let ropeScale: Float
|
let ropeScale: Float
|
||||||
if let ropeScaling = args.ropeScaling, ropeScaling["type"] == .string("linear"),
|
if let ropeScaling = args.ropeScaling, ropeScaling["type"] == .string("linear"),
|
||||||
let factor = ropeScaling["factor"]
|
let factor = ropeScaling["factor"]
|
||||||
{
|
{
|
||||||
switch factor {
|
switch factor {
|
||||||
case .string:
|
case .string:
|
||||||
@@ -60,8 +60,8 @@ private class Attention: Module {
|
|||||||
}
|
}
|
||||||
|
|
||||||
public func callAsFunction(
|
public func callAsFunction(
|
||||||
_ x: MLXArray, mask: MLXArray? = nil, cache: (MLXArray, MLXArray)? = nil) -> (MLXArray, (MLXArray, MLXArray))
|
_ x: MLXArray, mask: MLXArray? = nil, cache: (MLXArray, MLXArray)? = nil
|
||||||
{
|
) -> (MLXArray, (MLXArray, MLXArray)) {
|
||||||
let (B, L) = (x.dim(0), x.dim(1))
|
let (B, L) = (x.dim(0), x.dim(1))
|
||||||
|
|
||||||
var queries = wq(x)
|
var queries = wq(x)
|
||||||
@@ -134,8 +134,8 @@ private class TransformerBlock: Module {
|
|||||||
}
|
}
|
||||||
|
|
||||||
public func callAsFunction(
|
public func callAsFunction(
|
||||||
_ x: MLXArray, mask: MLXArray? = nil, cache: (MLXArray, MLXArray)? = nil) -> (MLXArray, (MLXArray, MLXArray))
|
_ x: MLXArray, mask: MLXArray? = nil, cache: (MLXArray, MLXArray)? = nil
|
||||||
{
|
) -> (MLXArray, (MLXArray, MLXArray)) {
|
||||||
var (r, cache) = attention(inputLayerNorm(x), mask: mask, cache: cache)
|
var (r, cache) = attention(inputLayerNorm(x), mask: mask, cache: cache)
|
||||||
let h = x + r
|
let h = x + r
|
||||||
r = mlp(postAttentionLayerNorm(h))
|
r = mlp(postAttentionLayerNorm(h))
|
||||||
@@ -164,8 +164,8 @@ public class Qwen2ModelInner: Module {
|
|||||||
}
|
}
|
||||||
|
|
||||||
public func callAsFunction(_ inputs: MLXArray, cache: [(MLXArray, MLXArray)]? = nil) -> (
|
public func callAsFunction(_ inputs: MLXArray, cache: [(MLXArray, MLXArray)]? = nil) -> (
|
||||||
MLXArray, [(MLXArray, MLXArray)])
|
MLXArray, [(MLXArray, MLXArray)]
|
||||||
{
|
) {
|
||||||
var h = embedTokens(inputs)
|
var h = embedTokens(inputs)
|
||||||
|
|
||||||
var mask: MLXArray? = nil
|
var mask: MLXArray? = nil
|
||||||
@@ -199,8 +199,8 @@ public class Qwen2Model: Module, LLMModel {
|
|||||||
}
|
}
|
||||||
|
|
||||||
public func callAsFunction(_ inputs: MLXArray, cache: [(MLXArray, MLXArray)]?) -> (
|
public func callAsFunction(_ inputs: MLXArray, cache: [(MLXArray, MLXArray)]?) -> (
|
||||||
MLXArray, [(MLXArray, MLXArray)])
|
MLXArray, [(MLXArray, MLXArray)]
|
||||||
{
|
) {
|
||||||
let (out, cache) = model(inputs, cache: cache)
|
let (out, cache) = model(inputs, cache: cache)
|
||||||
return (lmHead(out), cache)
|
return (lmHead(out), cache)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -42,9 +42,9 @@ struct SyncGenerator: AsyncParsableCommand {
|
|||||||
|
|
||||||
let modelConfiguration = ModelConfiguration.configuration(id: model)
|
let modelConfiguration = ModelConfiguration.configuration(id: model)
|
||||||
let (model, tokenizer) = try await load(configuration: modelConfiguration)
|
let (model, tokenizer) = try await load(configuration: modelConfiguration)
|
||||||
|
|
||||||
print("Model loaded -> \(self.model)")
|
print("Model loaded -> \(self.model)")
|
||||||
|
|
||||||
let prompt = modelConfiguration.prepare(prompt: self.prompt)
|
let prompt = modelConfiguration.prepare(prompt: self.prompt)
|
||||||
let promptTokens = tokenizer.encode(text: prompt)
|
let promptTokens = tokenizer.encode(text: prompt)
|
||||||
|
|
||||||
@@ -133,7 +133,7 @@ struct AsyncGenerator: AsyncParsableCommand {
|
|||||||
|
|
||||||
let modelConfiguration = ModelConfiguration.configuration(id: model)
|
let modelConfiguration = ModelConfiguration.configuration(id: model)
|
||||||
let (model, tokenizer) = try await load(configuration: modelConfiguration)
|
let (model, tokenizer) = try await load(configuration: modelConfiguration)
|
||||||
|
|
||||||
print("Model loaded -> \(self.model)")
|
print("Model loaded -> \(self.model)")
|
||||||
|
|
||||||
let prompt = modelConfiguration.prepare(prompt: self.prompt)
|
let prompt = modelConfiguration.prepare(prompt: self.prompt)
|
||||||
|
|||||||
Reference in New Issue
Block a user