swift-format!
This commit is contained in:
@@ -14,7 +14,8 @@ struct LLMError: Error {
|
||||
|
||||
/// Load and return the model and tokenizer
|
||||
public func load(
|
||||
hub: HubApi = HubApi(), configuration: ModelConfiguration, progressHandler: @escaping (Progress) -> Void = { _ in }
|
||||
hub: HubApi = HubApi(), configuration: ModelConfiguration,
|
||||
progressHandler: @escaping (Progress) -> Void = { _ in }
|
||||
) async throws -> (LLMModel, Tokenizer) {
|
||||
// note: this doesn't have a way to pass the HubApi
|
||||
let tokenizer = try await loadTokenizer(configuration: configuration)
|
||||
|
||||
@@ -19,7 +19,9 @@ public struct ModelConfiguration {
|
||||
/// format
|
||||
private let preparePrompt: ((String) -> String)?
|
||||
|
||||
public init(id: String, overrideTokenizer: String? = nil, preparePrompt: ((String) -> String)? = nil) {
|
||||
public init(
|
||||
id: String, overrideTokenizer: String? = nil, preparePrompt: ((String) -> String)? = nil
|
||||
) {
|
||||
self.id = id
|
||||
self.overrideTokenizer = overrideTokenizer
|
||||
self.preparePrompt = preparePrompt
|
||||
@@ -56,15 +58,13 @@ extension ModelConfiguration {
|
||||
|
||||
static let codeLlama13b4bit = ModelConfiguration(
|
||||
id: "mlx-community/CodeLlama-13b-Instruct-hf-4bit-MLX",
|
||||
overrideTokenizer: "PreTrainedTokenizer")
|
||||
{ prompt in
|
||||
overrideTokenizer: "PreTrainedTokenizer"
|
||||
) { prompt in
|
||||
// given the prompt: func sortArray(_ array: [Int]) -> String { <FILL_ME> }
|
||||
// the python code produces this (via its custom tokenizer):
|
||||
// <PRE> func sortArray(_ array: [Int]) -> String { <SUF> } <MID>
|
||||
|
||||
"<PRE> " +
|
||||
prompt.replacingOccurrences(of: "<FILL_ME>", with: "<SUF>") +
|
||||
" <MID>"
|
||||
"<PRE> " + prompt.replacingOccurrences(of: "<FILL_ME>", with: "<SUF>") + " <MID>"
|
||||
}
|
||||
|
||||
static let phi4bit = ModelConfiguration(id: "mlx-community/phi-2-hf-4bit-mlx") { prompt in
|
||||
@@ -73,7 +73,8 @@ extension ModelConfiguration {
|
||||
|
||||
static let gemma2bQuantized = ModelConfiguration(
|
||||
id: "mlx-community/quantized-gemma-2b-it",
|
||||
overrideTokenizer: "PreTrainedTokenizer") { prompt in
|
||||
overrideTokenizer: "PreTrainedTokenizer"
|
||||
) { prompt in
|
||||
"<start_of_turn>user \(prompt)<end_of_turn><start_of_turn>model"
|
||||
}
|
||||
|
||||
|
||||
@@ -57,7 +57,8 @@ struct SyncGenerator: AsyncParsableCommand {
|
||||
var tokens = [Int]()
|
||||
var printed = 0
|
||||
|
||||
for token in TokenIterator(prompt: MLXArray(promptTokens), model: model, temp: temperature) {
|
||||
for token in TokenIterator(prompt: MLXArray(promptTokens), model: model, temp: temperature)
|
||||
{
|
||||
if tokens.isEmpty {
|
||||
eval(token)
|
||||
let now = Date.timeIntervalSinceReferenceDate
|
||||
@@ -145,7 +146,8 @@ struct AsyncGenerator: AsyncParsableCommand {
|
||||
var tokens = [Int]()
|
||||
var printed = 0
|
||||
|
||||
let (task, channel) = generate(prompt: MLXArray(promptTokens), model: model, temp: temperature)
|
||||
let (task, channel) = generate(
|
||||
prompt: MLXArray(promptTokens), model: model, temp: temperature)
|
||||
|
||||
for await token in channel {
|
||||
if tokens.isEmpty {
|
||||
|
||||
Reference in New Issue
Block a user