swift-format!

This commit is contained in:
David Koski
2024-03-01 14:47:43 -08:00
parent 82f6a969d4
commit 2157333905
3 changed files with 34 additions and 30 deletions

View File

@@ -14,7 +14,8 @@ struct LLMError: Error {
/// Load and return the model and tokenizer /// Load and return the model and tokenizer
public func load( public func load(
hub: HubApi = HubApi(), configuration: ModelConfiguration, progressHandler: @escaping (Progress) -> Void = { _ in } hub: HubApi = HubApi(), configuration: ModelConfiguration,
progressHandler: @escaping (Progress) -> Void = { _ in }
) async throws -> (LLMModel, Tokenizer) { ) async throws -> (LLMModel, Tokenizer) {
// note: this doesn't have a way to pass the HubApi // note: this doesn't have a way to pass the HubApi
let tokenizer = try await loadTokenizer(configuration: configuration) let tokenizer = try await loadTokenizer(configuration: configuration)

View File

@@ -19,7 +19,9 @@ public struct ModelConfiguration {
/// format /// format
private let preparePrompt: ((String) -> String)? private let preparePrompt: ((String) -> String)?
public init(id: String, overrideTokenizer: String? = nil, preparePrompt: ((String) -> String)? = nil) { public init(
id: String, overrideTokenizer: String? = nil, preparePrompt: ((String) -> String)? = nil
) {
self.id = id self.id = id
self.overrideTokenizer = overrideTokenizer self.overrideTokenizer = overrideTokenizer
self.preparePrompt = preparePrompt self.preparePrompt = preparePrompt
@@ -56,15 +58,13 @@ extension ModelConfiguration {
static let codeLlama13b4bit = ModelConfiguration( static let codeLlama13b4bit = ModelConfiguration(
id: "mlx-community/CodeLlama-13b-Instruct-hf-4bit-MLX", id: "mlx-community/CodeLlama-13b-Instruct-hf-4bit-MLX",
overrideTokenizer: "PreTrainedTokenizer") overrideTokenizer: "PreTrainedTokenizer"
{ prompt in ) { prompt in
// given the prompt: func sortArray(_ array: [Int]) -> String { <FILL_ME> } // given the prompt: func sortArray(_ array: [Int]) -> String { <FILL_ME> }
// the python code produces this (via its custom tokenizer): // the python code produces this (via its custom tokenizer):
// <PRE> func sortArray(_ array: [Int]) -> String { <SUF> } <MID> // <PRE> func sortArray(_ array: [Int]) -> String { <SUF> } <MID>
"<PRE> " + "<PRE> " + prompt.replacingOccurrences(of: "<FILL_ME>", with: "<SUF>") + " <MID>"
prompt.replacingOccurrences(of: "<FILL_ME>", with: "<SUF>") +
" <MID>"
} }
static let phi4bit = ModelConfiguration(id: "mlx-community/phi-2-hf-4bit-mlx") { prompt in static let phi4bit = ModelConfiguration(id: "mlx-community/phi-2-hf-4bit-mlx") { prompt in
@@ -73,7 +73,8 @@ extension ModelConfiguration {
static let gemma2bQuantized = ModelConfiguration( static let gemma2bQuantized = ModelConfiguration(
id: "mlx-community/quantized-gemma-2b-it", id: "mlx-community/quantized-gemma-2b-it",
overrideTokenizer: "PreTrainedTokenizer") { prompt in overrideTokenizer: "PreTrainedTokenizer"
) { prompt in
"<start_of_turn>user \(prompt)<end_of_turn><start_of_turn>model" "<start_of_turn>user \(prompt)<end_of_turn><start_of_turn>model"
} }

View File

@@ -57,7 +57,8 @@ struct SyncGenerator: AsyncParsableCommand {
var tokens = [Int]() var tokens = [Int]()
var printed = 0 var printed = 0
for token in TokenIterator(prompt: MLXArray(promptTokens), model: model, temp: temperature) { for token in TokenIterator(prompt: MLXArray(promptTokens), model: model, temp: temperature)
{
if tokens.isEmpty { if tokens.isEmpty {
eval(token) eval(token)
let now = Date.timeIntervalSinceReferenceDate let now = Date.timeIntervalSinceReferenceDate
@@ -145,7 +146,8 @@ struct AsyncGenerator: AsyncParsableCommand {
var tokens = [Int]() var tokens = [Int]()
var printed = 0 var printed = 0
let (task, channel) = generate(prompt: MLXArray(promptTokens), model: model, temp: temperature) let (task, channel) = generate(
prompt: MLXArray(promptTokens), model: model, temp: temperature)
for await token in channel { for await token in channel {
if tokens.isEmpty { if tokens.isEmpty {