Merge pull request #11 from maiqingqiang/feat-qwen2
feat: Qwen2 support
This commit is contained in:
@@ -31,6 +31,7 @@ public enum ModelType: String, Codable {
|
|||||||
case llama
|
case llama
|
||||||
case phi
|
case phi
|
||||||
case gemma
|
case gemma
|
||||||
|
case qwen2
|
||||||
|
|
||||||
func createModel(configuration: URL) throws -> LLMModel {
|
func createModel(configuration: URL) throws -> LLMModel {
|
||||||
switch self {
|
switch self {
|
||||||
@@ -46,6 +47,10 @@ public enum ModelType: String, Codable {
|
|||||||
let configuration = try JSONDecoder().decode(
|
let configuration = try JSONDecoder().decode(
|
||||||
GemmaConfiguration.self, from: Data(contentsOf: configuration))
|
GemmaConfiguration.self, from: Data(contentsOf: configuration))
|
||||||
return GemmaModel(configuration)
|
return GemmaModel(configuration)
|
||||||
|
case .qwen2:
|
||||||
|
let configuration = try JSONDecoder().decode(
|
||||||
|
Qwen2Configuration.self, from: Data(contentsOf: configuration))
|
||||||
|
return Qwen2Model(configuration)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -84,6 +84,13 @@ extension ModelConfiguration {
|
|||||||
) { prompt in
|
) { prompt in
|
||||||
"<start_of_turn>user \(prompt)<end_of_turn><start_of_turn>model"
|
"<start_of_turn>user \(prompt)<end_of_turn><start_of_turn>model"
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public static let qwen205b4bit = ModelConfiguration(
|
||||||
|
id: "mlx-community/Qwen1.5-0.5B-Chat-4bit",
|
||||||
|
overrideTokenizer: "PreTrainedTokenizer"
|
||||||
|
) { prompt in
|
||||||
|
"<|im_start|>user \(prompt)<|im_end|><|im_start|>assistant"
|
||||||
|
}
|
||||||
|
|
||||||
private enum BootstrapState {
|
private enum BootstrapState {
|
||||||
case idle
|
case idle
|
||||||
@@ -102,6 +109,7 @@ extension ModelConfiguration {
|
|||||||
codeLlama13b4bit,
|
codeLlama13b4bit,
|
||||||
phi4bit,
|
phi4bit,
|
||||||
gemma2bQuantized,
|
gemma2bQuantized,
|
||||||
|
qwen205b4bit,
|
||||||
])
|
])
|
||||||
bootstrapState = .bootstrapped
|
bootstrapState = .bootstrapped
|
||||||
|
|
||||||
|
|||||||
263
Libraries/LLM/Qwen2.swift
Normal file
263
Libraries/LLM/Qwen2.swift
Normal file
@@ -0,0 +1,263 @@
|
|||||||
|
//
|
||||||
|
// Qwen2.swift
|
||||||
|
// LLM
|
||||||
|
//
|
||||||
|
// Created by John Mai on 2024/3/3.
|
||||||
|
//
|
||||||
|
|
||||||
|
import Foundation
|
||||||
|
import MLX
|
||||||
|
import MLXNN
|
||||||
|
|
||||||
|
// port of https://github.com/ml-explore/mlx-examples/blob/main/llms/mlx_lm/models/qwen2.py
|
||||||
|
|
||||||
|
private class Attention: Module {
|
||||||
|
let args: Qwen2Configuration
|
||||||
|
let repeats: Int
|
||||||
|
let scale: Float
|
||||||
|
|
||||||
|
@ModuleInfo(key: "q_proj") var wq: Linear
|
||||||
|
@ModuleInfo(key: "k_proj") var wk: Linear
|
||||||
|
@ModuleInfo(key: "v_proj") var wv: Linear
|
||||||
|
@ModuleInfo(key: "o_proj") var wo: Linear
|
||||||
|
|
||||||
|
let rope: RoPE
|
||||||
|
|
||||||
|
public init(_ args: Qwen2Configuration) {
|
||||||
|
self.args = args
|
||||||
|
|
||||||
|
let dim = args.hiddenSize
|
||||||
|
let heads = args.attentionHeads
|
||||||
|
let kvHeads = args.kvHeads
|
||||||
|
|
||||||
|
self.repeats = heads / kvHeads
|
||||||
|
|
||||||
|
let headDim = args.hiddenSize / heads
|
||||||
|
self.scale = pow(Float(headDim), -0.5)
|
||||||
|
|
||||||
|
_wq.wrappedValue = Linear(dim, heads * headDim, bias: true)
|
||||||
|
_wk.wrappedValue = Linear(dim, kvHeads * headDim, bias: true)
|
||||||
|
_wv.wrappedValue = Linear(dim, kvHeads * headDim, bias: true)
|
||||||
|
_wo.wrappedValue = Linear(heads * headDim, dim, bias: false)
|
||||||
|
|
||||||
|
let ropeScale: Float
|
||||||
|
if let ropeScaling = args.ropeScaling, ropeScaling["type"] == .string("linear"),
|
||||||
|
let factor = ropeScaling["factor"]
|
||||||
|
{
|
||||||
|
switch factor {
|
||||||
|
case .string:
|
||||||
|
fatalError("ropeScaling.factor must be a float")
|
||||||
|
case .float(let v):
|
||||||
|
ropeScale = 1 / v
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
ropeScale = 1
|
||||||
|
}
|
||||||
|
|
||||||
|
self.rope = RoPE(
|
||||||
|
dimensions: headDim, traditional: args.ropeTraditional, base: args.ropeTheta,
|
||||||
|
scale: ropeScale)
|
||||||
|
}
|
||||||
|
|
||||||
|
public func callAsFunction(
|
||||||
|
_ x: MLXArray, mask: MLXArray? = nil, cache: (MLXArray, MLXArray)? = nil) -> (MLXArray, (MLXArray, MLXArray))
|
||||||
|
{
|
||||||
|
let (B, L) = (x.dim(0), x.dim(1))
|
||||||
|
|
||||||
|
var queries = wq(x)
|
||||||
|
var keys = wk(x)
|
||||||
|
var values = wv(x)
|
||||||
|
|
||||||
|
// prepare the queries, keys and values for the attention computation
|
||||||
|
queries = queries.reshaped(B, L, args.attentionHeads, -1).transposed(0, 2, 1, 3)
|
||||||
|
keys = keys.reshaped(B, L, args.kvHeads, -1).transposed(0, 2, 1, 3)
|
||||||
|
values = values.reshaped(B, L, args.kvHeads, -1).transposed(0, 2, 1, 3)
|
||||||
|
|
||||||
|
if repeats > 1 {
|
||||||
|
keys = MLXArray.repeat(keys, count: repeats, axis: 1)
|
||||||
|
values = MLXArray.repeat(values, count: repeats, axis: 1)
|
||||||
|
}
|
||||||
|
|
||||||
|
if let (keyCache, valueCache) = cache {
|
||||||
|
queries = rope(queries, offset: keyCache.dim(2))
|
||||||
|
keys = rope(keys, offset: keyCache.dim(2))
|
||||||
|
keys = concatenated([keyCache, keys], axis: 2)
|
||||||
|
values = concatenated([valueCache, values], axis: 2)
|
||||||
|
} else {
|
||||||
|
queries = rope(queries)
|
||||||
|
keys = rope(keys)
|
||||||
|
}
|
||||||
|
|
||||||
|
var scores = (queries * scale).matmul(keys.transposed(0, 1, 3, 2))
|
||||||
|
if let mask {
|
||||||
|
scores = scores + mask
|
||||||
|
}
|
||||||
|
|
||||||
|
scores = softMax(scores.asType(.float32), axis: -1).asType(scores.dtype)
|
||||||
|
|
||||||
|
let output = matmul(scores, values).transposed(0, 2, 1, 3).reshaped(B, L, -1)
|
||||||
|
|
||||||
|
return (wo(output), (keys, values))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private class MLP: Module, UnaryLayer {
|
||||||
|
@ModuleInfo(key: "gate_proj") var gate: Linear
|
||||||
|
@ModuleInfo(key: "down_proj") var down: Linear
|
||||||
|
@ModuleInfo(key: "up_proj") var up: Linear
|
||||||
|
|
||||||
|
public init(dimensions: Int, hiddenDimensions: Int) {
|
||||||
|
_gate.wrappedValue = Linear(dimensions, hiddenDimensions, bias: false)
|
||||||
|
_down.wrappedValue = Linear(hiddenDimensions, dimensions, bias: false)
|
||||||
|
_up.wrappedValue = Linear(dimensions, hiddenDimensions, bias: false)
|
||||||
|
}
|
||||||
|
|
||||||
|
public func callAsFunction(_ x: MLXArray) -> MLXArray {
|
||||||
|
down(silu(gate(x)) * up(x))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private class TransformerBlock: Module {
|
||||||
|
@ModuleInfo(key: "self_attn") var attention: Attention
|
||||||
|
let mlp: MLP
|
||||||
|
|
||||||
|
@ModuleInfo(key: "input_layernorm") var inputLayerNorm: RMSNorm
|
||||||
|
@ModuleInfo(key: "post_attention_layernorm") var postAttentionLayerNorm: RMSNorm
|
||||||
|
|
||||||
|
public init(_ args: Qwen2Configuration) {
|
||||||
|
_attention.wrappedValue = Attention(args)
|
||||||
|
self.mlp = MLP(dimensions: args.hiddenSize, hiddenDimensions: args.intermediateSize)
|
||||||
|
_inputLayerNorm.wrappedValue = RMSNorm(
|
||||||
|
dimensions: args.hiddenSize, eps: args.rmsNormEps)
|
||||||
|
_postAttentionLayerNorm.wrappedValue = RMSNorm(
|
||||||
|
dimensions: args.hiddenSize, eps: args.rmsNormEps)
|
||||||
|
}
|
||||||
|
|
||||||
|
public func callAsFunction(
|
||||||
|
_ x: MLXArray, mask: MLXArray? = nil, cache: (MLXArray, MLXArray)? = nil) -> (MLXArray, (MLXArray, MLXArray))
|
||||||
|
{
|
||||||
|
var (r, cache) = attention(inputLayerNorm(x), mask: mask, cache: cache)
|
||||||
|
let h = x + r
|
||||||
|
r = mlp(postAttentionLayerNorm(h))
|
||||||
|
let out = h + r
|
||||||
|
return (out, cache)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
public class Qwen2ModelInner: Module {
|
||||||
|
@ModuleInfo(key: "embed_tokens") var embedTokens: Embedding
|
||||||
|
|
||||||
|
fileprivate let layers: [TransformerBlock]
|
||||||
|
let norm: RMSNorm
|
||||||
|
|
||||||
|
public init(_ args: Qwen2Configuration) {
|
||||||
|
precondition(args.vocabularySize > 0)
|
||||||
|
|
||||||
|
_embedTokens.wrappedValue = Embedding(
|
||||||
|
embeddingCount: args.vocabularySize, dimensions: args.hiddenSize)
|
||||||
|
|
||||||
|
self.layers = (0 ..< args.hiddenLayers)
|
||||||
|
.map { _ in
|
||||||
|
TransformerBlock(args)
|
||||||
|
}
|
||||||
|
self.norm = RMSNorm(dimensions: args.hiddenSize, eps: args.rmsNormEps)
|
||||||
|
}
|
||||||
|
|
||||||
|
public func callAsFunction(_ inputs: MLXArray, cache: [(MLXArray, MLXArray)]? = nil) -> (
|
||||||
|
MLXArray, [(MLXArray, MLXArray)])
|
||||||
|
{
|
||||||
|
var h = embedTokens(inputs)
|
||||||
|
|
||||||
|
var mask: MLXArray? = nil
|
||||||
|
if h.dim(1) > 1 {
|
||||||
|
mask = MultiHeadAttention.createAdditiveCausalMask(h.dim(1))
|
||||||
|
mask = mask?.asType(h.dtype)
|
||||||
|
}
|
||||||
|
|
||||||
|
var newCache = [(MLXArray, MLXArray)]()
|
||||||
|
|
||||||
|
for (i, layer) in layers.enumerated() {
|
||||||
|
var cacheUpdate: (MLXArray, MLXArray)
|
||||||
|
(h, cacheUpdate) = layer(h, mask: mask, cache: cache?[i])
|
||||||
|
newCache.append(cacheUpdate)
|
||||||
|
}
|
||||||
|
|
||||||
|
return (norm(h), newCache)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
public class Qwen2Model: Module, LLMModel {
|
||||||
|
public let vocabularySize: Int
|
||||||
|
let model: Qwen2ModelInner
|
||||||
|
|
||||||
|
@ModuleInfo(key: "lm_head") var lmHead: Linear
|
||||||
|
|
||||||
|
public init(_ args: Qwen2Configuration) {
|
||||||
|
self.vocabularySize = args.vocabularySize
|
||||||
|
self.model = Qwen2ModelInner(args)
|
||||||
|
_lmHead.wrappedValue = Linear(args.hiddenSize, args.vocabularySize, bias: false)
|
||||||
|
}
|
||||||
|
|
||||||
|
public func callAsFunction(_ inputs: MLXArray, cache: [(MLXArray, MLXArray)]?) -> (
|
||||||
|
MLXArray, [(MLXArray, MLXArray)])
|
||||||
|
{
|
||||||
|
let (out, cache) = model(inputs, cache: cache)
|
||||||
|
return (lmHead(out), cache)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
public struct Qwen2Configuration: Codable {
|
||||||
|
var hiddenSize: Int
|
||||||
|
var hiddenLayers: Int
|
||||||
|
var intermediateSize: Int
|
||||||
|
var attentionHeads: Int
|
||||||
|
var rmsNormEps: Float
|
||||||
|
var vocabularySize: Int
|
||||||
|
var kvHeads: Int
|
||||||
|
var ropeTheta: Float = 1_000_000
|
||||||
|
var ropeTraditional: Bool = false
|
||||||
|
var ropeScaling: [String: StringOrNumber]? = nil
|
||||||
|
|
||||||
|
enum CodingKeys: String, CodingKey {
|
||||||
|
case hiddenSize = "hidden_size"
|
||||||
|
case hiddenLayers = "num_hidden_layers"
|
||||||
|
case intermediateSize = "intermediate_size"
|
||||||
|
case attentionHeads = "num_attention_heads"
|
||||||
|
case rmsNormEps = "rms_norm_eps"
|
||||||
|
case vocabularySize = "vocab_size"
|
||||||
|
case kvHeads = "num_key_value_heads"
|
||||||
|
case ropeTheta = "rope_theta"
|
||||||
|
case ropeTraditional = "rope_traditional"
|
||||||
|
case ropeScaling = "rope_scaling"
|
||||||
|
}
|
||||||
|
|
||||||
|
public init(from decoder: Decoder) throws {
|
||||||
|
// custom implementation to handle optional keys with required values
|
||||||
|
let container: KeyedDecodingContainer<Qwen2Configuration.CodingKeys> =
|
||||||
|
try decoder.container(
|
||||||
|
keyedBy: Qwen2Configuration.CodingKeys.self)
|
||||||
|
|
||||||
|
self.hiddenSize = try container.decode(
|
||||||
|
Int.self, forKey: Qwen2Configuration.CodingKeys.hiddenSize)
|
||||||
|
self.hiddenLayers = try container.decode(
|
||||||
|
Int.self, forKey: Qwen2Configuration.CodingKeys.hiddenLayers)
|
||||||
|
self.intermediateSize = try container.decode(
|
||||||
|
Int.self, forKey: Qwen2Configuration.CodingKeys.intermediateSize)
|
||||||
|
self.attentionHeads = try container.decode(
|
||||||
|
Int.self, forKey: Qwen2Configuration.CodingKeys.attentionHeads)
|
||||||
|
self.rmsNormEps = try container.decode(
|
||||||
|
Float.self, forKey: Qwen2Configuration.CodingKeys.rmsNormEps)
|
||||||
|
self.vocabularySize = try container.decode(
|
||||||
|
Int.self, forKey: Qwen2Configuration.CodingKeys.vocabularySize)
|
||||||
|
self.kvHeads = try container.decode(Int.self, forKey: Qwen2Configuration.CodingKeys.kvHeads)
|
||||||
|
self.ropeTheta =
|
||||||
|
try container.decodeIfPresent(
|
||||||
|
Float.self, forKey: Qwen2Configuration.CodingKeys.ropeTheta)
|
||||||
|
?? 1_000_000
|
||||||
|
self.ropeTraditional =
|
||||||
|
try container.decodeIfPresent(
|
||||||
|
Bool.self, forKey: Qwen2Configuration.CodingKeys.ropeTraditional) ?? false
|
||||||
|
self.ropeScaling = try container.decodeIfPresent(
|
||||||
|
[String: StringOrNumber].self, forKey: Qwen2Configuration.CodingKeys.ropeScaling)
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -116,4 +116,5 @@ public func discardUnhandledMerges(tokenizerData: Config) -> Config {
|
|||||||
let replacementTokenizers = [
|
let replacementTokenizers = [
|
||||||
"CodeLlamaTokenizer": "LlamaTokenizer",
|
"CodeLlamaTokenizer": "LlamaTokenizer",
|
||||||
"GemmaTokenizer": "PreTrainedTokenizer",
|
"GemmaTokenizer": "PreTrainedTokenizer",
|
||||||
|
"Qwen2Tokenizer": "PreTrainedTokenizer",
|
||||||
]
|
]
|
||||||
|
|||||||
@@ -42,7 +42,9 @@ struct SyncGenerator: AsyncParsableCommand {
|
|||||||
|
|
||||||
let modelConfiguration = ModelConfiguration.configuration(id: model)
|
let modelConfiguration = ModelConfiguration.configuration(id: model)
|
||||||
let (model, tokenizer) = try await load(configuration: modelConfiguration)
|
let (model, tokenizer) = try await load(configuration: modelConfiguration)
|
||||||
|
|
||||||
|
print("Model loaded -> \(self.model)")
|
||||||
|
|
||||||
let prompt = modelConfiguration.prepare(prompt: self.prompt)
|
let prompt = modelConfiguration.prepare(prompt: self.prompt)
|
||||||
let promptTokens = tokenizer.encode(text: prompt)
|
let promptTokens = tokenizer.encode(text: prompt)
|
||||||
|
|
||||||
@@ -131,6 +133,8 @@ struct AsyncGenerator: AsyncParsableCommand {
|
|||||||
|
|
||||||
let modelConfiguration = ModelConfiguration.configuration(id: model)
|
let modelConfiguration = ModelConfiguration.configuration(id: model)
|
||||||
let (model, tokenizer) = try await load(configuration: modelConfiguration)
|
let (model, tokenizer) = try await load(configuration: modelConfiguration)
|
||||||
|
|
||||||
|
print("Model loaded -> \(self.model)")
|
||||||
|
|
||||||
let prompt = modelConfiguration.prepare(prompt: self.prompt)
|
let prompt = modelConfiguration.prepare(prompt: self.prompt)
|
||||||
let promptTokens = tokenizer.encode(text: prompt)
|
let promptTokens = tokenizer.encode(text: prompt)
|
||||||
|
|||||||
@@ -7,6 +7,7 @@
|
|||||||
objects = {
|
objects = {
|
||||||
|
|
||||||
/* Begin PBXBuildFile section */
|
/* Begin PBXBuildFile section */
|
||||||
|
52A776182B94B5EE00AA6E80 /* Qwen2.swift in Sources */ = {isa = PBXBuildFile; fileRef = 52A776172B94B5EE00AA6E80 /* Qwen2.swift */; };
|
||||||
C3288D762B6D9313009FF608 /* LinearModelTraining.swift in Sources */ = {isa = PBXBuildFile; fileRef = C3288D752B6D9313009FF608 /* LinearModelTraining.swift */; };
|
C3288D762B6D9313009FF608 /* LinearModelTraining.swift in Sources */ = {isa = PBXBuildFile; fileRef = C3288D752B6D9313009FF608 /* LinearModelTraining.swift */; };
|
||||||
C3288D7B2B6D9339009FF608 /* ArgumentParser in Frameworks */ = {isa = PBXBuildFile; productRef = C3288D7A2B6D9339009FF608 /* ArgumentParser */; };
|
C3288D7B2B6D9339009FF608 /* ArgumentParser in Frameworks */ = {isa = PBXBuildFile; productRef = C3288D7A2B6D9339009FF608 /* ArgumentParser */; };
|
||||||
C34E48F52B696F0B00FCB841 /* LLMTool.swift in Sources */ = {isa = PBXBuildFile; fileRef = C34E48F42B696F0B00FCB841 /* LLMTool.swift */; };
|
C34E48F52B696F0B00FCB841 /* LLMTool.swift in Sources */ = {isa = PBXBuildFile; fileRef = C34E48F42B696F0B00FCB841 /* LLMTool.swift */; };
|
||||||
@@ -180,6 +181,7 @@
|
|||||||
/* End PBXCopyFilesBuildPhase section */
|
/* End PBXCopyFilesBuildPhase section */
|
||||||
|
|
||||||
/* Begin PBXFileReference section */
|
/* Begin PBXFileReference section */
|
||||||
|
52A776172B94B5EE00AA6E80 /* Qwen2.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = Qwen2.swift; sourceTree = "<group>"; };
|
||||||
C325DE3F2B648CDB00628871 /* README.md */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = net.daringfireball.markdown; path = README.md; sourceTree = "<group>"; };
|
C325DE3F2B648CDB00628871 /* README.md */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = net.daringfireball.markdown; path = README.md; sourceTree = "<group>"; };
|
||||||
C3288D732B6D9313009FF608 /* LinearModelTraining */ = {isa = PBXFileReference; explicitFileType = "compiled.mach-o.executable"; includeInIndex = 0; path = LinearModelTraining; sourceTree = BUILT_PRODUCTS_DIR; };
|
C3288D732B6D9313009FF608 /* LinearModelTraining */ = {isa = PBXFileReference; explicitFileType = "compiled.mach-o.executable"; includeInIndex = 0; path = LinearModelTraining; sourceTree = BUILT_PRODUCTS_DIR; };
|
||||||
C3288D752B6D9313009FF608 /* LinearModelTraining.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = LinearModelTraining.swift; sourceTree = "<group>"; };
|
C3288D752B6D9313009FF608 /* LinearModelTraining.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = LinearModelTraining.swift; sourceTree = "<group>"; };
|
||||||
@@ -363,6 +365,7 @@
|
|||||||
C34E48ED2B696E6500FCB841 /* Load.swift */,
|
C34E48ED2B696E6500FCB841 /* Load.swift */,
|
||||||
C3E786AA2B8D1AEC0004D037 /* Evaluate.swift */,
|
C3E786AA2B8D1AEC0004D037 /* Evaluate.swift */,
|
||||||
C3E786AC2B8D4AF50004D037 /* Tokenizer.swift */,
|
C3E786AC2B8D4AF50004D037 /* Tokenizer.swift */,
|
||||||
|
52A776172B94B5EE00AA6E80 /* Qwen2.swift */,
|
||||||
);
|
);
|
||||||
path = LLM;
|
path = LLM;
|
||||||
sourceTree = "<group>";
|
sourceTree = "<group>";
|
||||||
@@ -829,6 +832,7 @@
|
|||||||
C3A8B3AC2B9283150002EFB8 /* Models.swift in Sources */,
|
C3A8B3AC2B9283150002EFB8 /* Models.swift in Sources */,
|
||||||
C3E786AB2B8D1AEC0004D037 /* Evaluate.swift in Sources */,
|
C3E786AB2B8D1AEC0004D037 /* Evaluate.swift in Sources */,
|
||||||
C38935CC2B869C870037B833 /* Llama.swift in Sources */,
|
C38935CC2B869C870037B833 /* Llama.swift in Sources */,
|
||||||
|
52A776182B94B5EE00AA6E80 /* Qwen2.swift in Sources */,
|
||||||
);
|
);
|
||||||
runOnlyForDeploymentPostprocessing = 0;
|
runOnlyForDeploymentPostprocessing = 0;
|
||||||
};
|
};
|
||||||
|
|||||||
@@ -55,6 +55,10 @@
|
|||||||
argument = "--model mlx-community/CodeLlama-13b-Instruct-hf-4bit-MLX"
|
argument = "--model mlx-community/CodeLlama-13b-Instruct-hf-4bit-MLX"
|
||||||
isEnabled = "NO">
|
isEnabled = "NO">
|
||||||
</CommandLineArgument>
|
</CommandLineArgument>
|
||||||
|
<CommandLineArgument
|
||||||
|
argument = "--model mlx-community/Qwen1.5-0.5B-Chat-4bit"
|
||||||
|
isEnabled = "YES">
|
||||||
|
</CommandLineArgument>
|
||||||
<CommandLineArgument
|
<CommandLineArgument
|
||||||
argument = "--prompt 'func sortArray(_ array: [Int]) -> String { <FILL_ME> }'"
|
argument = "--prompt 'func sortArray(_ array: [Int]) -> String { <FILL_ME> }'"
|
||||||
isEnabled = "NO">
|
isEnabled = "NO">
|
||||||
@@ -69,7 +73,7 @@
|
|||||||
</CommandLineArgument>
|
</CommandLineArgument>
|
||||||
<CommandLineArgument
|
<CommandLineArgument
|
||||||
argument = "--model mlx-community/phi-2-hf-4bit-mlx"
|
argument = "--model mlx-community/phi-2-hf-4bit-mlx"
|
||||||
isEnabled = "YES">
|
isEnabled = "NO">
|
||||||
</CommandLineArgument>
|
</CommandLineArgument>
|
||||||
</CommandLineArguments>
|
</CommandLineArguments>
|
||||||
</LaunchAction>
|
</LaunchAction>
|
||||||
|
|||||||
Reference in New Issue
Block a user