chore: add repetition_penalty example (#45)

This commit is contained in:
Anchen
2024-04-05 09:15:50 +11:00
committed by GitHub
parent 2d0fdfe3a9
commit c27208812d
4 changed files with 102 additions and 10 deletions

View File

@@ -32,6 +32,12 @@ struct LLMArguments: ParsableArguments {
@Option(name: .shortAndLong, help: "The top p sampling")
var topP: Float = 0.9
@Option(name: .shortAndLong, help: "The penalty factor for repeating tokens")
var repetitionPenalty: Float = 1.0
@Option(name: .shortAndLong, help: "The number of tokens to consider for repetition penalty")
var repetitionContextSize: Int = 20
@Option(name: .long, help: "The PRNG seed")
var seed: UInt64 = 0
@@ -130,7 +136,9 @@ struct SyncGenerator: AsyncParsableCommand {
var printed = 0
for token in TokenIterator(
prompt: MLXArray(promptTokens), model: model, temp: args.temperature, topP: args.topP)
prompt: MLXArray(promptTokens), model: model, temp: args.temperature, topP: args.topP,
repetitionPenalty: args.repetitionPenalty,
repetitionContextSize: args.repetitionContextSize)
{
if tokens.isEmpty {
eval(token)
@@ -208,7 +216,9 @@ struct AsyncGenerator: AsyncParsableCommand {
var printed = 0
let (task, channel) = generate(
prompt: MLXArray(promptTokens), model: model, temp: args.temperature, topP: args.topP)
prompt: MLXArray(promptTokens), model: model, temp: args.temperature, topP: args.topP,
repetitionPenalty: args.repetitionPenalty,
repetitionContextSize: args.repetitionContextSize)
for await token in channel {
if tokens.isEmpty {