Files
mlx-swift-examples/Tools/llm-tool/LLMTool.swift
2024-03-03 22:26:28 +08:00

201 lines
6.2 KiB
Swift

// Copyright © 2024 Apple Inc.
import ArgumentParser
import Foundation
import LLM
import MLX
import MLXRandom
@main
struct LLMTool: AsyncParsableCommand {
static var configuration = CommandConfiguration(
abstract: "Command line tool for generating text using Llama models",
subcommands: [SyncGenerator.self, AsyncGenerator.self],
defaultSubcommand: SyncGenerator.self)
}
struct SyncGenerator: AsyncParsableCommand {
static var configuration = CommandConfiguration(
commandName: "sync",
abstract: "Synchronous generator"
)
@Option(name: .long, help: "Name of the huggingface model")
var model: String = "mlx-community/Mistral-7B-v0.1-hf-4bit-mlx"
@Option(name: .shortAndLong, help: "The message to be processed by the model")
var prompt = "compare python and swift"
@Option(name: .shortAndLong, help: "Maximum number of tokens to generate")
var maxTokens = 100
@Option(name: .shortAndLong, help: "The sampling temperature")
var temperature: Float = 0.6
@Option(name: .long, help: "The PRNG seed")
var seed: UInt64 = 0
@MainActor
func run() async throws {
MLXRandom.seed(seed)
let modelConfiguration = ModelConfiguration.configuration(id: model)
let (model, tokenizer) = try await load(configuration: modelConfiguration)
print("Model loaded -> \(self.model)")
let prompt = modelConfiguration.prepare(prompt: self.prompt)
let promptTokens = tokenizer.encode(text: prompt)
print("Starting generation ...")
print(self.prompt, terminator: "")
var start = Date.timeIntervalSinceReferenceDate
var promptTime: TimeInterval = 0
// collect the tokens and keep track of how much of the string
// we have printed already
var tokens = [Int]()
var printed = 0
for token in TokenIterator(prompt: MLXArray(promptTokens), model: model, temp: temperature)
{
if tokens.isEmpty {
eval(token)
let now = Date.timeIntervalSinceReferenceDate
promptTime = now - start
start = now
}
let t = token.item(Int.self)
if t == tokenizer.unknownTokenId || t == tokenizer.eosTokenId {
break
}
tokens.append(t)
// print any new parts of the string
let fullOutput = tokenizer.decode(tokens: tokens)
let emitLength = fullOutput.count - printed
let suffix = fullOutput.suffix(emitLength)
print(suffix, terminator: "")
fflush(stdout)
printed = fullOutput.count
if tokens.count == maxTokens {
break
}
}
print()
print("------")
let now = Date.timeIntervalSinceReferenceDate
let generateTime = now - start
print(
"""
Prompt Tokens per second: \((Double(promptTokens.count) / promptTime).formatted())
Generation tokens per second: \((Double(tokens.count - 1) / generateTime).formatted())
""")
}
}
/// Example of an async generator.
///
/// Note that all of the computation is done on another thread and TokenId (Int32) are sent
/// rather than MLXArray.
struct AsyncGenerator: AsyncParsableCommand {
static var configuration = CommandConfiguration(
commandName: "async",
abstract: "async generator"
)
@Option(name: .long, help: "Name of the huggingface model")
var model: String = "mlx-community/Mistral-7B-v0.1-hf-4bit-mlx"
@Option(name: .shortAndLong, help: "The message to be processed by the model")
var prompt = "compare python and swift"
@Option(name: .shortAndLong, help: "Maximum number of tokens to generate")
var maxTokens = 100
@Option(name: .shortAndLong, help: "The sampling temperature")
var temperature: Float = 0.6
@Option(name: .long, help: "The PRNG seed")
var seed: UInt64 = 0
@MainActor
func run() async throws {
MLXRandom.seed(seed)
let modelConfiguration = ModelConfiguration.configuration(id: model)
let (model, tokenizer) = try await load(configuration: modelConfiguration)
print("Model loaded -> \(self.model)")
let prompt = modelConfiguration.prepare(prompt: self.prompt)
let promptTokens = tokenizer.encode(text: prompt)
print("Starting generation ...")
print(self.prompt, terminator: "")
var start = Date.timeIntervalSinceReferenceDate
var promptTime: TimeInterval = 0
// collect the tokens and keep track of how much of the string
// we have printed already
var tokens = [Int]()
var printed = 0
let (task, channel) = generate(
prompt: MLXArray(promptTokens), model: model, temp: temperature)
for await token in channel {
if tokens.isEmpty {
let now = Date.timeIntervalSinceReferenceDate
promptTime = now - start
start = now
}
if token == tokenizer.unknownTokenId || token == tokenizer.eosTokenId {
break
}
tokens.append(token)
// print any new parts of the string
let fullOutput = tokenizer.decode(tokens: tokens)
let emitLength = fullOutput.count - printed
let suffix = fullOutput.suffix(emitLength)
print(suffix, terminator: "")
fflush(stdout)
printed = fullOutput.count
if tokens.count == maxTokens {
break
}
}
// tell the task to stop
task.cancel()
print()
print("------")
let now = Date.timeIntervalSinceReferenceDate
let generateTime = now - start
print(
"""
Prompt Tokens per second: \((Double(promptTokens.count) / promptTime).formatted())
Generation tokens per second: \((Double(tokens.count - 1) / generateTime).formatted())
""")
// wait for the task to complete -- since it is running async, it might
// be in the middle of running the model
try? await Task.sleep(for: .milliseconds(500))
}
}