@@ -60,16 +60,16 @@ public struct GenerateParameters {
|
||||
public var temperature: Float = 0.6
|
||||
|
||||
/// top p sampling
|
||||
public var topP: Float = 0.9
|
||||
public var topP: Float = 1.0
|
||||
|
||||
/// penalty factor for repeating tokens
|
||||
public var repetitionPenalty: Float = 1.0
|
||||
public var repetitionPenalty: Float?
|
||||
|
||||
/// number of tokens to consider for repetition penalty
|
||||
public var repetitionContextSize: Int = 20
|
||||
|
||||
public init(
|
||||
temperature: Float = 0.6, topP: Float = 0.9, repetitionPenalty: Float = 1.0,
|
||||
temperature: Float = 0.6, topP: Float = 1.0, repetitionPenalty: Float? = nil,
|
||||
repetitionContextSize: Int = 20
|
||||
) {
|
||||
self.temperature = temperature
|
||||
@@ -111,11 +111,11 @@ public struct TokenIterator: Sequence, IteratorProtocol {
|
||||
var logits: MLXArray
|
||||
(logits, cache) = model(expandedDimensions(y, axis: 0), cache: cache.isEmpty ? nil : cache)
|
||||
logits = logits[0..., -1, 0...]
|
||||
if parameters.repetitionPenalty > 1.0 {
|
||||
if let repetitionPenalty = parameters.repetitionPenalty {
|
||||
// apply repetition penalty
|
||||
logits = applyRepetitionPenalty(
|
||||
logits: logits, repetitionContext: repetitionContext,
|
||||
penalty: parameters.repetitionPenalty)
|
||||
penalty: repetitionPenalty)
|
||||
}
|
||||
y = sample(logits: logits, temp: parameters.temperature, topP: parameters.topP)
|
||||
// append the current token to the context and check repetitionPenalty context see if need to remove the first token
|
||||
|
||||
Reference in New Issue
Block a user