handle partially quantized models (#76)
* handle partially quantized models - fix for #53 #71 #69 #74 - in order to test the models - I added a default prompt of an appropriate form - while working on the model configuration also added additional stop tokens (#74) - fixed the repetitionPenalty code (#71)
This commit is contained in:
@@ -12,7 +12,7 @@ private func topPSampling(logits: MLXArray, topP: Float, temp: Float) -> MLXArra
|
||||
logits = logits.asType(.float32)
|
||||
}
|
||||
|
||||
let probs = softMax(logits / temp, axis: -1)
|
||||
let probs = softmax(logits / temp, axis: -1)
|
||||
let sortedIndices = argSort(probs, axis: -1)
|
||||
|
||||
// probs shape is [B,V] and after take it will be [1, B, V], so we squeeze it back to [B, V]
|
||||
@@ -31,7 +31,7 @@ private func applyRepetitionPenalty(
|
||||
) -> MLXArray {
|
||||
if repetitionContext.shape[0] > 0 {
|
||||
let indices = repetitionContext
|
||||
var selectedLogits = take(logits, indices, axis: -1).squeezed(axis: 0)
|
||||
var selectedLogits = logits[0..., indices]
|
||||
|
||||
selectedLogits = MLX.where(
|
||||
selectedLogits .< 0, selectedLogits * penalty, selectedLogits / penalty)
|
||||
@@ -100,7 +100,7 @@ public struct TokenIterator: Sequence, IteratorProtocol {
|
||||
if prompt.shape[0] <= parameters.repetitionContextSize {
|
||||
self.repetitionContext = prompt
|
||||
} else {
|
||||
self.repetitionContext = prompt[-parameters.repetitionContextSize ... -1]
|
||||
self.repetitionContext = prompt[(-parameters.repetitionContextSize)...]
|
||||
}
|
||||
} else {
|
||||
self.repetitionContext = []
|
||||
@@ -120,9 +120,8 @@ public struct TokenIterator: Sequence, IteratorProtocol {
|
||||
y = sample(logits: logits, temp: parameters.temperature, topP: parameters.topP)
|
||||
// append the current token to the context and check repetitionPenalty context see if need to remove the first token
|
||||
if parameters.repetitionContextSize > 1 {
|
||||
repetitionContext = concatenated([repetitionContext, y], axis: 0)
|
||||
if repetitionContext.shape[0] > parameters.repetitionContextSize {
|
||||
repetitionContext = repetitionContext[1...]
|
||||
repetitionContext = repetitionContext[(-parameters.repetitionContextSize)...]
|
||||
}
|
||||
}
|
||||
|
||||
@@ -174,14 +173,31 @@ public enum GenerateDisposition {
|
||||
/// - parameters: generation parameters
|
||||
/// - model: model to evaluate
|
||||
/// - tokenizer: tokenizer to convert tokens back into strings and recognizer special tokens
|
||||
/// - configuration: the model configuration
|
||||
/// - didGenerate: visitor for the tokens as they are generated
|
||||
public func generate(
|
||||
promptTokens: [Int], parameters: GenerateParameters, model: LLMModel, tokenizer: Tokenizer,
|
||||
extraEOSTokens: Set<String>? = nil,
|
||||
didGenerate: ([Int]) async -> GenerateDisposition
|
||||
) async -> GenerateResult {
|
||||
var start = Date.timeIntervalSinceReferenceDate
|
||||
var promptTime: TimeInterval = 0
|
||||
|
||||
// build a set of additional stop tokens
|
||||
let additionalEOSTokenIds = Set(
|
||||
(extraEOSTokens ?? [])
|
||||
.map {
|
||||
tokenizer.encode(text: $0)
|
||||
}
|
||||
.filter {
|
||||
// discard anything that is not a single token. sometimes
|
||||
// the tokenizer will insert a <s> token, so accept that too
|
||||
$0.count == 1 || ($0.count == 2 && $0[0] == 1)
|
||||
}
|
||||
.map {
|
||||
$0.last!
|
||||
})
|
||||
|
||||
var tokens = [Int]()
|
||||
|
||||
for token in TokenIterator(
|
||||
@@ -196,7 +212,9 @@ public func generate(
|
||||
}
|
||||
|
||||
let t = token.item(Int.self)
|
||||
if t == tokenizer.unknownTokenId || t == tokenizer.eosTokenId {
|
||||
if t == tokenizer.unknownTokenId || t == tokenizer.eosTokenId
|
||||
|| additionalEOSTokenIds.contains(t)
|
||||
{
|
||||
break
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user