diff --git a/Libraries/LLM/Evaluate.swift b/Libraries/LLM/Evaluate.swift index 7420b26..bfc04b4 100644 --- a/Libraries/LLM/Evaluate.swift +++ b/Libraries/LLM/Evaluate.swift @@ -71,7 +71,7 @@ public struct TokenIterator: Sequence, IteratorProtocol { /// /// Note that because MLXArray is not thread safe this eval's the result and sends the TokenId back /// to the caller. -public func generate(prompt: MLXArray, model: LLMModel, temp: Float = 0.0) -> ( +public func generate(prompt: MLXArray, model: LLMModel, temp: Float = 0.0, topP: Float = 1.0) -> ( Task, AsyncBufferSequence> ) { let channel = AsyncChannel() @@ -85,7 +85,7 @@ public func generate(prompt: MLXArray, model: LLMModel, temp: Float = 0.0) -> ( var logits: MLXArray (logits, cache) = model( expandedDimensions(y, axis: 0), cache: cache.isEmpty ? nil : cache) - y = sample(logits: logits[-1, axis: 1], temp: temp) + y = sample(logits: logits[-1, axis: 1], temp: temp, topP: topP) eval(y) await channel.send(y.item(Int.self)) diff --git a/Tools/llm-tool/LLMTool.swift b/Tools/llm-tool/LLMTool.swift index 98be194..e2d2133 100644 --- a/Tools/llm-tool/LLMTool.swift +++ b/Tools/llm-tool/LLMTool.swift @@ -29,6 +29,9 @@ struct LLMArguments: ParsableArguments { @Option(name: .shortAndLong, help: "The sampling temperature") var temperature: Float = 0.6 + @Option(name: .shortAndLong, help: "The top p sampling") + var topP: Float = 0.9 + @Option(name: .long, help: "The PRNG seed") var seed: UInt64 = 0 @@ -127,7 +130,7 @@ struct SyncGenerator: AsyncParsableCommand { var printed = 0 for token in TokenIterator( - prompt: MLXArray(promptTokens), model: model, temp: args.temperature) + prompt: MLXArray(promptTokens), model: model, temp: args.temperature, topP: args.topP) { if tokens.isEmpty { eval(token) @@ -205,7 +208,7 @@ struct AsyncGenerator: AsyncParsableCommand { var printed = 0 let (task, channel) = generate( - prompt: MLXArray(promptTokens), model: model, temp: args.temperature) + prompt: MLXArray(promptTokens), model: model, temp: args.temperature, topP: args.topP) for await token in channel { if tokens.isEmpty { diff --git a/mlx-swift-examples.xcodeproj/project.xcworkspace/xcshareddata/swiftpm/Package.resolved b/mlx-swift-examples.xcodeproj/project.xcworkspace/xcshareddata/swiftpm/Package.resolved index fcc58d5..3628bc8 100644 --- a/mlx-swift-examples.xcodeproj/project.xcworkspace/xcshareddata/swiftpm/Package.resolved +++ b/mlx-swift-examples.xcodeproj/project.xcworkspace/xcshareddata/swiftpm/Package.resolved @@ -1,5 +1,4 @@ { - "originHash" : "da53546673b6d05016b6e5640c18814c7dba5b5af8db34715afe6d633037c758", "pins" : [ { "identity" : "gzipswift", @@ -79,9 +78,9 @@ "location" : "https://github.com/huggingface/swift-transformers", "state" : { "branch" : "main", - "revision" : "3bd02269b7797ade67c15679a575cd5c6f203ce6" + "revision" : "9d82e00af680253499f1a9372abb2552a73527fb" } } ], - "version" : 3 + "version" : 2 } diff --git a/mlx-swift-examples.xcodeproj/xcshareddata/xcschemes/llm-tool.xcscheme b/mlx-swift-examples.xcodeproj/xcshareddata/xcschemes/llm-tool.xcscheme index 5dcb893..402aa35 100644 --- a/mlx-swift-examples.xcodeproj/xcshareddata/xcschemes/llm-tool.xcscheme +++ b/mlx-swift-examples.xcodeproj/xcshareddata/xcschemes/llm-tool.xcscheme @@ -55,6 +55,10 @@ argument = "--model mlx-community/CodeLlama-13b-Instruct-hf-4bit-MLX" isEnabled = "NO"> + +