chore(llm-tool): add the top_p option in the llm-tool (#41)

* chore: add top p option in llm-tool
* chore: wire up the top p with async generate
This commit is contained in:
Anchen
2024-04-04 01:54:54 +11:00
committed by GitHub
parent b3eb428c60
commit 2d0fdfe3a9
4 changed files with 13 additions and 7 deletions

View File

@@ -71,7 +71,7 @@ public struct TokenIterator: Sequence, IteratorProtocol {
///
/// Note that because MLXArray is not thread safe this eval's the result and sends the TokenId back
/// to the caller.
public func generate(prompt: MLXArray, model: LLMModel, temp: Float = 0.0) -> (
public func generate(prompt: MLXArray, model: LLMModel, temp: Float = 0.0, topP: Float = 1.0) -> (
Task<Void, Never>, AsyncBufferSequence<AsyncChannel<Int>>
) {
let channel = AsyncChannel<Int>()
@@ -85,7 +85,7 @@ public func generate(prompt: MLXArray, model: LLMModel, temp: Float = 0.0) -> (
var logits: MLXArray
(logits, cache) = model(
expandedDimensions(y, axis: 0), cache: cache.isEmpty ? nil : cache)
y = sample(logits: logits[-1, axis: 1], temp: temp)
y = sample(logits: logits[-1, axis: 1], temp: temp, topP: topP)
eval(y)
await channel.send(y.item(Int.self))