llm improvements
- document the tokenizer used (https://github.com/huggingface/swift-transformers) - provide a hook for tokenizer configuration, prompt augmentation - this isn't as rich as the python equivalents but it helps a little
This commit is contained in:
@@ -25,7 +25,7 @@ struct SyncGenerator: AsyncParsableCommand {
|
||||
var model: String = "mlx-community/Mistral-7B-v0.1-hf-4bit-mlx"
|
||||
|
||||
@Option(name: .shortAndLong, help: "The message to be processed by the model")
|
||||
var prompt = "compare swift and python"
|
||||
var prompt = "compare python and swift"
|
||||
|
||||
@Option(name: .shortAndLong, help: "Maximum number of tokens to generate")
|
||||
var maxTokens = 100
|
||||
@@ -40,22 +40,24 @@ struct SyncGenerator: AsyncParsableCommand {
|
||||
func run() async throws {
|
||||
MLXRandom.seed(seed)
|
||||
|
||||
let (model, tokenizer) = try await load(name: model)
|
||||
let modelConfiguration = ModelConfiguration.configuration(id: model)
|
||||
let (model, tokenizer) = try await load(configuration: modelConfiguration)
|
||||
|
||||
let prompt = modelConfiguration.prepare(prompt: self.prompt)
|
||||
let promptTokens = tokenizer.encode(text: prompt)
|
||||
|
||||
print("Starting generation ...")
|
||||
print(prompt, terminator: "")
|
||||
print(self.prompt, terminator: "")
|
||||
|
||||
var start = Date.timeIntervalSinceReferenceDate
|
||||
var promptTime: TimeInterval = 0
|
||||
|
||||
let prompt = MLXArray(tokenizer.encode(text: prompt))
|
||||
|
||||
// collect the tokens and keep track of how much of the string
|
||||
// we have printed already
|
||||
var tokens = [Int]()
|
||||
var printed = 0
|
||||
|
||||
for token in TokenIterator(prompt: prompt, model: model, temp: temperature) {
|
||||
for token in TokenIterator(prompt: MLXArray(promptTokens), model: model, temp: temperature) {
|
||||
if tokens.isEmpty {
|
||||
eval(token)
|
||||
let now = Date.timeIntervalSinceReferenceDate
|
||||
@@ -90,7 +92,7 @@ struct SyncGenerator: AsyncParsableCommand {
|
||||
|
||||
print(
|
||||
"""
|
||||
Prompt Tokens per second: \((Double(prompt.size) / promptTime).formatted())
|
||||
Prompt Tokens per second: \((Double(promptTokens.count) / promptTime).formatted())
|
||||
Generation tokens per second: \((Double(tokens.count - 1) / generateTime).formatted())
|
||||
""")
|
||||
}
|
||||
@@ -111,7 +113,7 @@ struct AsyncGenerator: AsyncParsableCommand {
|
||||
var model: String = "mlx-community/Mistral-7B-v0.1-hf-4bit-mlx"
|
||||
|
||||
@Option(name: .shortAndLong, help: "The message to be processed by the model")
|
||||
var prompt = "compare swift and python"
|
||||
var prompt = "compare python and swift"
|
||||
|
||||
@Option(name: .shortAndLong, help: "Maximum number of tokens to generate")
|
||||
var maxTokens = 100
|
||||
@@ -126,22 +128,24 @@ struct AsyncGenerator: AsyncParsableCommand {
|
||||
func run() async throws {
|
||||
MLXRandom.seed(seed)
|
||||
|
||||
let (model, tokenizer) = try await load(name: model)
|
||||
let modelConfiguration = ModelConfiguration.configuration(id: model)
|
||||
let (model, tokenizer) = try await load(configuration: modelConfiguration)
|
||||
|
||||
let prompt = modelConfiguration.prepare(prompt: self.prompt)
|
||||
let promptTokens = tokenizer.encode(text: prompt)
|
||||
|
||||
print("Starting generation ...")
|
||||
print(prompt, terminator: "")
|
||||
print(self.prompt, terminator: "")
|
||||
|
||||
var start = Date.timeIntervalSinceReferenceDate
|
||||
var promptTime: TimeInterval = 0
|
||||
|
||||
let prompt = MLXArray(tokenizer.encode(text: prompt))
|
||||
|
||||
// collect the tokens and keep track of how much of the string
|
||||
// we have printed already
|
||||
var tokens = [Int]()
|
||||
var printed = 0
|
||||
|
||||
let (task, channel) = generate(prompt: prompt, model: model, temp: temperature)
|
||||
let (task, channel) = generate(prompt: MLXArray(promptTokens), model: model, temp: temperature)
|
||||
|
||||
for await token in channel {
|
||||
if tokens.isEmpty {
|
||||
@@ -179,7 +183,7 @@ struct AsyncGenerator: AsyncParsableCommand {
|
||||
|
||||
print(
|
||||
"""
|
||||
Prompt Tokens per second: \((Double(prompt.size) / promptTime).formatted())
|
||||
Prompt Tokens per second: \((Double(promptTokens.count) / promptTime).formatted())
|
||||
Generation tokens per second: \((Double(tokens.count - 1) / generateTime).formatted())
|
||||
""")
|
||||
|
||||
|
||||
Reference in New Issue
Block a user