handle partially quantized models (#76)

* handle partially quantized models

- fix for #53 #71 #69 #74
- in order to test the models
	- I added a default prompt of an appropriate form
	- while working on the model configuration also added additional stop tokens (#74)
- fixed the repetitionPenalty code (#71)
This commit is contained in:
David Koski
2024-05-28 16:35:11 -07:00
committed by GitHub
parent 65f4968e5f
commit 9d74afd119
12 changed files with 139 additions and 67 deletions

View File

@@ -12,7 +12,7 @@ private func topPSampling(logits: MLXArray, topP: Float, temp: Float) -> MLXArra
logits = logits.asType(.float32)
}
let probs = softMax(logits / temp, axis: -1)
let probs = softmax(logits / temp, axis: -1)
let sortedIndices = argSort(probs, axis: -1)
// probs shape is [B,V] and after take it will be [1, B, V], so we squeeze it back to [B, V]
@@ -31,7 +31,7 @@ private func applyRepetitionPenalty(
) -> MLXArray {
if repetitionContext.shape[0] > 0 {
let indices = repetitionContext
var selectedLogits = take(logits, indices, axis: -1).squeezed(axis: 0)
var selectedLogits = logits[0..., indices]
selectedLogits = MLX.where(
selectedLogits .< 0, selectedLogits * penalty, selectedLogits / penalty)
@@ -100,7 +100,7 @@ public struct TokenIterator: Sequence, IteratorProtocol {
if prompt.shape[0] <= parameters.repetitionContextSize {
self.repetitionContext = prompt
} else {
self.repetitionContext = prompt[-parameters.repetitionContextSize ... -1]
self.repetitionContext = prompt[(-parameters.repetitionContextSize)...]
}
} else {
self.repetitionContext = []
@@ -120,9 +120,8 @@ public struct TokenIterator: Sequence, IteratorProtocol {
y = sample(logits: logits, temp: parameters.temperature, topP: parameters.topP)
// append the current token to the context and check repetitionPenalty context see if need to remove the first token
if parameters.repetitionContextSize > 1 {
repetitionContext = concatenated([repetitionContext, y], axis: 0)
if repetitionContext.shape[0] > parameters.repetitionContextSize {
repetitionContext = repetitionContext[1...]
repetitionContext = repetitionContext[(-parameters.repetitionContextSize)...]
}
}
@@ -174,14 +173,31 @@ public enum GenerateDisposition {
/// - parameters: generation parameters
/// - model: model to evaluate
/// - tokenizer: tokenizer to convert tokens back into strings and recognizer special tokens
/// - configuration: the model configuration
/// - didGenerate: visitor for the tokens as they are generated
public func generate(
promptTokens: [Int], parameters: GenerateParameters, model: LLMModel, tokenizer: Tokenizer,
extraEOSTokens: Set<String>? = nil,
didGenerate: ([Int]) async -> GenerateDisposition
) async -> GenerateResult {
var start = Date.timeIntervalSinceReferenceDate
var promptTime: TimeInterval = 0
// build a set of additional stop tokens
let additionalEOSTokenIds = Set(
(extraEOSTokens ?? [])
.map {
tokenizer.encode(text: $0)
}
.filter {
// discard anything that is not a single token. sometimes
// the tokenizer will insert a <s> token, so accept that too
$0.count == 1 || ($0.count == 2 && $0[0] == 1)
}
.map {
$0.last!
})
var tokens = [Int]()
for token in TokenIterator(
@@ -196,7 +212,9 @@ public func generate(
}
let t = token.item(Int.self)
if t == tokenizer.unknownTokenId || t == tokenizer.eosTokenId {
if t == tokenizer.unknownTokenId || t == tokenizer.eosTokenId
|| additionalEOSTokenIds.contains(t)
{
break
}