Files
mlx-swift-examples/Applications/LLMEval/ContentView.swift
2024-03-01 16:10:34 -08:00

124 lines
3.0 KiB
Swift

// Copyright © 2024 Apple Inc.
import LLM
import MLX
import Metal
import SwiftUI
import Tokenizers
struct ContentView: View {
@State var prompt = "compare python and swift"
@State var llm = LLMEvaluator()
var body: some View {
VStack {
ScrollView(.vertical) {
if llm.running {
ProgressView()
}
Text(llm.output)
}
HStack {
TextField("prompt", text: $prompt)
.onSubmit(generate)
.disabled(llm.running)
Button("generate", action: generate)
.disabled(llm.running)
}
}
.padding()
.task {
_ = try? await llm.load()
}
}
private func generate() {
Task {
await llm.generate(prompt: prompt)
}
}
}
@Observable
class LLMEvaluator {
@MainActor
var running = false
var output = ""
let modelConfiguration = ModelConfiguration.phi4bit
let temperature: Float = 0.0
let maxTokens = 100
enum LoadState {
case idle
case loaded(LLMModel, LLM.Tokenizer)
}
var loadState = LoadState.idle
func load() async throws -> (LLMModel, LLM.Tokenizer) {
switch loadState {
case .idle:
let (model, tokenizer) = try await LLM.load(configuration: modelConfiguration) {
[modelConfiguration] progress in
DispatchQueue.main.sync {
self.output =
"Downloading \(modelConfiguration.id): \(Int(progress.fractionCompleted * 100))%"
}
}
loadState = .loaded(model, tokenizer)
return (model, tokenizer)
case .loaded(let model, let tokenizer):
return (model, tokenizer)
}
}
func generate(prompt: String) async {
do {
let (model, tokenizer) = try await load()
await MainActor.run {
running = true
self.output = ""
}
let prompt = modelConfiguration.prepare(prompt: prompt)
let promptTokens = MLXArray(tokenizer.encode(text: prompt))
var outputTokens = [Int]()
for token in TokenIterator(prompt: promptTokens, model: model, temp: temperature) {
let tokenId = token.item(Int.self)
if tokenId == tokenizer.unknownTokenId {
break
}
outputTokens.append(tokenId)
let text = tokenizer.decode(tokens: outputTokens)
await MainActor.run {
self.output = text
}
if outputTokens.count == maxTokens {
break
}
}
await MainActor.run {
running = false
}
} catch {
await MainActor.run {
running = false
output = "Failed: \(error)"
}
}
}
}