swift-format!

This commit is contained in:
David Koski
2024-03-01 14:47:43 -08:00
parent 82f6a969d4
commit 2157333905
3 changed files with 34 additions and 30 deletions

View File

@@ -14,7 +14,8 @@ struct LLMError: Error {
/// Load and return the model and tokenizer
public func load(
hub: HubApi = HubApi(), configuration: ModelConfiguration, progressHandler: @escaping (Progress) -> Void = { _ in }
hub: HubApi = HubApi(), configuration: ModelConfiguration,
progressHandler: @escaping (Progress) -> Void = { _ in }
) async throws -> (LLMModel, Tokenizer) {
// note: this doesn't have a way to pass the HubApi
let tokenizer = try await loadTokenizer(configuration: configuration)

View File

@@ -10,35 +10,37 @@ import Foundation
/// implementation, if needed.
public struct ModelConfiguration {
public let id: String
/// overrides for TokenizerModel/knownTokenizers -- useful before swift-transformers is updated
public let overrideTokenizer: String?
/// custom preparation logic for the prompt. custom tokenizers provide more capability, but this
/// allows some minor formtting changes, e.g. wrapping the user input in the expected prompt
/// format
private let preparePrompt: ((String) -> String)?
public init(id: String, overrideTokenizer: String? = nil, preparePrompt: ((String) -> String)? = nil) {
public init(
id: String, overrideTokenizer: String? = nil, preparePrompt: ((String) -> String)? = nil
) {
self.id = id
self.overrideTokenizer = overrideTokenizer
self.preparePrompt = preparePrompt
}
public func prepare(prompt: String) -> String {
preparePrompt?(prompt) ?? prompt
}
public static var registry = [String:ModelConfiguration]()
public static var registry = [String: ModelConfiguration]()
public static func register(configurations: [ModelConfiguration]) {
bootstrap()
for c in configurations {
registry[c.id] = c
}
}
public static func configuration(id: String) -> ModelConfiguration {
bootstrap()
@@ -51,40 +53,39 @@ public struct ModelConfiguration {
}
extension ModelConfiguration {
static let mistral7B4bit = ModelConfiguration(id: "mlx-community/Mistral-7B-v0.1-hf-4bit-mlx")
static let codeLlama13b4bit = ModelConfiguration(
id: "mlx-community/CodeLlama-13b-Instruct-hf-4bit-MLX",
overrideTokenizer: "PreTrainedTokenizer")
{ prompt in
overrideTokenizer: "PreTrainedTokenizer"
) { prompt in
// given the prompt: func sortArray(_ array: [Int]) -> String { <FILL_ME> }
// the python code produces this (via its custom tokenizer):
// <PRE> func sortArray(_ array: [Int]) -> String { <SUF> } <MID>
"<PRE> " +
prompt.replacingOccurrences(of: "<FILL_ME>", with: "<SUF>") +
" <MID>"
"<PRE> " + prompt.replacingOccurrences(of: "<FILL_ME>", with: "<SUF>") + " <MID>"
}
static let phi4bit = ModelConfiguration(id: "mlx-community/phi-2-hf-4bit-mlx") { prompt in
"Instruct: \(prompt). Output: "
}
static let gemma2bQuantized = ModelConfiguration(
id: "mlx-community/quantized-gemma-2b-it",
overrideTokenizer: "PreTrainedTokenizer") { prompt in
"<start_of_turn>user \(prompt)<end_of_turn><start_of_turn>model"
}
overrideTokenizer: "PreTrainedTokenizer"
) { prompt in
"<start_of_turn>user \(prompt)<end_of_turn><start_of_turn>model"
}
private enum BootstrapState {
case idle
case bootstrapping
case bootstrapped
}
static private var bootstrapState = BootstrapState.idle
static func bootstrap() {
switch bootstrapState {
case .idle:
@@ -99,7 +100,7 @@ extension ModelConfiguration {
case .bootstrapping:
break
case .bootstrapped:
break
}