swift-format!

This commit is contained in:
David Koski
2024-03-01 14:47:43 -08:00
parent 82f6a969d4
commit 2157333905
3 changed files with 34 additions and 30 deletions

View File

@@ -14,7 +14,8 @@ struct LLMError: Error {
/// Load and return the model and tokenizer /// Load and return the model and tokenizer
public func load( public func load(
hub: HubApi = HubApi(), configuration: ModelConfiguration, progressHandler: @escaping (Progress) -> Void = { _ in } hub: HubApi = HubApi(), configuration: ModelConfiguration,
progressHandler: @escaping (Progress) -> Void = { _ in }
) async throws -> (LLMModel, Tokenizer) { ) async throws -> (LLMModel, Tokenizer) {
// note: this doesn't have a way to pass the HubApi // note: this doesn't have a way to pass the HubApi
let tokenizer = try await loadTokenizer(configuration: configuration) let tokenizer = try await loadTokenizer(configuration: configuration)

View File

@@ -10,35 +10,37 @@ import Foundation
/// implementation, if needed. /// implementation, if needed.
public struct ModelConfiguration { public struct ModelConfiguration {
public let id: String public let id: String
/// overrides for TokenizerModel/knownTokenizers -- useful before swift-transformers is updated /// overrides for TokenizerModel/knownTokenizers -- useful before swift-transformers is updated
public let overrideTokenizer: String? public let overrideTokenizer: String?
/// custom preparation logic for the prompt. custom tokenizers provide more capability, but this /// custom preparation logic for the prompt. custom tokenizers provide more capability, but this
/// allows some minor formtting changes, e.g. wrapping the user input in the expected prompt /// allows some minor formtting changes, e.g. wrapping the user input in the expected prompt
/// format /// format
private let preparePrompt: ((String) -> String)? private let preparePrompt: ((String) -> String)?
public init(id: String, overrideTokenizer: String? = nil, preparePrompt: ((String) -> String)? = nil) { public init(
id: String, overrideTokenizer: String? = nil, preparePrompt: ((String) -> String)? = nil
) {
self.id = id self.id = id
self.overrideTokenizer = overrideTokenizer self.overrideTokenizer = overrideTokenizer
self.preparePrompt = preparePrompt self.preparePrompt = preparePrompt
} }
public func prepare(prompt: String) -> String { public func prepare(prompt: String) -> String {
preparePrompt?(prompt) ?? prompt preparePrompt?(prompt) ?? prompt
} }
public static var registry = [String:ModelConfiguration]() public static var registry = [String: ModelConfiguration]()
public static func register(configurations: [ModelConfiguration]) { public static func register(configurations: [ModelConfiguration]) {
bootstrap() bootstrap()
for c in configurations { for c in configurations {
registry[c.id] = c registry[c.id] = c
} }
} }
public static func configuration(id: String) -> ModelConfiguration { public static func configuration(id: String) -> ModelConfiguration {
bootstrap() bootstrap()
@@ -51,40 +53,39 @@ public struct ModelConfiguration {
} }
extension ModelConfiguration { extension ModelConfiguration {
static let mistral7B4bit = ModelConfiguration(id: "mlx-community/Mistral-7B-v0.1-hf-4bit-mlx") static let mistral7B4bit = ModelConfiguration(id: "mlx-community/Mistral-7B-v0.1-hf-4bit-mlx")
static let codeLlama13b4bit = ModelConfiguration( static let codeLlama13b4bit = ModelConfiguration(
id: "mlx-community/CodeLlama-13b-Instruct-hf-4bit-MLX", id: "mlx-community/CodeLlama-13b-Instruct-hf-4bit-MLX",
overrideTokenizer: "PreTrainedTokenizer") overrideTokenizer: "PreTrainedTokenizer"
{ prompt in ) { prompt in
// given the prompt: func sortArray(_ array: [Int]) -> String { <FILL_ME> } // given the prompt: func sortArray(_ array: [Int]) -> String { <FILL_ME> }
// the python code produces this (via its custom tokenizer): // the python code produces this (via its custom tokenizer):
// <PRE> func sortArray(_ array: [Int]) -> String { <SUF> } <MID> // <PRE> func sortArray(_ array: [Int]) -> String { <SUF> } <MID>
"<PRE> " + "<PRE> " + prompt.replacingOccurrences(of: "<FILL_ME>", with: "<SUF>") + " <MID>"
prompt.replacingOccurrences(of: "<FILL_ME>", with: "<SUF>") +
" <MID>"
} }
static let phi4bit = ModelConfiguration(id: "mlx-community/phi-2-hf-4bit-mlx") { prompt in static let phi4bit = ModelConfiguration(id: "mlx-community/phi-2-hf-4bit-mlx") { prompt in
"Instruct: \(prompt). Output: " "Instruct: \(prompt). Output: "
} }
static let gemma2bQuantized = ModelConfiguration( static let gemma2bQuantized = ModelConfiguration(
id: "mlx-community/quantized-gemma-2b-it", id: "mlx-community/quantized-gemma-2b-it",
overrideTokenizer: "PreTrainedTokenizer") { prompt in overrideTokenizer: "PreTrainedTokenizer"
"<start_of_turn>user \(prompt)<end_of_turn><start_of_turn>model" ) { prompt in
} "<start_of_turn>user \(prompt)<end_of_turn><start_of_turn>model"
}
private enum BootstrapState { private enum BootstrapState {
case idle case idle
case bootstrapping case bootstrapping
case bootstrapped case bootstrapped
} }
static private var bootstrapState = BootstrapState.idle static private var bootstrapState = BootstrapState.idle
static func bootstrap() { static func bootstrap() {
switch bootstrapState { switch bootstrapState {
case .idle: case .idle:
@@ -99,7 +100,7 @@ extension ModelConfiguration {
case .bootstrapping: case .bootstrapping:
break break
case .bootstrapped: case .bootstrapped:
break break
} }

View File

@@ -42,7 +42,7 @@ struct SyncGenerator: AsyncParsableCommand {
let modelConfiguration = ModelConfiguration.configuration(id: model) let modelConfiguration = ModelConfiguration.configuration(id: model)
let (model, tokenizer) = try await load(configuration: modelConfiguration) let (model, tokenizer) = try await load(configuration: modelConfiguration)
let prompt = modelConfiguration.prepare(prompt: self.prompt) let prompt = modelConfiguration.prepare(prompt: self.prompt)
let promptTokens = tokenizer.encode(text: prompt) let promptTokens = tokenizer.encode(text: prompt)
@@ -57,7 +57,8 @@ struct SyncGenerator: AsyncParsableCommand {
var tokens = [Int]() var tokens = [Int]()
var printed = 0 var printed = 0
for token in TokenIterator(prompt: MLXArray(promptTokens), model: model, temp: temperature) { for token in TokenIterator(prompt: MLXArray(promptTokens), model: model, temp: temperature)
{
if tokens.isEmpty { if tokens.isEmpty {
eval(token) eval(token)
let now = Date.timeIntervalSinceReferenceDate let now = Date.timeIntervalSinceReferenceDate
@@ -130,7 +131,7 @@ struct AsyncGenerator: AsyncParsableCommand {
let modelConfiguration = ModelConfiguration.configuration(id: model) let modelConfiguration = ModelConfiguration.configuration(id: model)
let (model, tokenizer) = try await load(configuration: modelConfiguration) let (model, tokenizer) = try await load(configuration: modelConfiguration)
let prompt = modelConfiguration.prepare(prompt: self.prompt) let prompt = modelConfiguration.prepare(prompt: self.prompt)
let promptTokens = tokenizer.encode(text: prompt) let promptTokens = tokenizer.encode(text: prompt)
@@ -145,7 +146,8 @@ struct AsyncGenerator: AsyncParsableCommand {
var tokens = [Int]() var tokens = [Int]()
var printed = 0 var printed = 0
let (task, channel) = generate(prompt: MLXArray(promptTokens), model: model, temp: temperature) let (task, channel) = generate(
prompt: MLXArray(promptTokens), model: model, temp: temperature)
for await token in channel { for await token in channel {
if tokens.isEmpty { if tokens.isEmpty {