chore: add repetition_penalty example (#45)
This commit is contained in:
@@ -32,6 +32,12 @@ struct LLMArguments: ParsableArguments {
|
||||
@Option(name: .shortAndLong, help: "The top p sampling")
|
||||
var topP: Float = 0.9
|
||||
|
||||
@Option(name: .shortAndLong, help: "The penalty factor for repeating tokens")
|
||||
var repetitionPenalty: Float = 1.0
|
||||
|
||||
@Option(name: .shortAndLong, help: "The number of tokens to consider for repetition penalty")
|
||||
var repetitionContextSize: Int = 20
|
||||
|
||||
@Option(name: .long, help: "The PRNG seed")
|
||||
var seed: UInt64 = 0
|
||||
|
||||
@@ -130,7 +136,9 @@ struct SyncGenerator: AsyncParsableCommand {
|
||||
var printed = 0
|
||||
|
||||
for token in TokenIterator(
|
||||
prompt: MLXArray(promptTokens), model: model, temp: args.temperature, topP: args.topP)
|
||||
prompt: MLXArray(promptTokens), model: model, temp: args.temperature, topP: args.topP,
|
||||
repetitionPenalty: args.repetitionPenalty,
|
||||
repetitionContextSize: args.repetitionContextSize)
|
||||
{
|
||||
if tokens.isEmpty {
|
||||
eval(token)
|
||||
@@ -208,7 +216,9 @@ struct AsyncGenerator: AsyncParsableCommand {
|
||||
var printed = 0
|
||||
|
||||
let (task, channel) = generate(
|
||||
prompt: MLXArray(promptTokens), model: model, temp: args.temperature, topP: args.topP)
|
||||
prompt: MLXArray(promptTokens), model: model, temp: args.temperature, topP: args.topP,
|
||||
repetitionPenalty: args.repetitionPenalty,
|
||||
repetitionContextSize: args.repetitionContextSize)
|
||||
|
||||
for await token in channel {
|
||||
if tokens.isEmpty {
|
||||
|
||||
Reference in New Issue
Block a user