initial commit
This commit is contained in:
190
Tools/llm-tool/LLMTool.swift
Normal file
190
Tools/llm-tool/LLMTool.swift
Normal file
@@ -0,0 +1,190 @@
|
||||
// Copyright © 2024 Apple Inc.
|
||||
|
||||
import ArgumentParser
|
||||
import Foundation
|
||||
import LLM
|
||||
import MLX
|
||||
import MLXRandom
|
||||
|
||||
struct LLMTool: AsyncParsableCommand {
|
||||
static var configuration = CommandConfiguration(
|
||||
abstract: "Command line tool for generating text using Llama models",
|
||||
subcommands: [SyncGenerator.self, AsyncGenerator.self],
|
||||
defaultSubcommand: SyncGenerator.self)
|
||||
}
|
||||
|
||||
@main
|
||||
struct SyncGenerator: AsyncParsableCommand {
|
||||
|
||||
static var configuration = CommandConfiguration(
|
||||
commandName: "sync",
|
||||
abstract: "Synchronous generator"
|
||||
)
|
||||
|
||||
@Option(name: .long, help: "Name of the huggingface model")
|
||||
var model: String = "mlx-community/Mistral-7B-v0.1-hf-4bit-mlx"
|
||||
|
||||
@Option(name: .shortAndLong, help: "The message to be processed by the model")
|
||||
var prompt = "compare swift and python"
|
||||
|
||||
@Option(name: .shortAndLong, help: "Maximum number of tokens to generate")
|
||||
var maxTokens = 100
|
||||
|
||||
@Option(name: .shortAndLong, help: "The sampling temperature")
|
||||
var temperature: Float = 0.0
|
||||
|
||||
@Option(name: .long, help: "The PRNG seed")
|
||||
var seed: UInt64 = 0
|
||||
|
||||
@MainActor
|
||||
func run() async throws {
|
||||
MLXRandom.seed(seed)
|
||||
|
||||
let (model, tokenizer) = try await load(name: model)
|
||||
|
||||
print("Starting generation ...")
|
||||
print(prompt, terminator: "")
|
||||
|
||||
var start = Date.timeIntervalSinceReferenceDate
|
||||
var promptTime: TimeInterval = 0
|
||||
|
||||
let prompt = MLXArray(tokenizer.encode(text: prompt))
|
||||
|
||||
// collect the tokens and keep track of how much of the string
|
||||
// we have printed already
|
||||
var tokens = [Int]()
|
||||
var printed = 0
|
||||
|
||||
for token in TokenIterator(prompt: prompt, model: model, temp: temperature) {
|
||||
if tokens.isEmpty {
|
||||
eval(token)
|
||||
let now = Date.timeIntervalSinceReferenceDate
|
||||
promptTime = now - start
|
||||
start = now
|
||||
}
|
||||
|
||||
let t = token.item(Int.self)
|
||||
if t == tokenizer.unknownTokenId {
|
||||
break
|
||||
}
|
||||
tokens.append(t)
|
||||
|
||||
// print any new parts of the string
|
||||
let fullOutput = tokenizer.decode(tokens: tokens)
|
||||
let emitLength = fullOutput.count - printed
|
||||
let suffix = fullOutput.suffix(emitLength)
|
||||
print(suffix, terminator: "")
|
||||
fflush(stdout)
|
||||
|
||||
printed = fullOutput.count
|
||||
|
||||
if tokens.count == maxTokens {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
print()
|
||||
print("------")
|
||||
let now = Date.timeIntervalSinceReferenceDate
|
||||
let generateTime = now - start
|
||||
|
||||
print(
|
||||
"""
|
||||
Prompt Tokens per second: \((Double(prompt.size) / promptTime).formatted())
|
||||
Generation tokens per second: \((Double(tokens.count - 1) / generateTime).formatted())
|
||||
""")
|
||||
}
|
||||
}
|
||||
|
||||
/// Example of an async generator.
|
||||
///
|
||||
/// Note that all of the computation is done on another thread and TokenId (Int32) are sent
|
||||
/// rather than MLXArray.
|
||||
struct AsyncGenerator: AsyncParsableCommand {
|
||||
|
||||
static var configuration = CommandConfiguration(
|
||||
commandName: "async",
|
||||
abstract: "async generator"
|
||||
)
|
||||
|
||||
@Option(name: .long, help: "Name of the huggingface model")
|
||||
var model: String = "mlx-community/Mistral-7B-v0.1-hf-4bit-mlx"
|
||||
|
||||
@Option(name: .shortAndLong, help: "The message to be processed by the model")
|
||||
var prompt = "compare swift and python"
|
||||
|
||||
@Option(name: .shortAndLong, help: "Maximum number of tokens to generate")
|
||||
var maxTokens = 100
|
||||
|
||||
@Option(name: .shortAndLong, help: "The sampling temperature")
|
||||
var temperature: Float = 0.0
|
||||
|
||||
@Option(name: .long, help: "The PRNG seed")
|
||||
var seed: UInt64 = 0
|
||||
|
||||
@MainActor
|
||||
func run() async throws {
|
||||
MLXRandom.seed(seed)
|
||||
|
||||
let (model, tokenizer) = try await load(name: model)
|
||||
|
||||
print("Starting generation ...")
|
||||
print(prompt, terminator: "")
|
||||
|
||||
var start = Date.timeIntervalSinceReferenceDate
|
||||
var promptTime: TimeInterval = 0
|
||||
|
||||
let prompt = MLXArray(tokenizer.encode(text: prompt))
|
||||
|
||||
// collect the tokens and keep track of how much of the string
|
||||
// we have printed already
|
||||
var tokens = [Int]()
|
||||
var printed = 0
|
||||
|
||||
let (task, channel) = generate(prompt: prompt, model: model, temp: temperature)
|
||||
|
||||
for await token in channel {
|
||||
if tokens.isEmpty {
|
||||
let now = Date.timeIntervalSinceReferenceDate
|
||||
promptTime = now - start
|
||||
start = now
|
||||
}
|
||||
|
||||
if token == tokenizer.unknownTokenId {
|
||||
break
|
||||
}
|
||||
tokens.append(token)
|
||||
|
||||
// print any new parts of the string
|
||||
let fullOutput = tokenizer.decode(tokens: tokens)
|
||||
let emitLength = fullOutput.count - printed
|
||||
let suffix = fullOutput.suffix(emitLength)
|
||||
print(suffix, terminator: "")
|
||||
fflush(stdout)
|
||||
|
||||
printed = fullOutput.count
|
||||
|
||||
if tokens.count == maxTokens {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
// tell the task to stop
|
||||
task.cancel()
|
||||
|
||||
print()
|
||||
print("------")
|
||||
let now = Date.timeIntervalSinceReferenceDate
|
||||
let generateTime = now - start
|
||||
|
||||
print(
|
||||
"""
|
||||
Prompt Tokens per second: \((Double(prompt.size) / promptTime).formatted())
|
||||
Generation tokens per second: \((Double(tokens.count - 1) / generateTime).formatted())
|
||||
""")
|
||||
|
||||
// wait for the task to complete -- since it is running async, it might
|
||||
// be in the middle of running the model
|
||||
try? await Task.sleep(for: .milliseconds(500))
|
||||
}
|
||||
}
|
||||
38
Tools/llm-tool/README.md
Normal file
38
Tools/llm-tool/README.md
Normal file
@@ -0,0 +1,38 @@
|
||||
# llm-tool
|
||||
|
||||
See various READMEs:
|
||||
|
||||
- [Llama](../../Libraries/Llama/README.md)
|
||||
|
||||
### Building
|
||||
|
||||
Build the `llm-tool` scheme in Xcode.
|
||||
|
||||
### Running (Xcode)
|
||||
|
||||
To run this in Xcode simply press cmd-opt-r to set the scheme arguments. For example:
|
||||
|
||||
```
|
||||
--model mlx-community/Mistral-7B-v0.1-hf-4bit-mlx
|
||||
--prompt "swift programming language"
|
||||
--max-tokens 50
|
||||
```
|
||||
|
||||
Then cmd-r to run.
|
||||
|
||||
> Note: you may be prompted for access to your Documents directory -- this is where
|
||||
the huggingface HubApi stores the downloaded files.
|
||||
|
||||
### Running (Command Line)
|
||||
|
||||
`llm-tool` can also be run from the command line if built from Xcode, but
|
||||
the `DYLD_FRAMEWORK_PATH` must be set so that the frameworks and bundles can be found:
|
||||
|
||||
- [MLX troubleshooting](https://ml-explore.github.io/mlx-swift/MLX/documentation/mlx/troubleshooting)
|
||||
|
||||
The easiest way to do this is drag the Products/llm-tool into Terminal to get the path:
|
||||
|
||||
```
|
||||
DYLD_FRAMEWORK_PATH=~/Library/Developer/Xcode/DerivedData/mlx-examples-swift-ceuohnhzsownvsbbleukxoksddja/Build/Products/Debug ~/Library/Developer/Xcode/DerivedData/mlx-examples-swift-ceuohnhzsownvsbbleukxoksddja/Build/Products/Debug/llm-tool --prompt "swift programming language"
|
||||
```
|
||||
|
||||
Reference in New Issue
Block a user