prepare for lora branch (#47)

- remove async llm generation -- this is just doubling our work
	- and does not match the style used in the example applications
- package generation parameters into a struct
- refactor command line arguments into distinct pieces based on their use
	- this will be reusable in the lora commands
This commit is contained in:
David Koski
2024-04-10 10:56:18 -07:00
committed by GitHub
parent cedf73421f
commit 96b94b0df6
5 changed files with 227 additions and 274 deletions

View File

@@ -4,6 +4,7 @@ import AsyncAlgorithms
import Foundation
import MLX
import MLXRandom
import Tokenizers
private func topPSampling(logits: MLXArray, topP: Float, temp: Float) -> MLXArray {
var logits = logits
@@ -28,8 +29,6 @@ private func topPSampling(logits: MLXArray, topP: Float, temp: Float) -> MLXArra
private func applyRepetitionPenalty(
logits: MLXArray, repetitionContext: MLXArray, penalty: Float
) -> MLXArray {
var logits = logits
if repetitionContext.shape[0] > 0 {
let indices = repetitionContext
var selectedLogits = take(logits, indices, axis: -1).squeezed(axis: 0)
@@ -55,37 +54,53 @@ private func sample(logits: MLXArray, temp: Float, topP: Float = 1.0) -> MLXArra
}
}
/// Parameters for text generation, see ``TokenIterator``
public struct GenerateParameters {
/// sampling temperature
public var temperature: Float = 0.6
/// top p sampling
public var topP: Float = 0.9
/// penalty factor for repeating tokens
public var repetitionPenalty: Float = 1.0
/// number of tokens to consider for repetition penalty
public var repetitionContextSize: Int = 20
public init(
temperature: Float = 0.6, topP: Float = 0.9, repetitionPenalty: Float = 1.0,
repetitionContextSize: Int = 20
) {
self.temperature = temperature
self.topP = topP
self.repetitionPenalty = repetitionPenalty
self.repetitionContextSize = repetitionContextSize
}
}
/// Synchronous generator of tokens.
///
/// Port of `generate_step()` from https://github.com/ml-explore/mlx-examples/blob/main/llms/mlx_lm/utils.py
public struct TokenIterator: Sequence, IteratorProtocol {
let model: LLMModel
let temp: Float
let topP: Float
let repetitionPenalty: Float
let repetitionContextSize: Int
let parameters: GenerateParameters
var repetitionContext: MLXArray
var y: MLXArray
var cache: [(MLXArray, MLXArray)]
var first = true
public init(
prompt: MLXArray, model: LLMModel, temp: Float = 0.0, topP: Float = 1.0,
repetitionPenalty: Float = 1.0, repetitionContextSize: Int = 20
) {
public init(prompt: MLXArray, model: LLMModel, parameters: GenerateParameters) {
self.model = model
self.temp = temp
self.topP = topP
self.parameters = parameters
self.y = prompt
self.cache = []
self.repetitionPenalty = repetitionPenalty
self.repetitionContextSize = repetitionContextSize
if repetitionContextSize > 1 {
if prompt.shape[0] <= repetitionContextSize {
if parameters.repetitionContextSize > 1 {
if prompt.shape[0] <= parameters.repetitionContextSize {
self.repetitionContext = prompt
} else {
self.repetitionContext = prompt[-repetitionContextSize ... -1]
self.repetitionContext = prompt[-parameters.repetitionContextSize ... -1]
}
} else {
self.repetitionContext = []
@@ -96,16 +111,17 @@ public struct TokenIterator: Sequence, IteratorProtocol {
var logits: MLXArray
(logits, cache) = model(expandedDimensions(y, axis: 0), cache: cache.isEmpty ? nil : cache)
logits = logits[0..., -1, 0...]
if repetitionPenalty > 1.0 {
if parameters.repetitionPenalty > 1.0 {
// apply repetition penalty
logits = applyRepetitionPenalty(
logits: logits, repetitionContext: repetitionContext, penalty: repetitionPenalty)
logits: logits, repetitionContext: repetitionContext,
penalty: parameters.repetitionPenalty)
}
y = sample(logits: logits, temp: temp, topP: topP)
y = sample(logits: logits, temp: parameters.temperature, topP: parameters.topP)
// append the current token to the context and check repetitionPenalty context see if need to remove the first token
if repetitionContextSize > 1 {
if parameters.repetitionContextSize > 1 {
repetitionContext = concatenated([repetitionContext, y], axis: 0)
if repetitionContext.shape[0] > repetitionContextSize {
if repetitionContext.shape[0] > parameters.repetitionContextSize {
repetitionContext = repetitionContext[1...]
}
}
@@ -114,61 +130,88 @@ public struct TokenIterator: Sequence, IteratorProtocol {
}
}
/// Async generator of tokens.
public struct GenerateResult {
/// input tokens
public let promptTokens: [Int]
/// output tokens
public let tokens: [Int]
/// output text
public let output: String
/// time to process the prompt / generate the first token
public let promptTime: TimeInterval
/// time to generate the remaining tokens
public let generateTime: TimeInterval
public var promptTokensPerSecond: Double {
Double(promptTokens.count) / promptTime
}
public var tokensPerSecond: Double {
Double(tokens.count - 1) / generateTime
}
public func summary() -> String {
"""
Prompt Tokens per second: \(promptTokensPerSecond.formatted())
Generation tokens per second: \(tokensPerSecond.formatted())
"""
}
}
public enum GenerateDisposition {
case more
case stop
}
/// Given prompt tokens generate text using the given model and parameters.
///
/// Port of `generate_step()` from https://github.com/ml-explore/mlx-examples/blob/main/llms/mlx_lm/utils.py.
///
/// Note that because MLXArray is not thread safe this eval's the result and sends the TokenId back
/// to the caller.
/// - Parameters:
/// - promptTokens: tokenized prompt
/// - parameters: generation parameters
/// - model: model to evaluate
/// - tokenizer: tokenizer to convert tokens back into strings and recognizer special tokens
/// - didGenerate: visitor for the tokens as they are generated
public func generate(
prompt: MLXArray, model: LLMModel, temp: Float = 0.0, topP: Float = 1.0,
repetitionPenalty: Float = 1.0, repetitionContextSize: Int = 20
) -> (
Task<Void, Never>, AsyncBufferSequence<AsyncChannel<Int>>
) {
let channel = AsyncChannel<Int>()
let buffer = channel.buffer(policy: .bounded(10))
promptTokens: [Int], parameters: GenerateParameters, model: LLMModel, tokenizer: Tokenizer,
didGenerate: ([Int]) async -> GenerateDisposition
) async -> GenerateResult {
var start = Date.timeIntervalSinceReferenceDate
var promptTime: TimeInterval = 0
let task = Task {
var y = prompt
var cache = [(MLXArray, MLXArray)]()
var repetitionContext: MLXArray
var tokens = [Int]()
if repetitionContextSize > 1 {
if prompt.shape[0] <= repetitionContextSize {
repetitionContext = prompt
} else {
repetitionContext = prompt[-repetitionContextSize ... -1]
}
} else {
repetitionContext = []
for token in TokenIterator(
prompt: MLXArray(promptTokens), model: model, parameters: parameters)
{
// compute the timing for the prompt
if tokens.isEmpty {
eval(token)
let now = Date.timeIntervalSinceReferenceDate
promptTime = now - start
start = now
}
while !Task.isCancelled {
var logits: MLXArray
(logits, cache) = model(
expandedDimensions(y, axis: 0), cache: cache.isEmpty ? nil : cache)
logits = logits[0..., -1, 0...]
if repetitionPenalty > 1.0 {
// apply repetition penalty
logits = applyRepetitionPenalty(
logits: logits, repetitionContext: repetitionContext, penalty: repetitionPenalty
)
}
y = sample(logits: logits, temp: temp, topP: topP)
// append the current token to the context and check repetitionPenalty context see if need to remove the first token
if repetitionContextSize > 1 {
repetitionContext = concatenated([repetitionContext, y], axis: 0)
if repetitionContext.shape[0] > repetitionContextSize {
repetitionContext = repetitionContext[1...]
}
}
let t = token.item(Int.self)
if t == tokenizer.unknownTokenId || t == tokenizer.eosTokenId {
break
}
eval(y)
tokens.append(t)
await channel.send(y.item(Int.self))
if await didGenerate(tokens) == .stop {
break
}
}
return (task, buffer)
let now = Date.timeIntervalSinceReferenceDate
let generateTime = now - start
return GenerateResult(
promptTokens: promptTokens, tokens: tokens,
output: tokenizer.decode(tokens: tokens),
promptTime: promptTime, generateTime: generateTime)
}