From 3314e20a2446d0b6a79dc169cdd16bf51493e819 Mon Sep 17 00:00:00 2001 From: Anchen Date: Wed, 27 Mar 2024 06:44:13 +1100 Subject: [PATCH] chore: add top_p sampling example (#34) --- Libraries/LLM/Evaluate.swift | 32 ++++++++++++++++++++++++++++---- 1 file changed, 28 insertions(+), 4 deletions(-) diff --git a/Libraries/LLM/Evaluate.swift b/Libraries/LLM/Evaluate.swift index be2c51a..7420b26 100644 --- a/Libraries/LLM/Evaluate.swift +++ b/Libraries/LLM/Evaluate.swift @@ -5,10 +5,33 @@ import Foundation import MLX import MLXRandom -private func sample(logits: MLXArray, temp: Float) -> MLXArray { +private func topPSampling(logits: MLXArray, topP: Float, temp: Float) -> MLXArray { + var logits = logits + if logits.dtype == .bfloat16 { + logits = logits.asType(.float32) + } + + let probs = softMax(logits / temp, axis: -1) + let sortedIndices = argSort(probs, axis: -1) + + // probs shape is [B,V] and after take it will be [1, B, V], so we squeeze it back to [B, V] + let sortedProbs = take(probs, sortedIndices, axis: -1).squeezed(axis: 0) + + let cumulativeProbs = cumsum(sortedProbs, axis: -1) + + let topProbs = MLX.where(cumulativeProbs .> (1 - topP), sortedProbs, zeros(like: sortedProbs)) + + let sortedToken = categorical(log(topProbs)) + return sortedIndices.squeezed(axis: 0)[sortedToken] +} + +private func sample(logits: MLXArray, temp: Float, topP: Float = 1.0) -> MLXArray { if temp == 0 { return argMax(logits, axis: -1) } else { + if topP > 0 && topP < 1 { + return topPSampling(logits: logits, topP: topP, temp: temp) + } return categorical(logits * (1 / temp)) } } @@ -19,15 +42,16 @@ private func sample(logits: MLXArray, temp: Float) -> MLXArray { public struct TokenIterator: Sequence, IteratorProtocol { let model: LLMModel let temp: Float - + let topP: Float var y: MLXArray var cache: [(MLXArray, MLXArray)] var first = true - public init(prompt: MLXArray, model: LLMModel, temp: Float = 0.0) { + public init(prompt: MLXArray, model: LLMModel, temp: Float = 0.0, topP: Float = 1.0) { self.model = model self.temp = temp + self.topP = topP self.y = prompt self.cache = [] } @@ -35,7 +59,7 @@ public struct TokenIterator: Sequence, IteratorProtocol { mutating public func next() -> MLXArray? { var logits: MLXArray (logits, cache) = model(expandedDimensions(y, axis: 0), cache: cache.isEmpty ? nil : cache) - y = sample(logits: logits[-1, axis: 1], temp: temp) + y = sample(logits: logits[-1, axis: 1], temp: temp, topP: topP) return y }