prepare for lora branch (#47)
- remove async llm generation -- this is just doubling our work - and does not match the style used in the example applications - package generation parameters into a struct - refactor command line arguments into distinct pieces based on their use - this will be reusable in the lora commands
This commit is contained in:
@@ -10,16 +10,27 @@ import Tokenizers
|
||||
@main
|
||||
struct LLMTool: AsyncParsableCommand {
|
||||
static var configuration = CommandConfiguration(
|
||||
abstract: "Command line tool for generating text using Llama models",
|
||||
subcommands: [SyncGenerator.self, AsyncGenerator.self],
|
||||
defaultSubcommand: SyncGenerator.self)
|
||||
abstract: "Command line tool for generating text and manipulating LLMs",
|
||||
subcommands: [EvaluateCommand.self],
|
||||
defaultSubcommand: EvaluateCommand.self)
|
||||
}
|
||||
|
||||
struct LLMArguments: ParsableArguments {
|
||||
/// Command line arguments for loading a model.
|
||||
struct ModelArguments: ParsableArguments {
|
||||
|
||||
@Option(name: .long, help: "Name of the huggingface model")
|
||||
var model: String = "mlx-community/Mistral-7B-v0.1-hf-4bit-mlx"
|
||||
|
||||
func load() async throws -> (LLMModel, Tokenizer, ModelConfiguration) {
|
||||
let modelConfiguration = ModelConfiguration.configuration(id: model)
|
||||
let (model, tokenizer) = try await LLM.load(configuration: modelConfiguration)
|
||||
return (model, tokenizer, modelConfiguration)
|
||||
}
|
||||
}
|
||||
|
||||
/// Command line arguments for controlling generation of text.
|
||||
struct GenerateArguments: ParsableArguments {
|
||||
|
||||
@Option(name: .shortAndLong, help: "The message to be processed by the model")
|
||||
var prompt = "compare python and swift"
|
||||
|
||||
@@ -29,19 +40,67 @@ struct LLMArguments: ParsableArguments {
|
||||
@Option(name: .shortAndLong, help: "The sampling temperature")
|
||||
var temperature: Float = 0.6
|
||||
|
||||
@Option(name: .shortAndLong, help: "The top p sampling")
|
||||
@Option(name: .long, help: "The top p sampling")
|
||||
var topP: Float = 0.9
|
||||
|
||||
@Option(name: .shortAndLong, help: "The penalty factor for repeating tokens")
|
||||
@Option(name: .long, help: "The penalty factor for repeating tokens")
|
||||
var repetitionPenalty: Float = 1.0
|
||||
|
||||
@Option(name: .shortAndLong, help: "The number of tokens to consider for repetition penalty")
|
||||
@Option(name: .long, help: "The number of tokens to consider for repetition penalty")
|
||||
var repetitionContextSize: Int = 20
|
||||
|
||||
@Option(name: .long, help: "The PRNG seed")
|
||||
var seed: UInt64 = 0
|
||||
|
||||
@Flag(help: "Show memory stats")
|
||||
var generateParameters: GenerateParameters {
|
||||
GenerateParameters(
|
||||
temperature: temperature, topP: topP, repetitionPenalty: repetitionPenalty,
|
||||
repetitionContextSize: repetitionContextSize)
|
||||
}
|
||||
|
||||
func tokenizePrompt(configuration: ModelConfiguration, tokenizer: Tokenizer) -> (String, [Int])
|
||||
{
|
||||
MLXRandom.seed(seed)
|
||||
|
||||
let prompt = configuration.prepare(prompt: self.prompt)
|
||||
let promptTokens = tokenizer.encode(text: prompt)
|
||||
|
||||
return (prompt, promptTokens)
|
||||
}
|
||||
|
||||
func generate(promptTokens: [Int], model: LLMModel, tokenizer: Tokenizer) async
|
||||
-> GenerateResult
|
||||
{
|
||||
// track how much we have printed
|
||||
var printed = 0
|
||||
|
||||
return await LLM.generate(
|
||||
promptTokens: promptTokens, parameters: generateParameters,
|
||||
model: model, tokenizer: tokenizer
|
||||
) { tokens in
|
||||
|
||||
// print any new parts of the string
|
||||
let fullOutput = tokenizer.decode(tokens: tokens)
|
||||
let emitLength = fullOutput.count - printed
|
||||
let suffix = fullOutput.suffix(emitLength)
|
||||
print(suffix, terminator: "")
|
||||
fflush(stdout)
|
||||
|
||||
printed = fullOutput.count
|
||||
|
||||
if tokens.count >= maxTokens {
|
||||
return .stop
|
||||
} else {
|
||||
return .more
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Argument package for adjusting and reporting memory use.
|
||||
struct MemoryArguments: ParsableArguments {
|
||||
|
||||
@Flag(name: .long, help: "Show memory stats")
|
||||
var memoryStats = false
|
||||
|
||||
@Option(name: .long, help: "Maximum cache size in M")
|
||||
@@ -52,9 +111,7 @@ struct LLMArguments: ParsableArguments {
|
||||
|
||||
var startMemory: GPU.Snapshot?
|
||||
|
||||
mutating func load() async throws -> (LLMModel, Tokenizer, ModelConfiguration) {
|
||||
MLXRandom.seed(seed)
|
||||
|
||||
mutating func start<L>(_ load: () async throws -> L) async throws -> L {
|
||||
if let cacheSize {
|
||||
GPU.set(cacheLimit: cacheSize * 1024 * 1024)
|
||||
}
|
||||
@@ -63,20 +120,29 @@ struct LLMArguments: ParsableArguments {
|
||||
GPU.set(memoryLimit: memorySize * 1024 * 1024)
|
||||
}
|
||||
|
||||
let modelConfiguration = ModelConfiguration.configuration(id: model)
|
||||
let (model, tokenizer) = try await LLM.load(configuration: modelConfiguration)
|
||||
|
||||
let result = try await load()
|
||||
startMemory = GPU.snapshot()
|
||||
|
||||
return (model, tokenizer, modelConfiguration)
|
||||
return result
|
||||
}
|
||||
|
||||
func tokenizePropmpt(configuration: ModelConfiguration, tokenizer: Tokenizer) -> (String, [Int])
|
||||
{
|
||||
let prompt = configuration.prepare(prompt: self.prompt)
|
||||
let promptTokens = tokenizer.encode(text: prompt)
|
||||
mutating func start() {
|
||||
if let cacheSize {
|
||||
GPU.set(cacheLimit: cacheSize * 1024 * 1024)
|
||||
}
|
||||
|
||||
return (prompt, promptTokens)
|
||||
if let memorySize {
|
||||
GPU.set(memoryLimit: memorySize * 1024 * 1024)
|
||||
}
|
||||
|
||||
startMemory = GPU.snapshot()
|
||||
}
|
||||
|
||||
func reportCurrent() {
|
||||
if memoryStats {
|
||||
let memory = GPU.snapshot()
|
||||
print(memory.description)
|
||||
}
|
||||
}
|
||||
|
||||
func reportMemoryStatistics() {
|
||||
@@ -106,164 +172,36 @@ struct LLMArguments: ParsableArguments {
|
||||
}
|
||||
}
|
||||
|
||||
struct SyncGenerator: AsyncParsableCommand {
|
||||
struct EvaluateCommand: AsyncParsableCommand {
|
||||
|
||||
static var configuration = CommandConfiguration(
|
||||
commandName: "sync",
|
||||
abstract: "Synchronous generator"
|
||||
commandName: "eval",
|
||||
abstract: "evaluate prompt and generate text"
|
||||
)
|
||||
|
||||
@OptionGroup var args: LLMArguments
|
||||
@OptionGroup var args: ModelArguments
|
||||
@OptionGroup var memory: MemoryArguments
|
||||
@OptionGroup var generate: GenerateArguments
|
||||
|
||||
@MainActor
|
||||
mutating func run() async throws {
|
||||
let (model, tokenizer, modelConfiguration) = try await args.load()
|
||||
let (model, tokenizer, modelConfiguration) = try await memory.start(args.load)
|
||||
|
||||
print("Model loaded -> \(modelConfiguration.id)")
|
||||
|
||||
let (prompt, promptTokens) = args.tokenizePropmpt(
|
||||
let (prompt, promptTokens) = generate.tokenizePrompt(
|
||||
configuration: modelConfiguration, tokenizer: tokenizer)
|
||||
|
||||
print("Starting generation ...")
|
||||
print(prompt, terminator: "")
|
||||
|
||||
var start = Date.timeIntervalSinceReferenceDate
|
||||
var promptTime: TimeInterval = 0
|
||||
|
||||
// collect the tokens and keep track of how much of the string
|
||||
// we have printed already
|
||||
var tokens = [Int]()
|
||||
var printed = 0
|
||||
|
||||
for token in TokenIterator(
|
||||
prompt: MLXArray(promptTokens), model: model, temp: args.temperature, topP: args.topP,
|
||||
repetitionPenalty: args.repetitionPenalty,
|
||||
repetitionContextSize: args.repetitionContextSize)
|
||||
{
|
||||
if tokens.isEmpty {
|
||||
eval(token)
|
||||
let now = Date.timeIntervalSinceReferenceDate
|
||||
promptTime = now - start
|
||||
start = now
|
||||
}
|
||||
|
||||
let t = token.item(Int.self)
|
||||
if t == tokenizer.unknownTokenId || t == tokenizer.eosTokenId {
|
||||
break
|
||||
}
|
||||
tokens.append(t)
|
||||
|
||||
// print any new parts of the string
|
||||
let fullOutput = tokenizer.decode(tokens: tokens)
|
||||
let emitLength = fullOutput.count - printed
|
||||
let suffix = fullOutput.suffix(emitLength)
|
||||
print(suffix, terminator: "")
|
||||
fflush(stdout)
|
||||
|
||||
printed = fullOutput.count
|
||||
|
||||
if tokens.count == args.maxTokens {
|
||||
break
|
||||
}
|
||||
}
|
||||
let result = await generate.generate(
|
||||
promptTokens: promptTokens, model: model, tokenizer: tokenizer)
|
||||
|
||||
print()
|
||||
print("------")
|
||||
let now = Date.timeIntervalSinceReferenceDate
|
||||
let generateTime = now - start
|
||||
print(result.summary())
|
||||
|
||||
print(
|
||||
"""
|
||||
Prompt Tokens per second: \((Double(promptTokens.count) / promptTime).formatted())
|
||||
Generation tokens per second: \((Double(tokens.count - 1) / generateTime).formatted())
|
||||
""")
|
||||
|
||||
args.reportMemoryStatistics()
|
||||
}
|
||||
}
|
||||
|
||||
/// Example of an async generator.
|
||||
///
|
||||
/// Note that all of the computation is done on another thread and TokenId (Int32) are sent
|
||||
/// rather than MLXArray.
|
||||
struct AsyncGenerator: AsyncParsableCommand {
|
||||
|
||||
static var configuration = CommandConfiguration(
|
||||
commandName: "async",
|
||||
abstract: "async generator"
|
||||
)
|
||||
|
||||
@OptionGroup var args: LLMArguments
|
||||
|
||||
@MainActor
|
||||
mutating func run() async throws {
|
||||
let (model, tokenizer, modelConfiguration) = try await args.load()
|
||||
|
||||
print("Model loaded -> \(modelConfiguration.id)")
|
||||
|
||||
let (prompt, promptTokens) = args.tokenizePropmpt(
|
||||
configuration: modelConfiguration, tokenizer: tokenizer)
|
||||
|
||||
print("Starting generation ...")
|
||||
print(prompt, terminator: "")
|
||||
|
||||
var start = Date.timeIntervalSinceReferenceDate
|
||||
var promptTime: TimeInterval = 0
|
||||
|
||||
// collect the tokens and keep track of how much of the string
|
||||
// we have printed already
|
||||
var tokens = [Int]()
|
||||
var printed = 0
|
||||
|
||||
let (task, channel) = generate(
|
||||
prompt: MLXArray(promptTokens), model: model, temp: args.temperature, topP: args.topP,
|
||||
repetitionPenalty: args.repetitionPenalty,
|
||||
repetitionContextSize: args.repetitionContextSize)
|
||||
|
||||
for await token in channel {
|
||||
if tokens.isEmpty {
|
||||
let now = Date.timeIntervalSinceReferenceDate
|
||||
promptTime = now - start
|
||||
start = now
|
||||
}
|
||||
|
||||
if token == tokenizer.unknownTokenId || token == tokenizer.eosTokenId {
|
||||
break
|
||||
}
|
||||
tokens.append(token)
|
||||
|
||||
// print any new parts of the string
|
||||
let fullOutput = tokenizer.decode(tokens: tokens)
|
||||
let emitLength = fullOutput.count - printed
|
||||
let suffix = fullOutput.suffix(emitLength)
|
||||
print(suffix, terminator: "")
|
||||
fflush(stdout)
|
||||
|
||||
printed = fullOutput.count
|
||||
|
||||
if tokens.count == args.maxTokens {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
// tell the task to stop
|
||||
task.cancel()
|
||||
|
||||
print()
|
||||
print("------")
|
||||
let now = Date.timeIntervalSinceReferenceDate
|
||||
let generateTime = now - start
|
||||
|
||||
print(
|
||||
"""
|
||||
Prompt Tokens per second: \((Double(promptTokens.count) / promptTime).formatted())
|
||||
Generation tokens per second: \((Double(tokens.count - 1) / generateTime).formatted())
|
||||
""")
|
||||
|
||||
args.reportMemoryStatistics()
|
||||
|
||||
// wait for the task to complete -- since it is running async, it might
|
||||
// be in the middle of running the model
|
||||
try? await Task.sleep(for: .milliseconds(500))
|
||||
memory.reportMemoryStatistics()
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user