swift-format!
This commit is contained in:
@@ -14,7 +14,8 @@ struct LLMError: Error {
|
|||||||
|
|
||||||
/// Load and return the model and tokenizer
|
/// Load and return the model and tokenizer
|
||||||
public func load(
|
public func load(
|
||||||
hub: HubApi = HubApi(), configuration: ModelConfiguration, progressHandler: @escaping (Progress) -> Void = { _ in }
|
hub: HubApi = HubApi(), configuration: ModelConfiguration,
|
||||||
|
progressHandler: @escaping (Progress) -> Void = { _ in }
|
||||||
) async throws -> (LLMModel, Tokenizer) {
|
) async throws -> (LLMModel, Tokenizer) {
|
||||||
// note: this doesn't have a way to pass the HubApi
|
// note: this doesn't have a way to pass the HubApi
|
||||||
let tokenizer = try await loadTokenizer(configuration: configuration)
|
let tokenizer = try await loadTokenizer(configuration: configuration)
|
||||||
|
|||||||
@@ -10,35 +10,37 @@ import Foundation
|
|||||||
/// implementation, if needed.
|
/// implementation, if needed.
|
||||||
public struct ModelConfiguration {
|
public struct ModelConfiguration {
|
||||||
public let id: String
|
public let id: String
|
||||||
|
|
||||||
/// overrides for TokenizerModel/knownTokenizers -- useful before swift-transformers is updated
|
/// overrides for TokenizerModel/knownTokenizers -- useful before swift-transformers is updated
|
||||||
public let overrideTokenizer: String?
|
public let overrideTokenizer: String?
|
||||||
|
|
||||||
/// custom preparation logic for the prompt. custom tokenizers provide more capability, but this
|
/// custom preparation logic for the prompt. custom tokenizers provide more capability, but this
|
||||||
/// allows some minor formtting changes, e.g. wrapping the user input in the expected prompt
|
/// allows some minor formtting changes, e.g. wrapping the user input in the expected prompt
|
||||||
/// format
|
/// format
|
||||||
private let preparePrompt: ((String) -> String)?
|
private let preparePrompt: ((String) -> String)?
|
||||||
|
|
||||||
public init(id: String, overrideTokenizer: String? = nil, preparePrompt: ((String) -> String)? = nil) {
|
public init(
|
||||||
|
id: String, overrideTokenizer: String? = nil, preparePrompt: ((String) -> String)? = nil
|
||||||
|
) {
|
||||||
self.id = id
|
self.id = id
|
||||||
self.overrideTokenizer = overrideTokenizer
|
self.overrideTokenizer = overrideTokenizer
|
||||||
self.preparePrompt = preparePrompt
|
self.preparePrompt = preparePrompt
|
||||||
}
|
}
|
||||||
|
|
||||||
public func prepare(prompt: String) -> String {
|
public func prepare(prompt: String) -> String {
|
||||||
preparePrompt?(prompt) ?? prompt
|
preparePrompt?(prompt) ?? prompt
|
||||||
}
|
}
|
||||||
|
|
||||||
public static var registry = [String:ModelConfiguration]()
|
public static var registry = [String: ModelConfiguration]()
|
||||||
|
|
||||||
public static func register(configurations: [ModelConfiguration]) {
|
public static func register(configurations: [ModelConfiguration]) {
|
||||||
bootstrap()
|
bootstrap()
|
||||||
|
|
||||||
for c in configurations {
|
for c in configurations {
|
||||||
registry[c.id] = c
|
registry[c.id] = c
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
public static func configuration(id: String) -> ModelConfiguration {
|
public static func configuration(id: String) -> ModelConfiguration {
|
||||||
bootstrap()
|
bootstrap()
|
||||||
|
|
||||||
@@ -51,40 +53,39 @@ public struct ModelConfiguration {
|
|||||||
}
|
}
|
||||||
|
|
||||||
extension ModelConfiguration {
|
extension ModelConfiguration {
|
||||||
|
|
||||||
static let mistral7B4bit = ModelConfiguration(id: "mlx-community/Mistral-7B-v0.1-hf-4bit-mlx")
|
static let mistral7B4bit = ModelConfiguration(id: "mlx-community/Mistral-7B-v0.1-hf-4bit-mlx")
|
||||||
|
|
||||||
static let codeLlama13b4bit = ModelConfiguration(
|
static let codeLlama13b4bit = ModelConfiguration(
|
||||||
id: "mlx-community/CodeLlama-13b-Instruct-hf-4bit-MLX",
|
id: "mlx-community/CodeLlama-13b-Instruct-hf-4bit-MLX",
|
||||||
overrideTokenizer: "PreTrainedTokenizer")
|
overrideTokenizer: "PreTrainedTokenizer"
|
||||||
{ prompt in
|
) { prompt in
|
||||||
// given the prompt: func sortArray(_ array: [Int]) -> String { <FILL_ME> }
|
// given the prompt: func sortArray(_ array: [Int]) -> String { <FILL_ME> }
|
||||||
// the python code produces this (via its custom tokenizer):
|
// the python code produces this (via its custom tokenizer):
|
||||||
// <PRE> func sortArray(_ array: [Int]) -> String { <SUF> } <MID>
|
// <PRE> func sortArray(_ array: [Int]) -> String { <SUF> } <MID>
|
||||||
|
|
||||||
"<PRE> " +
|
"<PRE> " + prompt.replacingOccurrences(of: "<FILL_ME>", with: "<SUF>") + " <MID>"
|
||||||
prompt.replacingOccurrences(of: "<FILL_ME>", with: "<SUF>") +
|
|
||||||
" <MID>"
|
|
||||||
}
|
}
|
||||||
|
|
||||||
static let phi4bit = ModelConfiguration(id: "mlx-community/phi-2-hf-4bit-mlx") { prompt in
|
static let phi4bit = ModelConfiguration(id: "mlx-community/phi-2-hf-4bit-mlx") { prompt in
|
||||||
"Instruct: \(prompt). Output: "
|
"Instruct: \(prompt). Output: "
|
||||||
}
|
}
|
||||||
|
|
||||||
static let gemma2bQuantized = ModelConfiguration(
|
static let gemma2bQuantized = ModelConfiguration(
|
||||||
id: "mlx-community/quantized-gemma-2b-it",
|
id: "mlx-community/quantized-gemma-2b-it",
|
||||||
overrideTokenizer: "PreTrainedTokenizer") { prompt in
|
overrideTokenizer: "PreTrainedTokenizer"
|
||||||
"<start_of_turn>user \(prompt)<end_of_turn><start_of_turn>model"
|
) { prompt in
|
||||||
}
|
"<start_of_turn>user \(prompt)<end_of_turn><start_of_turn>model"
|
||||||
|
}
|
||||||
|
|
||||||
private enum BootstrapState {
|
private enum BootstrapState {
|
||||||
case idle
|
case idle
|
||||||
case bootstrapping
|
case bootstrapping
|
||||||
case bootstrapped
|
case bootstrapped
|
||||||
}
|
}
|
||||||
|
|
||||||
static private var bootstrapState = BootstrapState.idle
|
static private var bootstrapState = BootstrapState.idle
|
||||||
|
|
||||||
static func bootstrap() {
|
static func bootstrap() {
|
||||||
switch bootstrapState {
|
switch bootstrapState {
|
||||||
case .idle:
|
case .idle:
|
||||||
@@ -99,7 +100,7 @@ extension ModelConfiguration {
|
|||||||
|
|
||||||
case .bootstrapping:
|
case .bootstrapping:
|
||||||
break
|
break
|
||||||
|
|
||||||
case .bootstrapped:
|
case .bootstrapped:
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -42,7 +42,7 @@ struct SyncGenerator: AsyncParsableCommand {
|
|||||||
|
|
||||||
let modelConfiguration = ModelConfiguration.configuration(id: model)
|
let modelConfiguration = ModelConfiguration.configuration(id: model)
|
||||||
let (model, tokenizer) = try await load(configuration: modelConfiguration)
|
let (model, tokenizer) = try await load(configuration: modelConfiguration)
|
||||||
|
|
||||||
let prompt = modelConfiguration.prepare(prompt: self.prompt)
|
let prompt = modelConfiguration.prepare(prompt: self.prompt)
|
||||||
let promptTokens = tokenizer.encode(text: prompt)
|
let promptTokens = tokenizer.encode(text: prompt)
|
||||||
|
|
||||||
@@ -57,7 +57,8 @@ struct SyncGenerator: AsyncParsableCommand {
|
|||||||
var tokens = [Int]()
|
var tokens = [Int]()
|
||||||
var printed = 0
|
var printed = 0
|
||||||
|
|
||||||
for token in TokenIterator(prompt: MLXArray(promptTokens), model: model, temp: temperature) {
|
for token in TokenIterator(prompt: MLXArray(promptTokens), model: model, temp: temperature)
|
||||||
|
{
|
||||||
if tokens.isEmpty {
|
if tokens.isEmpty {
|
||||||
eval(token)
|
eval(token)
|
||||||
let now = Date.timeIntervalSinceReferenceDate
|
let now = Date.timeIntervalSinceReferenceDate
|
||||||
@@ -130,7 +131,7 @@ struct AsyncGenerator: AsyncParsableCommand {
|
|||||||
|
|
||||||
let modelConfiguration = ModelConfiguration.configuration(id: model)
|
let modelConfiguration = ModelConfiguration.configuration(id: model)
|
||||||
let (model, tokenizer) = try await load(configuration: modelConfiguration)
|
let (model, tokenizer) = try await load(configuration: modelConfiguration)
|
||||||
|
|
||||||
let prompt = modelConfiguration.prepare(prompt: self.prompt)
|
let prompt = modelConfiguration.prepare(prompt: self.prompt)
|
||||||
let promptTokens = tokenizer.encode(text: prompt)
|
let promptTokens = tokenizer.encode(text: prompt)
|
||||||
|
|
||||||
@@ -145,7 +146,8 @@ struct AsyncGenerator: AsyncParsableCommand {
|
|||||||
var tokens = [Int]()
|
var tokens = [Int]()
|
||||||
var printed = 0
|
var printed = 0
|
||||||
|
|
||||||
let (task, channel) = generate(prompt: MLXArray(promptTokens), model: model, temp: temperature)
|
let (task, channel) = generate(
|
||||||
|
prompt: MLXArray(promptTokens), model: model, temp: temperature)
|
||||||
|
|
||||||
for await token in channel {
|
for await token in channel {
|
||||||
if tokens.isEmpty {
|
if tokens.isEmpty {
|
||||||
|
|||||||
Reference in New Issue
Block a user