prepare for lora branch (#47)
- remove async llm generation -- this is just doubling our work - and does not match the style used in the example applications - package generation parameters into a struct - refactor command line arguments into distinct pieces based on their use - this will be reusable in the lora commands
This commit is contained in:
@@ -160,7 +160,7 @@ class LLMEvaluator {
|
||||
let modelConfiguration = ModelConfiguration.phi4bit
|
||||
|
||||
/// parameters controlling the output
|
||||
let temperature: Float = 0.6
|
||||
let generateParameters = GenerateParameters(temperature: 0.6)
|
||||
let maxTokens = 240
|
||||
|
||||
/// update the display every N tokens -- 4 looks like it updates continuously
|
||||
@@ -201,7 +201,6 @@ class LLMEvaluator {
|
||||
}
|
||||
|
||||
func generate(prompt: String) async {
|
||||
let startTime = Date()
|
||||
do {
|
||||
let (model, tokenizer) = try await load()
|
||||
|
||||
@@ -212,59 +211,37 @@ class LLMEvaluator {
|
||||
|
||||
// augment the prompt as needed
|
||||
let prompt = modelConfiguration.prepare(prompt: prompt)
|
||||
let promptTokens = MLXArray(tokenizer.encode(text: prompt))
|
||||
|
||||
var initTime = Date()
|
||||
let initDuration = initTime.timeIntervalSince(startTime)
|
||||
await MainActor.run {
|
||||
self.stat = "Init: \(String(format: "%.3f", initDuration))s"
|
||||
}
|
||||
let promptTokens = tokenizer.encode(text: prompt)
|
||||
|
||||
// each time you generate you will get something new
|
||||
MLXRandom.seed(UInt64(Date.timeIntervalSinceReferenceDate * 1000))
|
||||
|
||||
var outputTokens = [Int]()
|
||||
|
||||
for token in TokenIterator(prompt: promptTokens, model: model, temp: temperature) {
|
||||
let tokenId = token.item(Int.self)
|
||||
|
||||
// to match the measurement from the command line we reset the start time
|
||||
// after the first token is generated (called the prompt time)
|
||||
if outputTokens.isEmpty {
|
||||
initTime = Date()
|
||||
}
|
||||
|
||||
if tokenId == tokenizer.unknownTokenId || tokenId == tokenizer.eosTokenId {
|
||||
break
|
||||
}
|
||||
|
||||
outputTokens.append(tokenId)
|
||||
let text = tokenizer.decode(tokens: outputTokens)
|
||||
|
||||
let result = await LLM.generate(
|
||||
promptTokens: promptTokens, parameters: generateParameters, model: model,
|
||||
tokenizer: tokenizer
|
||||
) { tokens in
|
||||
// update the output -- this will make the view show the text as it generates
|
||||
if outputTokens.count % displayEveryNTokens == 0 {
|
||||
if tokens.count % displayEveryNTokens == 0 {
|
||||
let text = tokenizer.decode(tokens: tokens)
|
||||
await MainActor.run {
|
||||
self.output = text
|
||||
}
|
||||
}
|
||||
|
||||
if outputTokens.count == maxTokens {
|
||||
break
|
||||
if tokens.count >= maxTokens {
|
||||
return .stop
|
||||
} else {
|
||||
return .more
|
||||
}
|
||||
}
|
||||
|
||||
let tokenDuration = Date().timeIntervalSince(initTime)
|
||||
let tokensPerSecond = Double(outputTokens.count) / tokenDuration
|
||||
|
||||
// update the text if needed, e.g. we haven't displayed because of displayEveryNTokens
|
||||
let finalText = tokenizer.decode(tokens: outputTokens)
|
||||
|
||||
await MainActor.run {
|
||||
if finalText != self.output {
|
||||
self.output = finalText
|
||||
if result.output != self.output {
|
||||
self.output = result.output
|
||||
}
|
||||
running = false
|
||||
self.stat += " Tokens/second: \(String(format: "%.3f", tokensPerSecond))"
|
||||
self.stat = " Tokens/second: \(String(format: "%.3f", result.tokensPerSecond))"
|
||||
}
|
||||
|
||||
} catch {
|
||||
|
||||
@@ -4,6 +4,7 @@ import AsyncAlgorithms
|
||||
import Foundation
|
||||
import MLX
|
||||
import MLXRandom
|
||||
import Tokenizers
|
||||
|
||||
private func topPSampling(logits: MLXArray, topP: Float, temp: Float) -> MLXArray {
|
||||
var logits = logits
|
||||
@@ -28,8 +29,6 @@ private func topPSampling(logits: MLXArray, topP: Float, temp: Float) -> MLXArra
|
||||
private func applyRepetitionPenalty(
|
||||
logits: MLXArray, repetitionContext: MLXArray, penalty: Float
|
||||
) -> MLXArray {
|
||||
var logits = logits
|
||||
|
||||
if repetitionContext.shape[0] > 0 {
|
||||
let indices = repetitionContext
|
||||
var selectedLogits = take(logits, indices, axis: -1).squeezed(axis: 0)
|
||||
@@ -55,37 +54,53 @@ private func sample(logits: MLXArray, temp: Float, topP: Float = 1.0) -> MLXArra
|
||||
}
|
||||
}
|
||||
|
||||
/// Parameters for text generation, see ``TokenIterator``
|
||||
public struct GenerateParameters {
|
||||
/// sampling temperature
|
||||
public var temperature: Float = 0.6
|
||||
|
||||
/// top p sampling
|
||||
public var topP: Float = 0.9
|
||||
|
||||
/// penalty factor for repeating tokens
|
||||
public var repetitionPenalty: Float = 1.0
|
||||
|
||||
/// number of tokens to consider for repetition penalty
|
||||
public var repetitionContextSize: Int = 20
|
||||
|
||||
public init(
|
||||
temperature: Float = 0.6, topP: Float = 0.9, repetitionPenalty: Float = 1.0,
|
||||
repetitionContextSize: Int = 20
|
||||
) {
|
||||
self.temperature = temperature
|
||||
self.topP = topP
|
||||
self.repetitionPenalty = repetitionPenalty
|
||||
self.repetitionContextSize = repetitionContextSize
|
||||
}
|
||||
}
|
||||
|
||||
/// Synchronous generator of tokens.
|
||||
///
|
||||
/// Port of `generate_step()` from https://github.com/ml-explore/mlx-examples/blob/main/llms/mlx_lm/utils.py
|
||||
public struct TokenIterator: Sequence, IteratorProtocol {
|
||||
let model: LLMModel
|
||||
let temp: Float
|
||||
let topP: Float
|
||||
let repetitionPenalty: Float
|
||||
let repetitionContextSize: Int
|
||||
let parameters: GenerateParameters
|
||||
var repetitionContext: MLXArray
|
||||
var y: MLXArray
|
||||
var cache: [(MLXArray, MLXArray)]
|
||||
|
||||
var first = true
|
||||
|
||||
public init(
|
||||
prompt: MLXArray, model: LLMModel, temp: Float = 0.0, topP: Float = 1.0,
|
||||
repetitionPenalty: Float = 1.0, repetitionContextSize: Int = 20
|
||||
) {
|
||||
public init(prompt: MLXArray, model: LLMModel, parameters: GenerateParameters) {
|
||||
self.model = model
|
||||
self.temp = temp
|
||||
self.topP = topP
|
||||
self.parameters = parameters
|
||||
self.y = prompt
|
||||
self.cache = []
|
||||
self.repetitionPenalty = repetitionPenalty
|
||||
self.repetitionContextSize = repetitionContextSize
|
||||
if repetitionContextSize > 1 {
|
||||
if prompt.shape[0] <= repetitionContextSize {
|
||||
if parameters.repetitionContextSize > 1 {
|
||||
if prompt.shape[0] <= parameters.repetitionContextSize {
|
||||
self.repetitionContext = prompt
|
||||
} else {
|
||||
self.repetitionContext = prompt[-repetitionContextSize ... -1]
|
||||
self.repetitionContext = prompt[-parameters.repetitionContextSize ... -1]
|
||||
}
|
||||
} else {
|
||||
self.repetitionContext = []
|
||||
@@ -96,16 +111,17 @@ public struct TokenIterator: Sequence, IteratorProtocol {
|
||||
var logits: MLXArray
|
||||
(logits, cache) = model(expandedDimensions(y, axis: 0), cache: cache.isEmpty ? nil : cache)
|
||||
logits = logits[0..., -1, 0...]
|
||||
if repetitionPenalty > 1.0 {
|
||||
if parameters.repetitionPenalty > 1.0 {
|
||||
// apply repetition penalty
|
||||
logits = applyRepetitionPenalty(
|
||||
logits: logits, repetitionContext: repetitionContext, penalty: repetitionPenalty)
|
||||
logits: logits, repetitionContext: repetitionContext,
|
||||
penalty: parameters.repetitionPenalty)
|
||||
}
|
||||
y = sample(logits: logits, temp: temp, topP: topP)
|
||||
y = sample(logits: logits, temp: parameters.temperature, topP: parameters.topP)
|
||||
// append the current token to the context and check repetitionPenalty context see if need to remove the first token
|
||||
if repetitionContextSize > 1 {
|
||||
if parameters.repetitionContextSize > 1 {
|
||||
repetitionContext = concatenated([repetitionContext, y], axis: 0)
|
||||
if repetitionContext.shape[0] > repetitionContextSize {
|
||||
if repetitionContext.shape[0] > parameters.repetitionContextSize {
|
||||
repetitionContext = repetitionContext[1...]
|
||||
}
|
||||
}
|
||||
@@ -114,61 +130,88 @@ public struct TokenIterator: Sequence, IteratorProtocol {
|
||||
}
|
||||
}
|
||||
|
||||
/// Async generator of tokens.
|
||||
///
|
||||
/// Port of `generate_step()` from https://github.com/ml-explore/mlx-examples/blob/main/llms/mlx_lm/utils.py.
|
||||
///
|
||||
/// Note that because MLXArray is not thread safe this eval's the result and sends the TokenId back
|
||||
/// to the caller.
|
||||
public func generate(
|
||||
prompt: MLXArray, model: LLMModel, temp: Float = 0.0, topP: Float = 1.0,
|
||||
repetitionPenalty: Float = 1.0, repetitionContextSize: Int = 20
|
||||
) -> (
|
||||
Task<Void, Never>, AsyncBufferSequence<AsyncChannel<Int>>
|
||||
) {
|
||||
let channel = AsyncChannel<Int>()
|
||||
let buffer = channel.buffer(policy: .bounded(10))
|
||||
public struct GenerateResult {
|
||||
/// input tokens
|
||||
public let promptTokens: [Int]
|
||||
|
||||
let task = Task {
|
||||
var y = prompt
|
||||
var cache = [(MLXArray, MLXArray)]()
|
||||
var repetitionContext: MLXArray
|
||||
/// output tokens
|
||||
public let tokens: [Int]
|
||||
|
||||
if repetitionContextSize > 1 {
|
||||
if prompt.shape[0] <= repetitionContextSize {
|
||||
repetitionContext = prompt
|
||||
} else {
|
||||
repetitionContext = prompt[-repetitionContextSize ... -1]
|
||||
}
|
||||
} else {
|
||||
repetitionContext = []
|
||||
}
|
||||
while !Task.isCancelled {
|
||||
var logits: MLXArray
|
||||
(logits, cache) = model(
|
||||
expandedDimensions(y, axis: 0), cache: cache.isEmpty ? nil : cache)
|
||||
/// output text
|
||||
public let output: String
|
||||
|
||||
logits = logits[0..., -1, 0...]
|
||||
if repetitionPenalty > 1.0 {
|
||||
// apply repetition penalty
|
||||
logits = applyRepetitionPenalty(
|
||||
logits: logits, repetitionContext: repetitionContext, penalty: repetitionPenalty
|
||||
)
|
||||
}
|
||||
y = sample(logits: logits, temp: temp, topP: topP)
|
||||
// append the current token to the context and check repetitionPenalty context see if need to remove the first token
|
||||
if repetitionContextSize > 1 {
|
||||
repetitionContext = concatenated([repetitionContext, y], axis: 0)
|
||||
if repetitionContext.shape[0] > repetitionContextSize {
|
||||
repetitionContext = repetitionContext[1...]
|
||||
}
|
||||
/// time to process the prompt / generate the first token
|
||||
public let promptTime: TimeInterval
|
||||
|
||||
/// time to generate the remaining tokens
|
||||
public let generateTime: TimeInterval
|
||||
|
||||
public var promptTokensPerSecond: Double {
|
||||
Double(promptTokens.count) / promptTime
|
||||
}
|
||||
|
||||
eval(y)
|
||||
|
||||
await channel.send(y.item(Int.self))
|
||||
}
|
||||
public var tokensPerSecond: Double {
|
||||
Double(tokens.count - 1) / generateTime
|
||||
}
|
||||
|
||||
return (task, buffer)
|
||||
public func summary() -> String {
|
||||
"""
|
||||
Prompt Tokens per second: \(promptTokensPerSecond.formatted())
|
||||
Generation tokens per second: \(tokensPerSecond.formatted())
|
||||
"""
|
||||
}
|
||||
}
|
||||
|
||||
public enum GenerateDisposition {
|
||||
case more
|
||||
case stop
|
||||
}
|
||||
|
||||
/// Given prompt tokens generate text using the given model and parameters.
|
||||
///
|
||||
/// - Parameters:
|
||||
/// - promptTokens: tokenized prompt
|
||||
/// - parameters: generation parameters
|
||||
/// - model: model to evaluate
|
||||
/// - tokenizer: tokenizer to convert tokens back into strings and recognizer special tokens
|
||||
/// - didGenerate: visitor for the tokens as they are generated
|
||||
public func generate(
|
||||
promptTokens: [Int], parameters: GenerateParameters, model: LLMModel, tokenizer: Tokenizer,
|
||||
didGenerate: ([Int]) async -> GenerateDisposition
|
||||
) async -> GenerateResult {
|
||||
var start = Date.timeIntervalSinceReferenceDate
|
||||
var promptTime: TimeInterval = 0
|
||||
|
||||
var tokens = [Int]()
|
||||
|
||||
for token in TokenIterator(
|
||||
prompt: MLXArray(promptTokens), model: model, parameters: parameters)
|
||||
{
|
||||
// compute the timing for the prompt
|
||||
if tokens.isEmpty {
|
||||
eval(token)
|
||||
let now = Date.timeIntervalSinceReferenceDate
|
||||
promptTime = now - start
|
||||
start = now
|
||||
}
|
||||
|
||||
let t = token.item(Int.self)
|
||||
if t == tokenizer.unknownTokenId || t == tokenizer.eosTokenId {
|
||||
break
|
||||
}
|
||||
|
||||
tokens.append(t)
|
||||
|
||||
if await didGenerate(tokens) == .stop {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
let now = Date.timeIntervalSinceReferenceDate
|
||||
let generateTime = now - start
|
||||
|
||||
return GenerateResult(
|
||||
promptTokens: promptTokens, tokens: tokens,
|
||||
output: tokenizer.decode(tokens: tokens),
|
||||
promptTime: promptTime, generateTime: generateTime)
|
||||
}
|
||||
|
||||
@@ -10,16 +10,27 @@ import Tokenizers
|
||||
@main
|
||||
struct LLMTool: AsyncParsableCommand {
|
||||
static var configuration = CommandConfiguration(
|
||||
abstract: "Command line tool for generating text using Llama models",
|
||||
subcommands: [SyncGenerator.self, AsyncGenerator.self],
|
||||
defaultSubcommand: SyncGenerator.self)
|
||||
abstract: "Command line tool for generating text and manipulating LLMs",
|
||||
subcommands: [EvaluateCommand.self],
|
||||
defaultSubcommand: EvaluateCommand.self)
|
||||
}
|
||||
|
||||
struct LLMArguments: ParsableArguments {
|
||||
/// Command line arguments for loading a model.
|
||||
struct ModelArguments: ParsableArguments {
|
||||
|
||||
@Option(name: .long, help: "Name of the huggingface model")
|
||||
var model: String = "mlx-community/Mistral-7B-v0.1-hf-4bit-mlx"
|
||||
|
||||
func load() async throws -> (LLMModel, Tokenizer, ModelConfiguration) {
|
||||
let modelConfiguration = ModelConfiguration.configuration(id: model)
|
||||
let (model, tokenizer) = try await LLM.load(configuration: modelConfiguration)
|
||||
return (model, tokenizer, modelConfiguration)
|
||||
}
|
||||
}
|
||||
|
||||
/// Command line arguments for controlling generation of text.
|
||||
struct GenerateArguments: ParsableArguments {
|
||||
|
||||
@Option(name: .shortAndLong, help: "The message to be processed by the model")
|
||||
var prompt = "compare python and swift"
|
||||
|
||||
@@ -29,19 +40,67 @@ struct LLMArguments: ParsableArguments {
|
||||
@Option(name: .shortAndLong, help: "The sampling temperature")
|
||||
var temperature: Float = 0.6
|
||||
|
||||
@Option(name: .shortAndLong, help: "The top p sampling")
|
||||
@Option(name: .long, help: "The top p sampling")
|
||||
var topP: Float = 0.9
|
||||
|
||||
@Option(name: .shortAndLong, help: "The penalty factor for repeating tokens")
|
||||
@Option(name: .long, help: "The penalty factor for repeating tokens")
|
||||
var repetitionPenalty: Float = 1.0
|
||||
|
||||
@Option(name: .shortAndLong, help: "The number of tokens to consider for repetition penalty")
|
||||
@Option(name: .long, help: "The number of tokens to consider for repetition penalty")
|
||||
var repetitionContextSize: Int = 20
|
||||
|
||||
@Option(name: .long, help: "The PRNG seed")
|
||||
var seed: UInt64 = 0
|
||||
|
||||
@Flag(help: "Show memory stats")
|
||||
var generateParameters: GenerateParameters {
|
||||
GenerateParameters(
|
||||
temperature: temperature, topP: topP, repetitionPenalty: repetitionPenalty,
|
||||
repetitionContextSize: repetitionContextSize)
|
||||
}
|
||||
|
||||
func tokenizePrompt(configuration: ModelConfiguration, tokenizer: Tokenizer) -> (String, [Int])
|
||||
{
|
||||
MLXRandom.seed(seed)
|
||||
|
||||
let prompt = configuration.prepare(prompt: self.prompt)
|
||||
let promptTokens = tokenizer.encode(text: prompt)
|
||||
|
||||
return (prompt, promptTokens)
|
||||
}
|
||||
|
||||
func generate(promptTokens: [Int], model: LLMModel, tokenizer: Tokenizer) async
|
||||
-> GenerateResult
|
||||
{
|
||||
// track how much we have printed
|
||||
var printed = 0
|
||||
|
||||
return await LLM.generate(
|
||||
promptTokens: promptTokens, parameters: generateParameters,
|
||||
model: model, tokenizer: tokenizer
|
||||
) { tokens in
|
||||
|
||||
// print any new parts of the string
|
||||
let fullOutput = tokenizer.decode(tokens: tokens)
|
||||
let emitLength = fullOutput.count - printed
|
||||
let suffix = fullOutput.suffix(emitLength)
|
||||
print(suffix, terminator: "")
|
||||
fflush(stdout)
|
||||
|
||||
printed = fullOutput.count
|
||||
|
||||
if tokens.count >= maxTokens {
|
||||
return .stop
|
||||
} else {
|
||||
return .more
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Argument package for adjusting and reporting memory use.
|
||||
struct MemoryArguments: ParsableArguments {
|
||||
|
||||
@Flag(name: .long, help: "Show memory stats")
|
||||
var memoryStats = false
|
||||
|
||||
@Option(name: .long, help: "Maximum cache size in M")
|
||||
@@ -52,9 +111,7 @@ struct LLMArguments: ParsableArguments {
|
||||
|
||||
var startMemory: GPU.Snapshot?
|
||||
|
||||
mutating func load() async throws -> (LLMModel, Tokenizer, ModelConfiguration) {
|
||||
MLXRandom.seed(seed)
|
||||
|
||||
mutating func start<L>(_ load: () async throws -> L) async throws -> L {
|
||||
if let cacheSize {
|
||||
GPU.set(cacheLimit: cacheSize * 1024 * 1024)
|
||||
}
|
||||
@@ -63,20 +120,29 @@ struct LLMArguments: ParsableArguments {
|
||||
GPU.set(memoryLimit: memorySize * 1024 * 1024)
|
||||
}
|
||||
|
||||
let modelConfiguration = ModelConfiguration.configuration(id: model)
|
||||
let (model, tokenizer) = try await LLM.load(configuration: modelConfiguration)
|
||||
|
||||
let result = try await load()
|
||||
startMemory = GPU.snapshot()
|
||||
|
||||
return (model, tokenizer, modelConfiguration)
|
||||
return result
|
||||
}
|
||||
|
||||
func tokenizePropmpt(configuration: ModelConfiguration, tokenizer: Tokenizer) -> (String, [Int])
|
||||
{
|
||||
let prompt = configuration.prepare(prompt: self.prompt)
|
||||
let promptTokens = tokenizer.encode(text: prompt)
|
||||
mutating func start() {
|
||||
if let cacheSize {
|
||||
GPU.set(cacheLimit: cacheSize * 1024 * 1024)
|
||||
}
|
||||
|
||||
return (prompt, promptTokens)
|
||||
if let memorySize {
|
||||
GPU.set(memoryLimit: memorySize * 1024 * 1024)
|
||||
}
|
||||
|
||||
startMemory = GPU.snapshot()
|
||||
}
|
||||
|
||||
func reportCurrent() {
|
||||
if memoryStats {
|
||||
let memory = GPU.snapshot()
|
||||
print(memory.description)
|
||||
}
|
||||
}
|
||||
|
||||
func reportMemoryStatistics() {
|
||||
@@ -106,164 +172,36 @@ struct LLMArguments: ParsableArguments {
|
||||
}
|
||||
}
|
||||
|
||||
struct SyncGenerator: AsyncParsableCommand {
|
||||
struct EvaluateCommand: AsyncParsableCommand {
|
||||
|
||||
static var configuration = CommandConfiguration(
|
||||
commandName: "sync",
|
||||
abstract: "Synchronous generator"
|
||||
commandName: "eval",
|
||||
abstract: "evaluate prompt and generate text"
|
||||
)
|
||||
|
||||
@OptionGroup var args: LLMArguments
|
||||
@OptionGroup var args: ModelArguments
|
||||
@OptionGroup var memory: MemoryArguments
|
||||
@OptionGroup var generate: GenerateArguments
|
||||
|
||||
@MainActor
|
||||
mutating func run() async throws {
|
||||
let (model, tokenizer, modelConfiguration) = try await args.load()
|
||||
let (model, tokenizer, modelConfiguration) = try await memory.start(args.load)
|
||||
|
||||
print("Model loaded -> \(modelConfiguration.id)")
|
||||
|
||||
let (prompt, promptTokens) = args.tokenizePropmpt(
|
||||
let (prompt, promptTokens) = generate.tokenizePrompt(
|
||||
configuration: modelConfiguration, tokenizer: tokenizer)
|
||||
|
||||
print("Starting generation ...")
|
||||
print(prompt, terminator: "")
|
||||
|
||||
var start = Date.timeIntervalSinceReferenceDate
|
||||
var promptTime: TimeInterval = 0
|
||||
|
||||
// collect the tokens and keep track of how much of the string
|
||||
// we have printed already
|
||||
var tokens = [Int]()
|
||||
var printed = 0
|
||||
|
||||
for token in TokenIterator(
|
||||
prompt: MLXArray(promptTokens), model: model, temp: args.temperature, topP: args.topP,
|
||||
repetitionPenalty: args.repetitionPenalty,
|
||||
repetitionContextSize: args.repetitionContextSize)
|
||||
{
|
||||
if tokens.isEmpty {
|
||||
eval(token)
|
||||
let now = Date.timeIntervalSinceReferenceDate
|
||||
promptTime = now - start
|
||||
start = now
|
||||
}
|
||||
|
||||
let t = token.item(Int.self)
|
||||
if t == tokenizer.unknownTokenId || t == tokenizer.eosTokenId {
|
||||
break
|
||||
}
|
||||
tokens.append(t)
|
||||
|
||||
// print any new parts of the string
|
||||
let fullOutput = tokenizer.decode(tokens: tokens)
|
||||
let emitLength = fullOutput.count - printed
|
||||
let suffix = fullOutput.suffix(emitLength)
|
||||
print(suffix, terminator: "")
|
||||
fflush(stdout)
|
||||
|
||||
printed = fullOutput.count
|
||||
|
||||
if tokens.count == args.maxTokens {
|
||||
break
|
||||
}
|
||||
}
|
||||
let result = await generate.generate(
|
||||
promptTokens: promptTokens, model: model, tokenizer: tokenizer)
|
||||
|
||||
print()
|
||||
print("------")
|
||||
let now = Date.timeIntervalSinceReferenceDate
|
||||
let generateTime = now - start
|
||||
print(result.summary())
|
||||
|
||||
print(
|
||||
"""
|
||||
Prompt Tokens per second: \((Double(promptTokens.count) / promptTime).formatted())
|
||||
Generation tokens per second: \((Double(tokens.count - 1) / generateTime).formatted())
|
||||
""")
|
||||
|
||||
args.reportMemoryStatistics()
|
||||
}
|
||||
}
|
||||
|
||||
/// Example of an async generator.
|
||||
///
|
||||
/// Note that all of the computation is done on another thread and TokenId (Int32) are sent
|
||||
/// rather than MLXArray.
|
||||
struct AsyncGenerator: AsyncParsableCommand {
|
||||
|
||||
static var configuration = CommandConfiguration(
|
||||
commandName: "async",
|
||||
abstract: "async generator"
|
||||
)
|
||||
|
||||
@OptionGroup var args: LLMArguments
|
||||
|
||||
@MainActor
|
||||
mutating func run() async throws {
|
||||
let (model, tokenizer, modelConfiguration) = try await args.load()
|
||||
|
||||
print("Model loaded -> \(modelConfiguration.id)")
|
||||
|
||||
let (prompt, promptTokens) = args.tokenizePropmpt(
|
||||
configuration: modelConfiguration, tokenizer: tokenizer)
|
||||
|
||||
print("Starting generation ...")
|
||||
print(prompt, terminator: "")
|
||||
|
||||
var start = Date.timeIntervalSinceReferenceDate
|
||||
var promptTime: TimeInterval = 0
|
||||
|
||||
// collect the tokens and keep track of how much of the string
|
||||
// we have printed already
|
||||
var tokens = [Int]()
|
||||
var printed = 0
|
||||
|
||||
let (task, channel) = generate(
|
||||
prompt: MLXArray(promptTokens), model: model, temp: args.temperature, topP: args.topP,
|
||||
repetitionPenalty: args.repetitionPenalty,
|
||||
repetitionContextSize: args.repetitionContextSize)
|
||||
|
||||
for await token in channel {
|
||||
if tokens.isEmpty {
|
||||
let now = Date.timeIntervalSinceReferenceDate
|
||||
promptTime = now - start
|
||||
start = now
|
||||
}
|
||||
|
||||
if token == tokenizer.unknownTokenId || token == tokenizer.eosTokenId {
|
||||
break
|
||||
}
|
||||
tokens.append(token)
|
||||
|
||||
// print any new parts of the string
|
||||
let fullOutput = tokenizer.decode(tokens: tokens)
|
||||
let emitLength = fullOutput.count - printed
|
||||
let suffix = fullOutput.suffix(emitLength)
|
||||
print(suffix, terminator: "")
|
||||
fflush(stdout)
|
||||
|
||||
printed = fullOutput.count
|
||||
|
||||
if tokens.count == args.maxTokens {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
// tell the task to stop
|
||||
task.cancel()
|
||||
|
||||
print()
|
||||
print("------")
|
||||
let now = Date.timeIntervalSinceReferenceDate
|
||||
let generateTime = now - start
|
||||
|
||||
print(
|
||||
"""
|
||||
Prompt Tokens per second: \((Double(promptTokens.count) / promptTime).formatted())
|
||||
Generation tokens per second: \((Double(tokens.count - 1) / generateTime).formatted())
|
||||
""")
|
||||
|
||||
args.reportMemoryStatistics()
|
||||
|
||||
// wait for the task to complete -- since it is running async, it might
|
||||
// be in the middle of running the model
|
||||
try? await Task.sleep(for: .milliseconds(500))
|
||||
memory.reportMemoryStatistics()
|
||||
}
|
||||
}
|
||||
|
||||
@@ -42,11 +42,9 @@
|
||||
C3932D592B6A0BE400A81055 /* Random.swift in Sources */ = {isa = PBXBuildFile; fileRef = C3932D582B6A0BE400A81055 /* Random.swift */; };
|
||||
C397C59C2B62C6D0004B084D /* ArgumentParser in Frameworks */ = {isa = PBXBuildFile; productRef = C397C59B2B62C6D0004B084D /* ArgumentParser */; };
|
||||
C3A8B3AC2B9283150002EFB8 /* Models.swift in Sources */ = {isa = PBXBuildFile; fileRef = C3A8B3AB2B9283150002EFB8 /* Models.swift */; };
|
||||
C3A8B3CA2B92951E0002EFB8 /* MNISTTrainer-Info.plist in Resources */ = {isa = PBXBuildFile; fileRef = C3A8B3C22B92951E0002EFB8 /* MNISTTrainer-Info.plist */; };
|
||||
C3A8B3CB2B92951E0002EFB8 /* Assets.xcassets in Resources */ = {isa = PBXBuildFile; fileRef = C3A8B3C32B92951E0002EFB8 /* Assets.xcassets */; };
|
||||
C3A8B3CC2B92951E0002EFB8 /* MNISTTrainerApp.swift in Sources */ = {isa = PBXBuildFile; fileRef = C3A8B3C42B92951E0002EFB8 /* MNISTTrainerApp.swift */; };
|
||||
C3A8B3CD2B92951E0002EFB8 /* Preview Assets.xcassets in Resources */ = {isa = PBXBuildFile; fileRef = C3A8B3C62B92951E0002EFB8 /* Preview Assets.xcassets */; };
|
||||
C3A8B3CE2B92951E0002EFB8 /* README.md in Resources */ = {isa = PBXBuildFile; fileRef = C3A8B3C82B92951E0002EFB8 /* README.md */; };
|
||||
C3A8B3CF2B92951E0002EFB8 /* ContentView.swift in Sources */ = {isa = PBXBuildFile; fileRef = C3A8B3C92B92951E0002EFB8 /* ContentView.swift */; };
|
||||
C3A8B3D22B92A0880002EFB8 /* MLXOptimizers in Frameworks */ = {isa = PBXBuildFile; productRef = C3A8B3D12B92A0880002EFB8 /* MLXOptimizers */; };
|
||||
C3A8B3D32B92A0880002EFB8 /* MNIST.framework in Frameworks */ = {isa = PBXBuildFile; fileRef = C34E490D2B69A92900FCB841 /* MNIST.framework */; };
|
||||
@@ -54,7 +52,6 @@
|
||||
C3A8B3F32B92A2A90002EFB8 /* Assets.xcassets in Resources */ = {isa = PBXBuildFile; fileRef = C3A8B3EC2B92A2A90002EFB8 /* Assets.xcassets */; };
|
||||
C3A8B3F42B92A2A90002EFB8 /* LLMEvalApp.swift in Sources */ = {isa = PBXBuildFile; fileRef = C3A8B3ED2B92A2A90002EFB8 /* LLMEvalApp.swift */; };
|
||||
C3A8B3F52B92A2A90002EFB8 /* Preview Assets.xcassets in Resources */ = {isa = PBXBuildFile; fileRef = C3A8B3EF2B92A2A90002EFB8 /* Preview Assets.xcassets */; };
|
||||
C3A8B3F62B92A2A90002EFB8 /* README.md in Resources */ = {isa = PBXBuildFile; fileRef = C3A8B3F02B92A2A90002EFB8 /* README.md */; };
|
||||
C3A8B3F72B92A2A90002EFB8 /* ContentView.swift in Sources */ = {isa = PBXBuildFile; fileRef = C3A8B3F22B92A2A90002EFB8 /* ContentView.swift */; };
|
||||
C3A8B3F82B92A3360002EFB8 /* LLM.framework in Frameworks */ = {isa = PBXBuildFile; fileRef = C38935C52B869C7A0037B833 /* LLM.framework */; };
|
||||
C3A8B3F92B92A3360002EFB8 /* LLM.framework in Embed Frameworks */ = {isa = PBXBuildFile; fileRef = C38935C52B869C7A0037B833 /* LLM.framework */; settings = {ATTRIBUTES = (CodeSignOnCopy, RemoveHeadersOnCopy, ); }; };
|
||||
@@ -801,8 +798,6 @@
|
||||
isa = PBXResourcesBuildPhase;
|
||||
buildActionMask = 2147483647;
|
||||
files = (
|
||||
C3A8B3CE2B92951E0002EFB8 /* README.md in Resources */,
|
||||
C3A8B3CA2B92951E0002EFB8 /* MNISTTrainer-Info.plist in Resources */,
|
||||
C3A8B3CD2B92951E0002EFB8 /* Preview Assets.xcassets in Resources */,
|
||||
C3A8B3CB2B92951E0002EFB8 /* Assets.xcassets in Resources */,
|
||||
);
|
||||
@@ -813,7 +808,6 @@
|
||||
buildActionMask = 2147483647;
|
||||
files = (
|
||||
C3A8B3F52B92A2A90002EFB8 /* Preview Assets.xcassets in Resources */,
|
||||
C3A8B3F62B92A2A90002EFB8 /* README.md in Resources */,
|
||||
C3A8B3F32B92A2A90002EFB8 /* Assets.xcassets in Resources */,
|
||||
);
|
||||
runOnlyForDeploymentPostprocessing = 0;
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
{
|
||||
"originHash" : "da53546673b6d05016b6e5640c18814c7dba5b5af8db34715afe6d633037c758",
|
||||
"pins" : [
|
||||
{
|
||||
"identity" : "gzipswift",
|
||||
@@ -15,7 +16,7 @@
|
||||
"location" : "https://github.com/ml-explore/mlx-swift",
|
||||
"state" : {
|
||||
"branch" : "main",
|
||||
"revision" : "a1c544c817d44cfdfa1a650f521066b565c2ae4f"
|
||||
"revision" : "b4d3e4bbbe41e6dc7c46d5ba075049ae7177961b"
|
||||
}
|
||||
},
|
||||
{
|
||||
@@ -82,5 +83,5 @@
|
||||
}
|
||||
}
|
||||
],
|
||||
"version" : 2
|
||||
"version" : 3
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user