// Copyright © 2024 Apple Inc. import ArgumentParser import Foundation import LLM import MLX import MLXRandom struct LLMTool: AsyncParsableCommand { static var configuration = CommandConfiguration( abstract: "Command line tool for generating text using Llama models", subcommands: [SyncGenerator.self, AsyncGenerator.self], defaultSubcommand: SyncGenerator.self) } @main struct SyncGenerator: AsyncParsableCommand { static var configuration = CommandConfiguration( commandName: "sync", abstract: "Synchronous generator" ) @Option(name: .long, help: "Name of the huggingface model") var model: String = "mlx-community/Mistral-7B-v0.1-hf-4bit-mlx" @Option(name: .shortAndLong, help: "The message to be processed by the model") var prompt = "compare swift and python" @Option(name: .shortAndLong, help: "Maximum number of tokens to generate") var maxTokens = 100 @Option(name: .shortAndLong, help: "The sampling temperature") var temperature: Float = 0.0 @Option(name: .long, help: "The PRNG seed") var seed: UInt64 = 0 @MainActor func run() async throws { MLXRandom.seed(seed) let (model, tokenizer) = try await load(name: model) print("Starting generation ...") print(prompt, terminator: "") var start = Date.timeIntervalSinceReferenceDate var promptTime: TimeInterval = 0 let prompt = MLXArray(tokenizer.encode(text: prompt)) // collect the tokens and keep track of how much of the string // we have printed already var tokens = [Int]() var printed = 0 for token in TokenIterator(prompt: prompt, model: model, temp: temperature) { if tokens.isEmpty { eval(token) let now = Date.timeIntervalSinceReferenceDate promptTime = now - start start = now } let t = token.item(Int.self) if t == tokenizer.unknownTokenId { break } tokens.append(t) // print any new parts of the string let fullOutput = tokenizer.decode(tokens: tokens) let emitLength = fullOutput.count - printed let suffix = fullOutput.suffix(emitLength) print(suffix, terminator: "") fflush(stdout) printed = fullOutput.count if tokens.count == maxTokens { break } } print() print("------") let now = Date.timeIntervalSinceReferenceDate let generateTime = now - start print( """ Prompt Tokens per second: \((Double(prompt.size) / promptTime).formatted()) Generation tokens per second: \((Double(tokens.count - 1) / generateTime).formatted()) """) } } /// Example of an async generator. /// /// Note that all of the computation is done on another thread and TokenId (Int32) are sent /// rather than MLXArray. struct AsyncGenerator: AsyncParsableCommand { static var configuration = CommandConfiguration( commandName: "async", abstract: "async generator" ) @Option(name: .long, help: "Name of the huggingface model") var model: String = "mlx-community/Mistral-7B-v0.1-hf-4bit-mlx" @Option(name: .shortAndLong, help: "The message to be processed by the model") var prompt = "compare swift and python" @Option(name: .shortAndLong, help: "Maximum number of tokens to generate") var maxTokens = 100 @Option(name: .shortAndLong, help: "The sampling temperature") var temperature: Float = 0.0 @Option(name: .long, help: "The PRNG seed") var seed: UInt64 = 0 @MainActor func run() async throws { MLXRandom.seed(seed) let (model, tokenizer) = try await load(name: model) print("Starting generation ...") print(prompt, terminator: "") var start = Date.timeIntervalSinceReferenceDate var promptTime: TimeInterval = 0 let prompt = MLXArray(tokenizer.encode(text: prompt)) // collect the tokens and keep track of how much of the string // we have printed already var tokens = [Int]() var printed = 0 let (task, channel) = generate(prompt: prompt, model: model, temp: temperature) for await token in channel { if tokens.isEmpty { let now = Date.timeIntervalSinceReferenceDate promptTime = now - start start = now } if token == tokenizer.unknownTokenId { break } tokens.append(token) // print any new parts of the string let fullOutput = tokenizer.decode(tokens: tokens) let emitLength = fullOutput.count - printed let suffix = fullOutput.suffix(emitLength) print(suffix, terminator: "") fflush(stdout) printed = fullOutput.count if tokens.count == maxTokens { break } } // tell the task to stop task.cancel() print() print("------") let now = Date.timeIntervalSinceReferenceDate let generateTime = now - start print( """ Prompt Tokens per second: \((Double(prompt.size) / promptTime).formatted()) Generation tokens per second: \((Double(tokens.count - 1) / generateTime).formatted()) """) // wait for the task to complete -- since it is running async, it might // be in the middle of running the model try? await Task.sleep(for: .milliseconds(500)) } }