* phi3

Co-authored-by: David Koski <dkoski@apple.com>
This commit is contained in:
Awni Hannun
2024-04-24 09:31:01 -07:00
committed by GitHub
parent 6c0b66f90a
commit b951b78eb2
7 changed files with 284 additions and 8 deletions

View File

@@ -60,16 +60,16 @@ public struct GenerateParameters {
public var temperature: Float = 0.6
/// top p sampling
public var topP: Float = 0.9
public var topP: Float = 1.0
/// penalty factor for repeating tokens
public var repetitionPenalty: Float = 1.0
public var repetitionPenalty: Float?
/// number of tokens to consider for repetition penalty
public var repetitionContextSize: Int = 20
public init(
temperature: Float = 0.6, topP: Float = 0.9, repetitionPenalty: Float = 1.0,
temperature: Float = 0.6, topP: Float = 1.0, repetitionPenalty: Float? = nil,
repetitionContextSize: Int = 20
) {
self.temperature = temperature
@@ -111,11 +111,11 @@ public struct TokenIterator: Sequence, IteratorProtocol {
var logits: MLXArray
(logits, cache) = model(expandedDimensions(y, axis: 0), cache: cache.isEmpty ? nil : cache)
logits = logits[0..., -1, 0...]
if parameters.repetitionPenalty > 1.0 {
if let repetitionPenalty = parameters.repetitionPenalty {
// apply repetition penalty
logits = applyRepetitionPenalty(
logits: logits, repetitionContext: repetitionContext,
penalty: parameters.repetitionPenalty)
penalty: repetitionPenalty)
}
y = sample(logits: logits, temp: parameters.temperature, topP: parameters.topP)
// append the current token to the context and check repetitionPenalty context see if need to remove the first token