swift-format!

This commit is contained in:
David Koski
2024-03-01 14:47:43 -08:00
parent 82f6a969d4
commit 2157333905
3 changed files with 34 additions and 30 deletions

View File

@@ -14,7 +14,8 @@ struct LLMError: Error {
/// Load and return the model and tokenizer
public func load(
hub: HubApi = HubApi(), configuration: ModelConfiguration, progressHandler: @escaping (Progress) -> Void = { _ in }
hub: HubApi = HubApi(), configuration: ModelConfiguration,
progressHandler: @escaping (Progress) -> Void = { _ in }
) async throws -> (LLMModel, Tokenizer) {
// note: this doesn't have a way to pass the HubApi
let tokenizer = try await loadTokenizer(configuration: configuration)

View File

@@ -19,7 +19,9 @@ public struct ModelConfiguration {
/// format
private let preparePrompt: ((String) -> String)?
public init(id: String, overrideTokenizer: String? = nil, preparePrompt: ((String) -> String)? = nil) {
public init(
id: String, overrideTokenizer: String? = nil, preparePrompt: ((String) -> String)? = nil
) {
self.id = id
self.overrideTokenizer = overrideTokenizer
self.preparePrompt = preparePrompt
@@ -29,7 +31,7 @@ public struct ModelConfiguration {
preparePrompt?(prompt) ?? prompt
}
public static var registry = [String:ModelConfiguration]()
public static var registry = [String: ModelConfiguration]()
public static func register(configurations: [ModelConfiguration]) {
bootstrap()
@@ -56,15 +58,13 @@ extension ModelConfiguration {
static let codeLlama13b4bit = ModelConfiguration(
id: "mlx-community/CodeLlama-13b-Instruct-hf-4bit-MLX",
overrideTokenizer: "PreTrainedTokenizer")
{ prompt in
overrideTokenizer: "PreTrainedTokenizer"
) { prompt in
// given the prompt: func sortArray(_ array: [Int]) -> String { <FILL_ME> }
// the python code produces this (via its custom tokenizer):
// <PRE> func sortArray(_ array: [Int]) -> String { <SUF> } <MID>
"<PRE> " +
prompt.replacingOccurrences(of: "<FILL_ME>", with: "<SUF>") +
" <MID>"
"<PRE> " + prompt.replacingOccurrences(of: "<FILL_ME>", with: "<SUF>") + " <MID>"
}
static let phi4bit = ModelConfiguration(id: "mlx-community/phi-2-hf-4bit-mlx") { prompt in
@@ -73,9 +73,10 @@ extension ModelConfiguration {
static let gemma2bQuantized = ModelConfiguration(
id: "mlx-community/quantized-gemma-2b-it",
overrideTokenizer: "PreTrainedTokenizer") { prompt in
"<start_of_turn>user \(prompt)<end_of_turn><start_of_turn>model"
}
overrideTokenizer: "PreTrainedTokenizer"
) { prompt in
"<start_of_turn>user \(prompt)<end_of_turn><start_of_turn>model"
}
private enum BootstrapState {
case idle

View File

@@ -57,7 +57,8 @@ struct SyncGenerator: AsyncParsableCommand {
var tokens = [Int]()
var printed = 0
for token in TokenIterator(prompt: MLXArray(promptTokens), model: model, temp: temperature) {
for token in TokenIterator(prompt: MLXArray(promptTokens), model: model, temp: temperature)
{
if tokens.isEmpty {
eval(token)
let now = Date.timeIntervalSinceReferenceDate
@@ -145,7 +146,8 @@ struct AsyncGenerator: AsyncParsableCommand {
var tokens = [Int]()
var printed = 0
let (task, channel) = generate(prompt: MLXArray(promptTokens), model: model, temp: temperature)
let (task, channel) = generate(
prompt: MLXArray(promptTokens), model: model, temp: temperature)
for await token in channel {
if tokens.isEmpty {