197 lines
6.1 KiB
Swift
197 lines
6.1 KiB
Swift
// Copyright © 2024 Apple Inc.
|
|
|
|
import ArgumentParser
|
|
import Foundation
|
|
import LLM
|
|
import MLX
|
|
import MLXRandom
|
|
|
|
@main
|
|
struct LLMTool: AsyncParsableCommand {
|
|
static var configuration = CommandConfiguration(
|
|
abstract: "Command line tool for generating text using Llama models",
|
|
subcommands: [SyncGenerator.self, AsyncGenerator.self],
|
|
defaultSubcommand: SyncGenerator.self)
|
|
}
|
|
|
|
struct SyncGenerator: AsyncParsableCommand {
|
|
|
|
static var configuration = CommandConfiguration(
|
|
commandName: "sync",
|
|
abstract: "Synchronous generator"
|
|
)
|
|
|
|
@Option(name: .long, help: "Name of the huggingface model")
|
|
var model: String = "mlx-community/Mistral-7B-v0.1-hf-4bit-mlx"
|
|
|
|
@Option(name: .shortAndLong, help: "The message to be processed by the model")
|
|
var prompt = "compare python and swift"
|
|
|
|
@Option(name: .shortAndLong, help: "Maximum number of tokens to generate")
|
|
var maxTokens = 100
|
|
|
|
@Option(name: .shortAndLong, help: "The sampling temperature")
|
|
var temperature: Float = 0.6
|
|
|
|
@Option(name: .long, help: "The PRNG seed")
|
|
var seed: UInt64 = 0
|
|
|
|
@MainActor
|
|
func run() async throws {
|
|
MLXRandom.seed(seed)
|
|
|
|
let modelConfiguration = ModelConfiguration.configuration(id: model)
|
|
let (model, tokenizer) = try await load(configuration: modelConfiguration)
|
|
|
|
let prompt = modelConfiguration.prepare(prompt: self.prompt)
|
|
let promptTokens = tokenizer.encode(text: prompt)
|
|
|
|
print("Starting generation ...")
|
|
print(self.prompt, terminator: "")
|
|
|
|
var start = Date.timeIntervalSinceReferenceDate
|
|
var promptTime: TimeInterval = 0
|
|
|
|
// collect the tokens and keep track of how much of the string
|
|
// we have printed already
|
|
var tokens = [Int]()
|
|
var printed = 0
|
|
|
|
for token in TokenIterator(prompt: MLXArray(promptTokens), model: model, temp: temperature)
|
|
{
|
|
if tokens.isEmpty {
|
|
eval(token)
|
|
let now = Date.timeIntervalSinceReferenceDate
|
|
promptTime = now - start
|
|
start = now
|
|
}
|
|
|
|
let t = token.item(Int.self)
|
|
if t == tokenizer.unknownTokenId || t == tokenizer.eosTokenId {
|
|
break
|
|
}
|
|
tokens.append(t)
|
|
|
|
// print any new parts of the string
|
|
let fullOutput = tokenizer.decode(tokens: tokens)
|
|
let emitLength = fullOutput.count - printed
|
|
let suffix = fullOutput.suffix(emitLength)
|
|
print(suffix, terminator: "")
|
|
fflush(stdout)
|
|
|
|
printed = fullOutput.count
|
|
|
|
if tokens.count == maxTokens {
|
|
break
|
|
}
|
|
}
|
|
|
|
print()
|
|
print("------")
|
|
let now = Date.timeIntervalSinceReferenceDate
|
|
let generateTime = now - start
|
|
|
|
print(
|
|
"""
|
|
Prompt Tokens per second: \((Double(promptTokens.count) / promptTime).formatted())
|
|
Generation tokens per second: \((Double(tokens.count - 1) / generateTime).formatted())
|
|
""")
|
|
}
|
|
}
|
|
|
|
/// Example of an async generator.
|
|
///
|
|
/// Note that all of the computation is done on another thread and TokenId (Int32) are sent
|
|
/// rather than MLXArray.
|
|
struct AsyncGenerator: AsyncParsableCommand {
|
|
|
|
static var configuration = CommandConfiguration(
|
|
commandName: "async",
|
|
abstract: "async generator"
|
|
)
|
|
|
|
@Option(name: .long, help: "Name of the huggingface model")
|
|
var model: String = "mlx-community/Mistral-7B-v0.1-hf-4bit-mlx"
|
|
|
|
@Option(name: .shortAndLong, help: "The message to be processed by the model")
|
|
var prompt = "compare python and swift"
|
|
|
|
@Option(name: .shortAndLong, help: "Maximum number of tokens to generate")
|
|
var maxTokens = 100
|
|
|
|
@Option(name: .shortAndLong, help: "The sampling temperature")
|
|
var temperature: Float = 0.6
|
|
|
|
@Option(name: .long, help: "The PRNG seed")
|
|
var seed: UInt64 = 0
|
|
|
|
@MainActor
|
|
func run() async throws {
|
|
MLXRandom.seed(seed)
|
|
|
|
let modelConfiguration = ModelConfiguration.configuration(id: model)
|
|
let (model, tokenizer) = try await load(configuration: modelConfiguration)
|
|
|
|
let prompt = modelConfiguration.prepare(prompt: self.prompt)
|
|
let promptTokens = tokenizer.encode(text: prompt)
|
|
|
|
print("Starting generation ...")
|
|
print(self.prompt, terminator: "")
|
|
|
|
var start = Date.timeIntervalSinceReferenceDate
|
|
var promptTime: TimeInterval = 0
|
|
|
|
// collect the tokens and keep track of how much of the string
|
|
// we have printed already
|
|
var tokens = [Int]()
|
|
var printed = 0
|
|
|
|
let (task, channel) = generate(
|
|
prompt: MLXArray(promptTokens), model: model, temp: temperature)
|
|
|
|
for await token in channel {
|
|
if tokens.isEmpty {
|
|
let now = Date.timeIntervalSinceReferenceDate
|
|
promptTime = now - start
|
|
start = now
|
|
}
|
|
|
|
if token == tokenizer.unknownTokenId || token == tokenizer.eosTokenId {
|
|
break
|
|
}
|
|
tokens.append(token)
|
|
|
|
// print any new parts of the string
|
|
let fullOutput = tokenizer.decode(tokens: tokens)
|
|
let emitLength = fullOutput.count - printed
|
|
let suffix = fullOutput.suffix(emitLength)
|
|
print(suffix, terminator: "")
|
|
fflush(stdout)
|
|
|
|
printed = fullOutput.count
|
|
|
|
if tokens.count == maxTokens {
|
|
break
|
|
}
|
|
}
|
|
|
|
// tell the task to stop
|
|
task.cancel()
|
|
|
|
print()
|
|
print("------")
|
|
let now = Date.timeIntervalSinceReferenceDate
|
|
let generateTime = now - start
|
|
|
|
print(
|
|
"""
|
|
Prompt Tokens per second: \((Double(promptTokens.count) / promptTime).formatted())
|
|
Generation tokens per second: \((Double(tokens.count - 1) / generateTime).formatted())
|
|
""")
|
|
|
|
// wait for the task to complete -- since it is running async, it might
|
|
// be in the middle of running the model
|
|
try? await Task.sleep(for: .milliseconds(500))
|
|
}
|
|
}
|