prepare for lora branch (#47)
- remove async llm generation -- this is just doubling our work - and does not match the style used in the example applications - package generation parameters into a struct - refactor command line arguments into distinct pieces based on their use - this will be reusable in the lora commands
This commit is contained in:
@@ -160,7 +160,7 @@ class LLMEvaluator {
|
|||||||
let modelConfiguration = ModelConfiguration.phi4bit
|
let modelConfiguration = ModelConfiguration.phi4bit
|
||||||
|
|
||||||
/// parameters controlling the output
|
/// parameters controlling the output
|
||||||
let temperature: Float = 0.6
|
let generateParameters = GenerateParameters(temperature: 0.6)
|
||||||
let maxTokens = 240
|
let maxTokens = 240
|
||||||
|
|
||||||
/// update the display every N tokens -- 4 looks like it updates continuously
|
/// update the display every N tokens -- 4 looks like it updates continuously
|
||||||
@@ -201,7 +201,6 @@ class LLMEvaluator {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func generate(prompt: String) async {
|
func generate(prompt: String) async {
|
||||||
let startTime = Date()
|
|
||||||
do {
|
do {
|
||||||
let (model, tokenizer) = try await load()
|
let (model, tokenizer) = try await load()
|
||||||
|
|
||||||
@@ -212,59 +211,37 @@ class LLMEvaluator {
|
|||||||
|
|
||||||
// augment the prompt as needed
|
// augment the prompt as needed
|
||||||
let prompt = modelConfiguration.prepare(prompt: prompt)
|
let prompt = modelConfiguration.prepare(prompt: prompt)
|
||||||
let promptTokens = MLXArray(tokenizer.encode(text: prompt))
|
let promptTokens = tokenizer.encode(text: prompt)
|
||||||
|
|
||||||
var initTime = Date()
|
|
||||||
let initDuration = initTime.timeIntervalSince(startTime)
|
|
||||||
await MainActor.run {
|
|
||||||
self.stat = "Init: \(String(format: "%.3f", initDuration))s"
|
|
||||||
}
|
|
||||||
|
|
||||||
// each time you generate you will get something new
|
// each time you generate you will get something new
|
||||||
MLXRandom.seed(UInt64(Date.timeIntervalSinceReferenceDate * 1000))
|
MLXRandom.seed(UInt64(Date.timeIntervalSinceReferenceDate * 1000))
|
||||||
|
|
||||||
var outputTokens = [Int]()
|
let result = await LLM.generate(
|
||||||
|
promptTokens: promptTokens, parameters: generateParameters, model: model,
|
||||||
for token in TokenIterator(prompt: promptTokens, model: model, temp: temperature) {
|
tokenizer: tokenizer
|
||||||
let tokenId = token.item(Int.self)
|
) { tokens in
|
||||||
|
|
||||||
// to match the measurement from the command line we reset the start time
|
|
||||||
// after the first token is generated (called the prompt time)
|
|
||||||
if outputTokens.isEmpty {
|
|
||||||
initTime = Date()
|
|
||||||
}
|
|
||||||
|
|
||||||
if tokenId == tokenizer.unknownTokenId || tokenId == tokenizer.eosTokenId {
|
|
||||||
break
|
|
||||||
}
|
|
||||||
|
|
||||||
outputTokens.append(tokenId)
|
|
||||||
let text = tokenizer.decode(tokens: outputTokens)
|
|
||||||
|
|
||||||
// update the output -- this will make the view show the text as it generates
|
// update the output -- this will make the view show the text as it generates
|
||||||
if outputTokens.count % displayEveryNTokens == 0 {
|
if tokens.count % displayEveryNTokens == 0 {
|
||||||
|
let text = tokenizer.decode(tokens: tokens)
|
||||||
await MainActor.run {
|
await MainActor.run {
|
||||||
self.output = text
|
self.output = text
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if outputTokens.count == maxTokens {
|
if tokens.count >= maxTokens {
|
||||||
break
|
return .stop
|
||||||
|
} else {
|
||||||
|
return .more
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
let tokenDuration = Date().timeIntervalSince(initTime)
|
|
||||||
let tokensPerSecond = Double(outputTokens.count) / tokenDuration
|
|
||||||
|
|
||||||
// update the text if needed, e.g. we haven't displayed because of displayEveryNTokens
|
// update the text if needed, e.g. we haven't displayed because of displayEveryNTokens
|
||||||
let finalText = tokenizer.decode(tokens: outputTokens)
|
|
||||||
|
|
||||||
await MainActor.run {
|
await MainActor.run {
|
||||||
if finalText != self.output {
|
if result.output != self.output {
|
||||||
self.output = finalText
|
self.output = result.output
|
||||||
}
|
}
|
||||||
running = false
|
running = false
|
||||||
self.stat += " Tokens/second: \(String(format: "%.3f", tokensPerSecond))"
|
self.stat = " Tokens/second: \(String(format: "%.3f", result.tokensPerSecond))"
|
||||||
}
|
}
|
||||||
|
|
||||||
} catch {
|
} catch {
|
||||||
|
|||||||
@@ -4,6 +4,7 @@ import AsyncAlgorithms
|
|||||||
import Foundation
|
import Foundation
|
||||||
import MLX
|
import MLX
|
||||||
import MLXRandom
|
import MLXRandom
|
||||||
|
import Tokenizers
|
||||||
|
|
||||||
private func topPSampling(logits: MLXArray, topP: Float, temp: Float) -> MLXArray {
|
private func topPSampling(logits: MLXArray, topP: Float, temp: Float) -> MLXArray {
|
||||||
var logits = logits
|
var logits = logits
|
||||||
@@ -28,8 +29,6 @@ private func topPSampling(logits: MLXArray, topP: Float, temp: Float) -> MLXArra
|
|||||||
private func applyRepetitionPenalty(
|
private func applyRepetitionPenalty(
|
||||||
logits: MLXArray, repetitionContext: MLXArray, penalty: Float
|
logits: MLXArray, repetitionContext: MLXArray, penalty: Float
|
||||||
) -> MLXArray {
|
) -> MLXArray {
|
||||||
var logits = logits
|
|
||||||
|
|
||||||
if repetitionContext.shape[0] > 0 {
|
if repetitionContext.shape[0] > 0 {
|
||||||
let indices = repetitionContext
|
let indices = repetitionContext
|
||||||
var selectedLogits = take(logits, indices, axis: -1).squeezed(axis: 0)
|
var selectedLogits = take(logits, indices, axis: -1).squeezed(axis: 0)
|
||||||
@@ -55,37 +54,53 @@ private func sample(logits: MLXArray, temp: Float, topP: Float = 1.0) -> MLXArra
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Parameters for text generation, see ``TokenIterator``
|
||||||
|
public struct GenerateParameters {
|
||||||
|
/// sampling temperature
|
||||||
|
public var temperature: Float = 0.6
|
||||||
|
|
||||||
|
/// top p sampling
|
||||||
|
public var topP: Float = 0.9
|
||||||
|
|
||||||
|
/// penalty factor for repeating tokens
|
||||||
|
public var repetitionPenalty: Float = 1.0
|
||||||
|
|
||||||
|
/// number of tokens to consider for repetition penalty
|
||||||
|
public var repetitionContextSize: Int = 20
|
||||||
|
|
||||||
|
public init(
|
||||||
|
temperature: Float = 0.6, topP: Float = 0.9, repetitionPenalty: Float = 1.0,
|
||||||
|
repetitionContextSize: Int = 20
|
||||||
|
) {
|
||||||
|
self.temperature = temperature
|
||||||
|
self.topP = topP
|
||||||
|
self.repetitionPenalty = repetitionPenalty
|
||||||
|
self.repetitionContextSize = repetitionContextSize
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
/// Synchronous generator of tokens.
|
/// Synchronous generator of tokens.
|
||||||
///
|
///
|
||||||
/// Port of `generate_step()` from https://github.com/ml-explore/mlx-examples/blob/main/llms/mlx_lm/utils.py
|
/// Port of `generate_step()` from https://github.com/ml-explore/mlx-examples/blob/main/llms/mlx_lm/utils.py
|
||||||
public struct TokenIterator: Sequence, IteratorProtocol {
|
public struct TokenIterator: Sequence, IteratorProtocol {
|
||||||
let model: LLMModel
|
let model: LLMModel
|
||||||
let temp: Float
|
let parameters: GenerateParameters
|
||||||
let topP: Float
|
|
||||||
let repetitionPenalty: Float
|
|
||||||
let repetitionContextSize: Int
|
|
||||||
var repetitionContext: MLXArray
|
var repetitionContext: MLXArray
|
||||||
var y: MLXArray
|
var y: MLXArray
|
||||||
var cache: [(MLXArray, MLXArray)]
|
var cache: [(MLXArray, MLXArray)]
|
||||||
|
|
||||||
var first = true
|
var first = true
|
||||||
|
|
||||||
public init(
|
public init(prompt: MLXArray, model: LLMModel, parameters: GenerateParameters) {
|
||||||
prompt: MLXArray, model: LLMModel, temp: Float = 0.0, topP: Float = 1.0,
|
|
||||||
repetitionPenalty: Float = 1.0, repetitionContextSize: Int = 20
|
|
||||||
) {
|
|
||||||
self.model = model
|
self.model = model
|
||||||
self.temp = temp
|
self.parameters = parameters
|
||||||
self.topP = topP
|
|
||||||
self.y = prompt
|
self.y = prompt
|
||||||
self.cache = []
|
self.cache = []
|
||||||
self.repetitionPenalty = repetitionPenalty
|
if parameters.repetitionContextSize > 1 {
|
||||||
self.repetitionContextSize = repetitionContextSize
|
if prompt.shape[0] <= parameters.repetitionContextSize {
|
||||||
if repetitionContextSize > 1 {
|
|
||||||
if prompt.shape[0] <= repetitionContextSize {
|
|
||||||
self.repetitionContext = prompt
|
self.repetitionContext = prompt
|
||||||
} else {
|
} else {
|
||||||
self.repetitionContext = prompt[-repetitionContextSize ... -1]
|
self.repetitionContext = prompt[-parameters.repetitionContextSize ... -1]
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
self.repetitionContext = []
|
self.repetitionContext = []
|
||||||
@@ -96,16 +111,17 @@ public struct TokenIterator: Sequence, IteratorProtocol {
|
|||||||
var logits: MLXArray
|
var logits: MLXArray
|
||||||
(logits, cache) = model(expandedDimensions(y, axis: 0), cache: cache.isEmpty ? nil : cache)
|
(logits, cache) = model(expandedDimensions(y, axis: 0), cache: cache.isEmpty ? nil : cache)
|
||||||
logits = logits[0..., -1, 0...]
|
logits = logits[0..., -1, 0...]
|
||||||
if repetitionPenalty > 1.0 {
|
if parameters.repetitionPenalty > 1.0 {
|
||||||
// apply repetition penalty
|
// apply repetition penalty
|
||||||
logits = applyRepetitionPenalty(
|
logits = applyRepetitionPenalty(
|
||||||
logits: logits, repetitionContext: repetitionContext, penalty: repetitionPenalty)
|
logits: logits, repetitionContext: repetitionContext,
|
||||||
|
penalty: parameters.repetitionPenalty)
|
||||||
}
|
}
|
||||||
y = sample(logits: logits, temp: temp, topP: topP)
|
y = sample(logits: logits, temp: parameters.temperature, topP: parameters.topP)
|
||||||
// append the current token to the context and check repetitionPenalty context see if need to remove the first token
|
// append the current token to the context and check repetitionPenalty context see if need to remove the first token
|
||||||
if repetitionContextSize > 1 {
|
if parameters.repetitionContextSize > 1 {
|
||||||
repetitionContext = concatenated([repetitionContext, y], axis: 0)
|
repetitionContext = concatenated([repetitionContext, y], axis: 0)
|
||||||
if repetitionContext.shape[0] > repetitionContextSize {
|
if repetitionContext.shape[0] > parameters.repetitionContextSize {
|
||||||
repetitionContext = repetitionContext[1...]
|
repetitionContext = repetitionContext[1...]
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -114,61 +130,88 @@ public struct TokenIterator: Sequence, IteratorProtocol {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Async generator of tokens.
|
public struct GenerateResult {
|
||||||
|
/// input tokens
|
||||||
|
public let promptTokens: [Int]
|
||||||
|
|
||||||
|
/// output tokens
|
||||||
|
public let tokens: [Int]
|
||||||
|
|
||||||
|
/// output text
|
||||||
|
public let output: String
|
||||||
|
|
||||||
|
/// time to process the prompt / generate the first token
|
||||||
|
public let promptTime: TimeInterval
|
||||||
|
|
||||||
|
/// time to generate the remaining tokens
|
||||||
|
public let generateTime: TimeInterval
|
||||||
|
|
||||||
|
public var promptTokensPerSecond: Double {
|
||||||
|
Double(promptTokens.count) / promptTime
|
||||||
|
}
|
||||||
|
|
||||||
|
public var tokensPerSecond: Double {
|
||||||
|
Double(tokens.count - 1) / generateTime
|
||||||
|
}
|
||||||
|
|
||||||
|
public func summary() -> String {
|
||||||
|
"""
|
||||||
|
Prompt Tokens per second: \(promptTokensPerSecond.formatted())
|
||||||
|
Generation tokens per second: \(tokensPerSecond.formatted())
|
||||||
|
"""
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
public enum GenerateDisposition {
|
||||||
|
case more
|
||||||
|
case stop
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Given prompt tokens generate text using the given model and parameters.
|
||||||
///
|
///
|
||||||
/// Port of `generate_step()` from https://github.com/ml-explore/mlx-examples/blob/main/llms/mlx_lm/utils.py.
|
/// - Parameters:
|
||||||
///
|
/// - promptTokens: tokenized prompt
|
||||||
/// Note that because MLXArray is not thread safe this eval's the result and sends the TokenId back
|
/// - parameters: generation parameters
|
||||||
/// to the caller.
|
/// - model: model to evaluate
|
||||||
|
/// - tokenizer: tokenizer to convert tokens back into strings and recognizer special tokens
|
||||||
|
/// - didGenerate: visitor for the tokens as they are generated
|
||||||
public func generate(
|
public func generate(
|
||||||
prompt: MLXArray, model: LLMModel, temp: Float = 0.0, topP: Float = 1.0,
|
promptTokens: [Int], parameters: GenerateParameters, model: LLMModel, tokenizer: Tokenizer,
|
||||||
repetitionPenalty: Float = 1.0, repetitionContextSize: Int = 20
|
didGenerate: ([Int]) async -> GenerateDisposition
|
||||||
) -> (
|
) async -> GenerateResult {
|
||||||
Task<Void, Never>, AsyncBufferSequence<AsyncChannel<Int>>
|
var start = Date.timeIntervalSinceReferenceDate
|
||||||
) {
|
var promptTime: TimeInterval = 0
|
||||||
let channel = AsyncChannel<Int>()
|
|
||||||
let buffer = channel.buffer(policy: .bounded(10))
|
|
||||||
|
|
||||||
let task = Task {
|
var tokens = [Int]()
|
||||||
var y = prompt
|
|
||||||
var cache = [(MLXArray, MLXArray)]()
|
|
||||||
var repetitionContext: MLXArray
|
|
||||||
|
|
||||||
if repetitionContextSize > 1 {
|
for token in TokenIterator(
|
||||||
if prompt.shape[0] <= repetitionContextSize {
|
prompt: MLXArray(promptTokens), model: model, parameters: parameters)
|
||||||
repetitionContext = prompt
|
{
|
||||||
} else {
|
// compute the timing for the prompt
|
||||||
repetitionContext = prompt[-repetitionContextSize ... -1]
|
if tokens.isEmpty {
|
||||||
}
|
eval(token)
|
||||||
} else {
|
let now = Date.timeIntervalSinceReferenceDate
|
||||||
repetitionContext = []
|
promptTime = now - start
|
||||||
|
start = now
|
||||||
}
|
}
|
||||||
while !Task.isCancelled {
|
|
||||||
var logits: MLXArray
|
|
||||||
(logits, cache) = model(
|
|
||||||
expandedDimensions(y, axis: 0), cache: cache.isEmpty ? nil : cache)
|
|
||||||
|
|
||||||
logits = logits[0..., -1, 0...]
|
let t = token.item(Int.self)
|
||||||
if repetitionPenalty > 1.0 {
|
if t == tokenizer.unknownTokenId || t == tokenizer.eosTokenId {
|
||||||
// apply repetition penalty
|
break
|
||||||
logits = applyRepetitionPenalty(
|
}
|
||||||
logits: logits, repetitionContext: repetitionContext, penalty: repetitionPenalty
|
|
||||||
)
|
|
||||||
}
|
|
||||||
y = sample(logits: logits, temp: temp, topP: topP)
|
|
||||||
// append the current token to the context and check repetitionPenalty context see if need to remove the first token
|
|
||||||
if repetitionContextSize > 1 {
|
|
||||||
repetitionContext = concatenated([repetitionContext, y], axis: 0)
|
|
||||||
if repetitionContext.shape[0] > repetitionContextSize {
|
|
||||||
repetitionContext = repetitionContext[1...]
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
eval(y)
|
tokens.append(t)
|
||||||
|
|
||||||
await channel.send(y.item(Int.self))
|
if await didGenerate(tokens) == .stop {
|
||||||
|
break
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return (task, buffer)
|
let now = Date.timeIntervalSinceReferenceDate
|
||||||
|
let generateTime = now - start
|
||||||
|
|
||||||
|
return GenerateResult(
|
||||||
|
promptTokens: promptTokens, tokens: tokens,
|
||||||
|
output: tokenizer.decode(tokens: tokens),
|
||||||
|
promptTime: promptTime, generateTime: generateTime)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -10,16 +10,27 @@ import Tokenizers
|
|||||||
@main
|
@main
|
||||||
struct LLMTool: AsyncParsableCommand {
|
struct LLMTool: AsyncParsableCommand {
|
||||||
static var configuration = CommandConfiguration(
|
static var configuration = CommandConfiguration(
|
||||||
abstract: "Command line tool for generating text using Llama models",
|
abstract: "Command line tool for generating text and manipulating LLMs",
|
||||||
subcommands: [SyncGenerator.self, AsyncGenerator.self],
|
subcommands: [EvaluateCommand.self],
|
||||||
defaultSubcommand: SyncGenerator.self)
|
defaultSubcommand: EvaluateCommand.self)
|
||||||
}
|
}
|
||||||
|
|
||||||
struct LLMArguments: ParsableArguments {
|
/// Command line arguments for loading a model.
|
||||||
|
struct ModelArguments: ParsableArguments {
|
||||||
|
|
||||||
@Option(name: .long, help: "Name of the huggingface model")
|
@Option(name: .long, help: "Name of the huggingface model")
|
||||||
var model: String = "mlx-community/Mistral-7B-v0.1-hf-4bit-mlx"
|
var model: String = "mlx-community/Mistral-7B-v0.1-hf-4bit-mlx"
|
||||||
|
|
||||||
|
func load() async throws -> (LLMModel, Tokenizer, ModelConfiguration) {
|
||||||
|
let modelConfiguration = ModelConfiguration.configuration(id: model)
|
||||||
|
let (model, tokenizer) = try await LLM.load(configuration: modelConfiguration)
|
||||||
|
return (model, tokenizer, modelConfiguration)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Command line arguments for controlling generation of text.
|
||||||
|
struct GenerateArguments: ParsableArguments {
|
||||||
|
|
||||||
@Option(name: .shortAndLong, help: "The message to be processed by the model")
|
@Option(name: .shortAndLong, help: "The message to be processed by the model")
|
||||||
var prompt = "compare python and swift"
|
var prompt = "compare python and swift"
|
||||||
|
|
||||||
@@ -29,19 +40,67 @@ struct LLMArguments: ParsableArguments {
|
|||||||
@Option(name: .shortAndLong, help: "The sampling temperature")
|
@Option(name: .shortAndLong, help: "The sampling temperature")
|
||||||
var temperature: Float = 0.6
|
var temperature: Float = 0.6
|
||||||
|
|
||||||
@Option(name: .shortAndLong, help: "The top p sampling")
|
@Option(name: .long, help: "The top p sampling")
|
||||||
var topP: Float = 0.9
|
var topP: Float = 0.9
|
||||||
|
|
||||||
@Option(name: .shortAndLong, help: "The penalty factor for repeating tokens")
|
@Option(name: .long, help: "The penalty factor for repeating tokens")
|
||||||
var repetitionPenalty: Float = 1.0
|
var repetitionPenalty: Float = 1.0
|
||||||
|
|
||||||
@Option(name: .shortAndLong, help: "The number of tokens to consider for repetition penalty")
|
@Option(name: .long, help: "The number of tokens to consider for repetition penalty")
|
||||||
var repetitionContextSize: Int = 20
|
var repetitionContextSize: Int = 20
|
||||||
|
|
||||||
@Option(name: .long, help: "The PRNG seed")
|
@Option(name: .long, help: "The PRNG seed")
|
||||||
var seed: UInt64 = 0
|
var seed: UInt64 = 0
|
||||||
|
|
||||||
@Flag(help: "Show memory stats")
|
var generateParameters: GenerateParameters {
|
||||||
|
GenerateParameters(
|
||||||
|
temperature: temperature, topP: topP, repetitionPenalty: repetitionPenalty,
|
||||||
|
repetitionContextSize: repetitionContextSize)
|
||||||
|
}
|
||||||
|
|
||||||
|
func tokenizePrompt(configuration: ModelConfiguration, tokenizer: Tokenizer) -> (String, [Int])
|
||||||
|
{
|
||||||
|
MLXRandom.seed(seed)
|
||||||
|
|
||||||
|
let prompt = configuration.prepare(prompt: self.prompt)
|
||||||
|
let promptTokens = tokenizer.encode(text: prompt)
|
||||||
|
|
||||||
|
return (prompt, promptTokens)
|
||||||
|
}
|
||||||
|
|
||||||
|
func generate(promptTokens: [Int], model: LLMModel, tokenizer: Tokenizer) async
|
||||||
|
-> GenerateResult
|
||||||
|
{
|
||||||
|
// track how much we have printed
|
||||||
|
var printed = 0
|
||||||
|
|
||||||
|
return await LLM.generate(
|
||||||
|
promptTokens: promptTokens, parameters: generateParameters,
|
||||||
|
model: model, tokenizer: tokenizer
|
||||||
|
) { tokens in
|
||||||
|
|
||||||
|
// print any new parts of the string
|
||||||
|
let fullOutput = tokenizer.decode(tokens: tokens)
|
||||||
|
let emitLength = fullOutput.count - printed
|
||||||
|
let suffix = fullOutput.suffix(emitLength)
|
||||||
|
print(suffix, terminator: "")
|
||||||
|
fflush(stdout)
|
||||||
|
|
||||||
|
printed = fullOutput.count
|
||||||
|
|
||||||
|
if tokens.count >= maxTokens {
|
||||||
|
return .stop
|
||||||
|
} else {
|
||||||
|
return .more
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Argument package for adjusting and reporting memory use.
|
||||||
|
struct MemoryArguments: ParsableArguments {
|
||||||
|
|
||||||
|
@Flag(name: .long, help: "Show memory stats")
|
||||||
var memoryStats = false
|
var memoryStats = false
|
||||||
|
|
||||||
@Option(name: .long, help: "Maximum cache size in M")
|
@Option(name: .long, help: "Maximum cache size in M")
|
||||||
@@ -52,9 +111,7 @@ struct LLMArguments: ParsableArguments {
|
|||||||
|
|
||||||
var startMemory: GPU.Snapshot?
|
var startMemory: GPU.Snapshot?
|
||||||
|
|
||||||
mutating func load() async throws -> (LLMModel, Tokenizer, ModelConfiguration) {
|
mutating func start<L>(_ load: () async throws -> L) async throws -> L {
|
||||||
MLXRandom.seed(seed)
|
|
||||||
|
|
||||||
if let cacheSize {
|
if let cacheSize {
|
||||||
GPU.set(cacheLimit: cacheSize * 1024 * 1024)
|
GPU.set(cacheLimit: cacheSize * 1024 * 1024)
|
||||||
}
|
}
|
||||||
@@ -63,20 +120,29 @@ struct LLMArguments: ParsableArguments {
|
|||||||
GPU.set(memoryLimit: memorySize * 1024 * 1024)
|
GPU.set(memoryLimit: memorySize * 1024 * 1024)
|
||||||
}
|
}
|
||||||
|
|
||||||
let modelConfiguration = ModelConfiguration.configuration(id: model)
|
let result = try await load()
|
||||||
let (model, tokenizer) = try await LLM.load(configuration: modelConfiguration)
|
|
||||||
|
|
||||||
startMemory = GPU.snapshot()
|
startMemory = GPU.snapshot()
|
||||||
|
|
||||||
return (model, tokenizer, modelConfiguration)
|
return result
|
||||||
}
|
}
|
||||||
|
|
||||||
func tokenizePropmpt(configuration: ModelConfiguration, tokenizer: Tokenizer) -> (String, [Int])
|
mutating func start() {
|
||||||
{
|
if let cacheSize {
|
||||||
let prompt = configuration.prepare(prompt: self.prompt)
|
GPU.set(cacheLimit: cacheSize * 1024 * 1024)
|
||||||
let promptTokens = tokenizer.encode(text: prompt)
|
}
|
||||||
|
|
||||||
return (prompt, promptTokens)
|
if let memorySize {
|
||||||
|
GPU.set(memoryLimit: memorySize * 1024 * 1024)
|
||||||
|
}
|
||||||
|
|
||||||
|
startMemory = GPU.snapshot()
|
||||||
|
}
|
||||||
|
|
||||||
|
func reportCurrent() {
|
||||||
|
if memoryStats {
|
||||||
|
let memory = GPU.snapshot()
|
||||||
|
print(memory.description)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func reportMemoryStatistics() {
|
func reportMemoryStatistics() {
|
||||||
@@ -106,164 +172,36 @@ struct LLMArguments: ParsableArguments {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
struct SyncGenerator: AsyncParsableCommand {
|
struct EvaluateCommand: AsyncParsableCommand {
|
||||||
|
|
||||||
static var configuration = CommandConfiguration(
|
static var configuration = CommandConfiguration(
|
||||||
commandName: "sync",
|
commandName: "eval",
|
||||||
abstract: "Synchronous generator"
|
abstract: "evaluate prompt and generate text"
|
||||||
)
|
)
|
||||||
|
|
||||||
@OptionGroup var args: LLMArguments
|
@OptionGroup var args: ModelArguments
|
||||||
|
@OptionGroup var memory: MemoryArguments
|
||||||
|
@OptionGroup var generate: GenerateArguments
|
||||||
|
|
||||||
@MainActor
|
@MainActor
|
||||||
mutating func run() async throws {
|
mutating func run() async throws {
|
||||||
let (model, tokenizer, modelConfiguration) = try await args.load()
|
let (model, tokenizer, modelConfiguration) = try await memory.start(args.load)
|
||||||
|
|
||||||
print("Model loaded -> \(modelConfiguration.id)")
|
print("Model loaded -> \(modelConfiguration.id)")
|
||||||
|
|
||||||
let (prompt, promptTokens) = args.tokenizePropmpt(
|
let (prompt, promptTokens) = generate.tokenizePrompt(
|
||||||
configuration: modelConfiguration, tokenizer: tokenizer)
|
configuration: modelConfiguration, tokenizer: tokenizer)
|
||||||
|
|
||||||
print("Starting generation ...")
|
print("Starting generation ...")
|
||||||
print(prompt, terminator: "")
|
print(prompt, terminator: "")
|
||||||
|
|
||||||
var start = Date.timeIntervalSinceReferenceDate
|
let result = await generate.generate(
|
||||||
var promptTime: TimeInterval = 0
|
promptTokens: promptTokens, model: model, tokenizer: tokenizer)
|
||||||
|
|
||||||
// collect the tokens and keep track of how much of the string
|
|
||||||
// we have printed already
|
|
||||||
var tokens = [Int]()
|
|
||||||
var printed = 0
|
|
||||||
|
|
||||||
for token in TokenIterator(
|
|
||||||
prompt: MLXArray(promptTokens), model: model, temp: args.temperature, topP: args.topP,
|
|
||||||
repetitionPenalty: args.repetitionPenalty,
|
|
||||||
repetitionContextSize: args.repetitionContextSize)
|
|
||||||
{
|
|
||||||
if tokens.isEmpty {
|
|
||||||
eval(token)
|
|
||||||
let now = Date.timeIntervalSinceReferenceDate
|
|
||||||
promptTime = now - start
|
|
||||||
start = now
|
|
||||||
}
|
|
||||||
|
|
||||||
let t = token.item(Int.self)
|
|
||||||
if t == tokenizer.unknownTokenId || t == tokenizer.eosTokenId {
|
|
||||||
break
|
|
||||||
}
|
|
||||||
tokens.append(t)
|
|
||||||
|
|
||||||
// print any new parts of the string
|
|
||||||
let fullOutput = tokenizer.decode(tokens: tokens)
|
|
||||||
let emitLength = fullOutput.count - printed
|
|
||||||
let suffix = fullOutput.suffix(emitLength)
|
|
||||||
print(suffix, terminator: "")
|
|
||||||
fflush(stdout)
|
|
||||||
|
|
||||||
printed = fullOutput.count
|
|
||||||
|
|
||||||
if tokens.count == args.maxTokens {
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
print()
|
print()
|
||||||
print("------")
|
print("------")
|
||||||
let now = Date.timeIntervalSinceReferenceDate
|
print(result.summary())
|
||||||
let generateTime = now - start
|
|
||||||
|
|
||||||
print(
|
memory.reportMemoryStatistics()
|
||||||
"""
|
|
||||||
Prompt Tokens per second: \((Double(promptTokens.count) / promptTime).formatted())
|
|
||||||
Generation tokens per second: \((Double(tokens.count - 1) / generateTime).formatted())
|
|
||||||
""")
|
|
||||||
|
|
||||||
args.reportMemoryStatistics()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Example of an async generator.
|
|
||||||
///
|
|
||||||
/// Note that all of the computation is done on another thread and TokenId (Int32) are sent
|
|
||||||
/// rather than MLXArray.
|
|
||||||
struct AsyncGenerator: AsyncParsableCommand {
|
|
||||||
|
|
||||||
static var configuration = CommandConfiguration(
|
|
||||||
commandName: "async",
|
|
||||||
abstract: "async generator"
|
|
||||||
)
|
|
||||||
|
|
||||||
@OptionGroup var args: LLMArguments
|
|
||||||
|
|
||||||
@MainActor
|
|
||||||
mutating func run() async throws {
|
|
||||||
let (model, tokenizer, modelConfiguration) = try await args.load()
|
|
||||||
|
|
||||||
print("Model loaded -> \(modelConfiguration.id)")
|
|
||||||
|
|
||||||
let (prompt, promptTokens) = args.tokenizePropmpt(
|
|
||||||
configuration: modelConfiguration, tokenizer: tokenizer)
|
|
||||||
|
|
||||||
print("Starting generation ...")
|
|
||||||
print(prompt, terminator: "")
|
|
||||||
|
|
||||||
var start = Date.timeIntervalSinceReferenceDate
|
|
||||||
var promptTime: TimeInterval = 0
|
|
||||||
|
|
||||||
// collect the tokens and keep track of how much of the string
|
|
||||||
// we have printed already
|
|
||||||
var tokens = [Int]()
|
|
||||||
var printed = 0
|
|
||||||
|
|
||||||
let (task, channel) = generate(
|
|
||||||
prompt: MLXArray(promptTokens), model: model, temp: args.temperature, topP: args.topP,
|
|
||||||
repetitionPenalty: args.repetitionPenalty,
|
|
||||||
repetitionContextSize: args.repetitionContextSize)
|
|
||||||
|
|
||||||
for await token in channel {
|
|
||||||
if tokens.isEmpty {
|
|
||||||
let now = Date.timeIntervalSinceReferenceDate
|
|
||||||
promptTime = now - start
|
|
||||||
start = now
|
|
||||||
}
|
|
||||||
|
|
||||||
if token == tokenizer.unknownTokenId || token == tokenizer.eosTokenId {
|
|
||||||
break
|
|
||||||
}
|
|
||||||
tokens.append(token)
|
|
||||||
|
|
||||||
// print any new parts of the string
|
|
||||||
let fullOutput = tokenizer.decode(tokens: tokens)
|
|
||||||
let emitLength = fullOutput.count - printed
|
|
||||||
let suffix = fullOutput.suffix(emitLength)
|
|
||||||
print(suffix, terminator: "")
|
|
||||||
fflush(stdout)
|
|
||||||
|
|
||||||
printed = fullOutput.count
|
|
||||||
|
|
||||||
if tokens.count == args.maxTokens {
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// tell the task to stop
|
|
||||||
task.cancel()
|
|
||||||
|
|
||||||
print()
|
|
||||||
print("------")
|
|
||||||
let now = Date.timeIntervalSinceReferenceDate
|
|
||||||
let generateTime = now - start
|
|
||||||
|
|
||||||
print(
|
|
||||||
"""
|
|
||||||
Prompt Tokens per second: \((Double(promptTokens.count) / promptTime).formatted())
|
|
||||||
Generation tokens per second: \((Double(tokens.count - 1) / generateTime).formatted())
|
|
||||||
""")
|
|
||||||
|
|
||||||
args.reportMemoryStatistics()
|
|
||||||
|
|
||||||
// wait for the task to complete -- since it is running async, it might
|
|
||||||
// be in the middle of running the model
|
|
||||||
try? await Task.sleep(for: .milliseconds(500))
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -42,11 +42,9 @@
|
|||||||
C3932D592B6A0BE400A81055 /* Random.swift in Sources */ = {isa = PBXBuildFile; fileRef = C3932D582B6A0BE400A81055 /* Random.swift */; };
|
C3932D592B6A0BE400A81055 /* Random.swift in Sources */ = {isa = PBXBuildFile; fileRef = C3932D582B6A0BE400A81055 /* Random.swift */; };
|
||||||
C397C59C2B62C6D0004B084D /* ArgumentParser in Frameworks */ = {isa = PBXBuildFile; productRef = C397C59B2B62C6D0004B084D /* ArgumentParser */; };
|
C397C59C2B62C6D0004B084D /* ArgumentParser in Frameworks */ = {isa = PBXBuildFile; productRef = C397C59B2B62C6D0004B084D /* ArgumentParser */; };
|
||||||
C3A8B3AC2B9283150002EFB8 /* Models.swift in Sources */ = {isa = PBXBuildFile; fileRef = C3A8B3AB2B9283150002EFB8 /* Models.swift */; };
|
C3A8B3AC2B9283150002EFB8 /* Models.swift in Sources */ = {isa = PBXBuildFile; fileRef = C3A8B3AB2B9283150002EFB8 /* Models.swift */; };
|
||||||
C3A8B3CA2B92951E0002EFB8 /* MNISTTrainer-Info.plist in Resources */ = {isa = PBXBuildFile; fileRef = C3A8B3C22B92951E0002EFB8 /* MNISTTrainer-Info.plist */; };
|
|
||||||
C3A8B3CB2B92951E0002EFB8 /* Assets.xcassets in Resources */ = {isa = PBXBuildFile; fileRef = C3A8B3C32B92951E0002EFB8 /* Assets.xcassets */; };
|
C3A8B3CB2B92951E0002EFB8 /* Assets.xcassets in Resources */ = {isa = PBXBuildFile; fileRef = C3A8B3C32B92951E0002EFB8 /* Assets.xcassets */; };
|
||||||
C3A8B3CC2B92951E0002EFB8 /* MNISTTrainerApp.swift in Sources */ = {isa = PBXBuildFile; fileRef = C3A8B3C42B92951E0002EFB8 /* MNISTTrainerApp.swift */; };
|
C3A8B3CC2B92951E0002EFB8 /* MNISTTrainerApp.swift in Sources */ = {isa = PBXBuildFile; fileRef = C3A8B3C42B92951E0002EFB8 /* MNISTTrainerApp.swift */; };
|
||||||
C3A8B3CD2B92951E0002EFB8 /* Preview Assets.xcassets in Resources */ = {isa = PBXBuildFile; fileRef = C3A8B3C62B92951E0002EFB8 /* Preview Assets.xcassets */; };
|
C3A8B3CD2B92951E0002EFB8 /* Preview Assets.xcassets in Resources */ = {isa = PBXBuildFile; fileRef = C3A8B3C62B92951E0002EFB8 /* Preview Assets.xcassets */; };
|
||||||
C3A8B3CE2B92951E0002EFB8 /* README.md in Resources */ = {isa = PBXBuildFile; fileRef = C3A8B3C82B92951E0002EFB8 /* README.md */; };
|
|
||||||
C3A8B3CF2B92951E0002EFB8 /* ContentView.swift in Sources */ = {isa = PBXBuildFile; fileRef = C3A8B3C92B92951E0002EFB8 /* ContentView.swift */; };
|
C3A8B3CF2B92951E0002EFB8 /* ContentView.swift in Sources */ = {isa = PBXBuildFile; fileRef = C3A8B3C92B92951E0002EFB8 /* ContentView.swift */; };
|
||||||
C3A8B3D22B92A0880002EFB8 /* MLXOptimizers in Frameworks */ = {isa = PBXBuildFile; productRef = C3A8B3D12B92A0880002EFB8 /* MLXOptimizers */; };
|
C3A8B3D22B92A0880002EFB8 /* MLXOptimizers in Frameworks */ = {isa = PBXBuildFile; productRef = C3A8B3D12B92A0880002EFB8 /* MLXOptimizers */; };
|
||||||
C3A8B3D32B92A0880002EFB8 /* MNIST.framework in Frameworks */ = {isa = PBXBuildFile; fileRef = C34E490D2B69A92900FCB841 /* MNIST.framework */; };
|
C3A8B3D32B92A0880002EFB8 /* MNIST.framework in Frameworks */ = {isa = PBXBuildFile; fileRef = C34E490D2B69A92900FCB841 /* MNIST.framework */; };
|
||||||
@@ -54,7 +52,6 @@
|
|||||||
C3A8B3F32B92A2A90002EFB8 /* Assets.xcassets in Resources */ = {isa = PBXBuildFile; fileRef = C3A8B3EC2B92A2A90002EFB8 /* Assets.xcassets */; };
|
C3A8B3F32B92A2A90002EFB8 /* Assets.xcassets in Resources */ = {isa = PBXBuildFile; fileRef = C3A8B3EC2B92A2A90002EFB8 /* Assets.xcassets */; };
|
||||||
C3A8B3F42B92A2A90002EFB8 /* LLMEvalApp.swift in Sources */ = {isa = PBXBuildFile; fileRef = C3A8B3ED2B92A2A90002EFB8 /* LLMEvalApp.swift */; };
|
C3A8B3F42B92A2A90002EFB8 /* LLMEvalApp.swift in Sources */ = {isa = PBXBuildFile; fileRef = C3A8B3ED2B92A2A90002EFB8 /* LLMEvalApp.swift */; };
|
||||||
C3A8B3F52B92A2A90002EFB8 /* Preview Assets.xcassets in Resources */ = {isa = PBXBuildFile; fileRef = C3A8B3EF2B92A2A90002EFB8 /* Preview Assets.xcassets */; };
|
C3A8B3F52B92A2A90002EFB8 /* Preview Assets.xcassets in Resources */ = {isa = PBXBuildFile; fileRef = C3A8B3EF2B92A2A90002EFB8 /* Preview Assets.xcassets */; };
|
||||||
C3A8B3F62B92A2A90002EFB8 /* README.md in Resources */ = {isa = PBXBuildFile; fileRef = C3A8B3F02B92A2A90002EFB8 /* README.md */; };
|
|
||||||
C3A8B3F72B92A2A90002EFB8 /* ContentView.swift in Sources */ = {isa = PBXBuildFile; fileRef = C3A8B3F22B92A2A90002EFB8 /* ContentView.swift */; };
|
C3A8B3F72B92A2A90002EFB8 /* ContentView.swift in Sources */ = {isa = PBXBuildFile; fileRef = C3A8B3F22B92A2A90002EFB8 /* ContentView.swift */; };
|
||||||
C3A8B3F82B92A3360002EFB8 /* LLM.framework in Frameworks */ = {isa = PBXBuildFile; fileRef = C38935C52B869C7A0037B833 /* LLM.framework */; };
|
C3A8B3F82B92A3360002EFB8 /* LLM.framework in Frameworks */ = {isa = PBXBuildFile; fileRef = C38935C52B869C7A0037B833 /* LLM.framework */; };
|
||||||
C3A8B3F92B92A3360002EFB8 /* LLM.framework in Embed Frameworks */ = {isa = PBXBuildFile; fileRef = C38935C52B869C7A0037B833 /* LLM.framework */; settings = {ATTRIBUTES = (CodeSignOnCopy, RemoveHeadersOnCopy, ); }; };
|
C3A8B3F92B92A3360002EFB8 /* LLM.framework in Embed Frameworks */ = {isa = PBXBuildFile; fileRef = C38935C52B869C7A0037B833 /* LLM.framework */; settings = {ATTRIBUTES = (CodeSignOnCopy, RemoveHeadersOnCopy, ); }; };
|
||||||
@@ -801,8 +798,6 @@
|
|||||||
isa = PBXResourcesBuildPhase;
|
isa = PBXResourcesBuildPhase;
|
||||||
buildActionMask = 2147483647;
|
buildActionMask = 2147483647;
|
||||||
files = (
|
files = (
|
||||||
C3A8B3CE2B92951E0002EFB8 /* README.md in Resources */,
|
|
||||||
C3A8B3CA2B92951E0002EFB8 /* MNISTTrainer-Info.plist in Resources */,
|
|
||||||
C3A8B3CD2B92951E0002EFB8 /* Preview Assets.xcassets in Resources */,
|
C3A8B3CD2B92951E0002EFB8 /* Preview Assets.xcassets in Resources */,
|
||||||
C3A8B3CB2B92951E0002EFB8 /* Assets.xcassets in Resources */,
|
C3A8B3CB2B92951E0002EFB8 /* Assets.xcassets in Resources */,
|
||||||
);
|
);
|
||||||
@@ -813,7 +808,6 @@
|
|||||||
buildActionMask = 2147483647;
|
buildActionMask = 2147483647;
|
||||||
files = (
|
files = (
|
||||||
C3A8B3F52B92A2A90002EFB8 /* Preview Assets.xcassets in Resources */,
|
C3A8B3F52B92A2A90002EFB8 /* Preview Assets.xcassets in Resources */,
|
||||||
C3A8B3F62B92A2A90002EFB8 /* README.md in Resources */,
|
|
||||||
C3A8B3F32B92A2A90002EFB8 /* Assets.xcassets in Resources */,
|
C3A8B3F32B92A2A90002EFB8 /* Assets.xcassets in Resources */,
|
||||||
);
|
);
|
||||||
runOnlyForDeploymentPostprocessing = 0;
|
runOnlyForDeploymentPostprocessing = 0;
|
||||||
|
|||||||
@@ -1,4 +1,5 @@
|
|||||||
{
|
{
|
||||||
|
"originHash" : "da53546673b6d05016b6e5640c18814c7dba5b5af8db34715afe6d633037c758",
|
||||||
"pins" : [
|
"pins" : [
|
||||||
{
|
{
|
||||||
"identity" : "gzipswift",
|
"identity" : "gzipswift",
|
||||||
@@ -15,7 +16,7 @@
|
|||||||
"location" : "https://github.com/ml-explore/mlx-swift",
|
"location" : "https://github.com/ml-explore/mlx-swift",
|
||||||
"state" : {
|
"state" : {
|
||||||
"branch" : "main",
|
"branch" : "main",
|
||||||
"revision" : "a1c544c817d44cfdfa1a650f521066b565c2ae4f"
|
"revision" : "b4d3e4bbbe41e6dc7c46d5ba075049ae7177961b"
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@@ -82,5 +83,5 @@
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
"version" : 2
|
"version" : 3
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user