Files
mlx-swift-examples/Applications/LoRATrainingExample/ContentView.swift
David Koski 9d74afd119 handle partially quantized models (#76)
* handle partially quantized models

- fix for #53 #71 #69 #74
- in order to test the models
	- I added a default prompt of an appropriate form
	- while working on the model configuration also added additional stop tokens (#74)
- fixed the repetitionPenalty code (#71)
2024-05-28 16:35:11 -07:00

286 lines
8.5 KiB
Swift

// Copyright © 2024 Apple Inc.
import LLM
import MLX
import MLXOptimizers
import MLXRandom
import SwiftUI
import Tokenizers
struct ContentView: View {
@State var evaluator = LoRAEvaluator()
@State var prompt = """
table: 1-10015132-16
columns: Player, No., Nationality, Position, Years in Toronto, School/Club Team
Q: What is terrence ross' nationality
A:
"""
var body: some View {
VStack {
HStack {
if let progress = evaluator.progress {
if let current = progress.current, let limit = progress.limit {
ProgressView(progress.title, value: current, total: limit)
} else {
ProgressView(progress.title)
}
}
}
.frame(maxWidth: .infinity, minHeight: 25)
VStack {
ScrollView(.vertical) {
ScrollViewReader { sp in
Group {
Text(evaluator.output)
.textSelection(.enabled)
.frame(maxWidth: .infinity)
}
.onChange(of: evaluator.output) { _, _ in
sp.scrollTo("bottom")
}
.padding()
Spacer()
.frame(width: 1, height: 1)
.id("bottom")
}
}
// controls for each of the different states
VStack {
switch evaluator.state {
case .idle:
Button("Start", action: start)
case .training:
EmptyView()
case .evaluate:
Group {
TextEditor(text: $prompt)
.frame(minHeight: 60)
Button("Evaluate", action: evaluate)
}
.disabled(evaluator.progress != nil)
case .failed(let message):
Text("Failed: \(message)")
.bold()
.foregroundStyle(.red)
}
}
.padding()
.frame(maxWidth: .infinity)
}
}
.padding()
}
func start() {
Task {
await evaluator.start()
}
}
func evaluate() {
Task {
await evaluator.evaluate(prompt: prompt)
}
}
}
/// Progress reporting with a title.
struct Progress: Equatable {
let title: String
let current: Double?
let limit: Double?
}
@Observable
class LoRAEvaluator {
enum State {
case idle
case training
case evaluate
case failed(String)
}
enum ModelState {
case idle
case loaded(LLMModel, Tokenizer)
}
var state = State.idle
var progress: Progress?
var output = ""
private let modelConfiguration = ModelConfiguration.mistral7B4bit
private var model: ModelState = .idle
private let loraLayers = 4
private let learningRate: Float = 1e-5
private let parameters = LoRATrain.Parameters(batchSize: 1, iterations: 200)
private let generateParameters = GenerateParameters(temperature: 0.6, topP: 0.9)
private let evaluateShowEvery = 8
private let maxTokens = 200
private func loadModel() async throws -> (LLMModel, Tokenizer) {
switch self.model {
case .idle:
let name = modelConfiguration.name
await MainActor.run {
progress = .init(title: "Loading \(name)", current: 0, limit: 1)
}
let (model, tokenizer) = try await LLM.load(configuration: modelConfiguration) {
progress in
if progress.fractionCompleted < 1.0 {
DispatchQueue.main.sync {
self.progress = .init(
title: "Download \(name)", current: progress.fractionCompleted,
limit: 1.0)
}
}
}
eval(model)
self.model = .loaded(model, tokenizer)
return (model, tokenizer)
case .loaded(let model, let tokenizer):
return (model, tokenizer)
}
}
private func loadLoRAData(name: String) throws -> [String]? {
if let url = Bundle.main.url(forResource: name, withExtension: "jsonl") {
return try LLM.loadLoRAData(url: url)
}
return nil
}
func start() async {
do {
try await startInner()
} catch {
self.state = .failed("Failed: \(error)")
}
}
private func startInner() async throws {
// setup
GPU.set(cacheLimit: 32 * 1024 * 1024)
await MainActor.run {
output = ""
state = .training
}
// load the model
let (model, tokenizer) = try await loadModel()
// apply LoRA adapters and train
guard let layerProvider = model as? LoRAModel else {
state = .failed("Model must implement the LoRALayerProvider protocol")
return
}
LoRATrain.convert(
model: model, layers: Array(layerProvider.loraLinearLayers().suffix(loraLayers)))
let train = try loadLoRAData(name: "train")
let valid = try loadLoRAData(name: "valid")
guard let train, let valid else {
state = .failed("Failed to load train/validation data")
return
}
let optimizer = Adam(learningRate: learningRate)
try await LoRATrain.train(
model: model, train: train, validate: valid, optimizer: optimizer, tokenizer: tokenizer,
parameters: parameters
) { progress in
await MainActor.run {
switch progress {
case .train(let i, _, _, _):
self.progress = .init(
title: "Train", current: Double(i), limit: Double(parameters.iterations))
case .validation:
output += "\n"
default:
break
}
output += progress.description + "\n"
}
return .more
}
// done training, test
await MainActor.run {
self.progress = .init(title: "Testing", current: nil, limit: nil)
}
guard let test = try loadLoRAData(name: "test") else {
state = .failed("Failed to load test data")
return
}
let loss = LoRATrain.evaluate(
model: model, dataset: test, tokenizer: tokenizer, batchSize: 1, batchCount: 0)
await MainActor.run {
self.progress = nil
self.output += "\n"
self.output += "Test loss \(loss.formatted()), ppl \(exp(loss).formatted())\n"
self.state = .evaluate
}
}
func evaluate(prompt: String) async {
do {
try await evaluateInner(prompt: prompt)
} catch {
self.state = .failed("Failed: \(error)")
}
}
func evaluateInner(prompt: String) async throws {
await MainActor.run {
self.progress = .init(title: "Evaluating", current: nil, limit: nil)
self.output = ""
}
MLXRandom.seed(UInt64(Date.timeIntervalSinceReferenceDate * 1000))
let (model, tokenizer) = try await loadModel()
// prepare the prompt
let preparedPrompt = modelConfiguration.prepare(prompt: prompt)
let promptTokens = tokenizer.encode(text: preparedPrompt)
// evaluate
let result = await LLM.generate(
promptTokens: promptTokens, parameters: generateParameters, model: model,
tokenizer: tokenizer,
extraEOSTokens: modelConfiguration.extraEOSTokens,
didGenerate: { tokens in
if tokens.count % evaluateShowEvery == 0 {
let fullOutput = tokenizer.decode(tokens: tokens)
await MainActor.run {
self.output = fullOutput
}
}
return tokens.count >= maxTokens ? .stop : .more
})
await MainActor.run {
self.output = result.output
self.progress = nil
}
}
}