implement LoRA / QLoRA (#46)

* implement LoRA / QLoRA

- example of using MLX to fine-tune an LLM with low rank adaptation (LoRA) for a target task
- see also https://arxiv.org/abs/2106.09685
- based on https://github.com/ml-explore/mlx-examples/tree/main/lora

* add some command line flags I found useful during use
- --quiet -- don't print decorator text, just the generated text
- --prompt @/tmp/file.txt -- load prompt from file

* user can specify path to model OR model identifier in huggingface

* update mlx-swift reference

Co-authored-by: Ashraful Islam <ashraful.meche@gmail.com>
Co-authored-by: JustinMeans <46542161+JustinMeans@users.noreply.github.com>
This commit is contained in:
David Koski
2024-04-22 09:30:12 -07:00
committed by GitHub
parent 7e85eb8b88
commit 6c0b66f90a
32 changed files with 3483 additions and 64 deletions

View File

@@ -0,0 +1,11 @@
{
"colors" : [
{
"idiom" : "universal"
}
],
"info" : {
"author" : "xcode",
"version" : 1
}
}

View File

@@ -0,0 +1,63 @@
{
"images" : [
{
"idiom" : "universal",
"platform" : "ios",
"size" : "1024x1024"
},
{
"idiom" : "mac",
"scale" : "1x",
"size" : "16x16"
},
{
"idiom" : "mac",
"scale" : "2x",
"size" : "16x16"
},
{
"idiom" : "mac",
"scale" : "1x",
"size" : "32x32"
},
{
"idiom" : "mac",
"scale" : "2x",
"size" : "32x32"
},
{
"idiom" : "mac",
"scale" : "1x",
"size" : "128x128"
},
{
"idiom" : "mac",
"scale" : "2x",
"size" : "128x128"
},
{
"idiom" : "mac",
"scale" : "1x",
"size" : "256x256"
},
{
"idiom" : "mac",
"scale" : "2x",
"size" : "256x256"
},
{
"idiom" : "mac",
"scale" : "1x",
"size" : "512x512"
},
{
"idiom" : "mac",
"scale" : "2x",
"size" : "512x512"
}
],
"info" : {
"author" : "xcode",
"version" : 1
}
}

View File

@@ -0,0 +1,6 @@
{
"info" : {
"author" : "xcode",
"version" : 1
}
}

View File

@@ -0,0 +1,284 @@
// Copyright © 2024 Apple Inc.
import LLM
import MLX
import MLXOptimizers
import MLXRandom
import SwiftUI
import Tokenizers
struct ContentView: View {
@State var evaluator = LoRAEvaluator()
@State var prompt = """
table: 1-10015132-16
columns: Player, No., Nationality, Position, Years in Toronto, School/Club Team
Q: What is terrence ross' nationality
A:
"""
var body: some View {
VStack {
HStack {
if let progress = evaluator.progress {
if let current = progress.current, let limit = progress.limit {
ProgressView(progress.title, value: current, total: limit)
} else {
ProgressView(progress.title)
}
}
}
.frame(maxWidth: .infinity, minHeight: 25)
VStack {
ScrollView(.vertical) {
ScrollViewReader { sp in
Group {
Text(evaluator.output)
.textSelection(.enabled)
.frame(maxWidth: .infinity)
}
.onChange(of: evaluator.output) { _, _ in
sp.scrollTo("bottom")
}
.padding()
Spacer()
.frame(width: 1, height: 1)
.id("bottom")
}
}
// controls for each of the different states
VStack {
switch evaluator.state {
case .idle:
Button("Start", action: start)
case .training:
EmptyView()
case .evaluate:
Group {
TextEditor(text: $prompt)
.frame(minHeight: 60)
Button("Evaluate", action: evaluate)
}
.disabled(evaluator.progress != nil)
case .failed(let message):
Text("Failed: \(message)")
.bold()
.foregroundStyle(.red)
}
}
.padding()
.frame(maxWidth: .infinity)
}
}
.padding()
}
func start() {
Task {
await evaluator.start()
}
}
func evaluate() {
Task {
await evaluator.evaluate(prompt: prompt)
}
}
}
/// Progress reporting with a title.
struct Progress: Equatable {
let title: String
let current: Double?
let limit: Double?
}
@Observable
class LoRAEvaluator {
enum State {
case idle
case training
case evaluate
case failed(String)
}
enum ModelState {
case idle
case loaded(LLMModel, Tokenizer)
}
var state = State.idle
var progress: Progress?
var output = ""
private let modelConfiguration = ModelConfiguration.mistral7B4bit
private var model: ModelState = .idle
private let loraLayers = 4
private let learningRate: Float = 1e-5
private let parameters = LoRATrain.Parameters(batchSize: 1, iterations: 200)
private let generateParameters = GenerateParameters(temperature: 0.6, topP: 0.9)
private let evaluateShowEvery = 8
private let maxTokens = 200
private func loadModel() async throws -> (LLMModel, Tokenizer) {
switch self.model {
case .idle:
let name = modelConfiguration.name
await MainActor.run {
progress = .init(title: "Loading \(name)", current: 0, limit: 1)
}
let (model, tokenizer) = try await LLM.load(configuration: modelConfiguration) {
progress in
if progress.fractionCompleted < 1.0 {
DispatchQueue.main.sync {
self.progress = .init(
title: "Download \(name)", current: progress.fractionCompleted,
limit: 1.0)
}
}
}
eval(model)
self.model = .loaded(model, tokenizer)
return (model, tokenizer)
case .loaded(let model, let tokenizer):
return (model, tokenizer)
}
}
private func loadLoRAData(name: String) throws -> [String]? {
if let url = Bundle.main.url(forResource: name, withExtension: "jsonl") {
return try LLM.loadLoRAData(url: url)
}
return nil
}
func start() async {
do {
try await startInner()
} catch {
self.state = .failed("Failed: \(error)")
}
}
private func startInner() async throws {
// setup
GPU.set(cacheLimit: 32 * 1024 * 1024)
await MainActor.run {
output = ""
state = .training
}
// load the model
let (model, tokenizer) = try await loadModel()
// apply LoRA adapters and train
guard let layerProvider = model as? LoRAModel else {
state = .failed("Model must implement the LoRALayerProvider protocol")
return
}
LoRATrain.convert(
model: model, layers: Array(layerProvider.loraLinearLayers().suffix(loraLayers)))
let train = try loadLoRAData(name: "train")
let valid = try loadLoRAData(name: "valid")
guard let train, let valid else {
state = .failed("Failed to load train/validation data")
return
}
let optimizer = Adam(learningRate: learningRate)
try await LoRATrain.train(
model: model, train: train, validate: valid, optimizer: optimizer, tokenizer: tokenizer,
parameters: parameters
) { progress in
await MainActor.run {
switch progress {
case .train(let i, _, _, _):
self.progress = .init(
title: "Train", current: Double(i), limit: Double(parameters.iterations))
case .validation:
output += "\n"
default:
break
}
output += progress.description + "\n"
}
return .more
}
// done training, test
await MainActor.run {
self.progress = .init(title: "Testing", current: nil, limit: nil)
}
guard let test = try loadLoRAData(name: "test") else {
state = .failed("Failed to load test data")
return
}
let loss = LoRATrain.evaluate(
model: model, dataset: test, tokenizer: tokenizer, batchSize: 1, batchCount: 0)
await MainActor.run {
self.progress = nil
self.output += "\n"
self.output += "Test loss \(loss.formatted()), ppl \(exp(loss).formatted())\n"
self.state = .evaluate
}
}
func evaluate(prompt: String) async {
do {
try await evaluateInner(prompt: prompt)
} catch {
self.state = .failed("Failed: \(error)")
}
}
func evaluateInner(prompt: String) async throws {
await MainActor.run {
self.progress = .init(title: "Evaluating", current: nil, limit: nil)
self.output = ""
}
MLXRandom.seed(UInt64(Date.timeIntervalSinceReferenceDate * 1000))
let (model, tokenizer) = try await loadModel()
// prepare the prompt
let preparedPrompt = modelConfiguration.prepare(prompt: prompt)
let promptTokens = tokenizer.encode(text: preparedPrompt)
// evaluate
let result = await LLM.generate(
promptTokens: promptTokens, parameters: generateParameters, model: model,
tokenizer: tokenizer,
didGenerate: { tokens in
if tokens.count % evaluateShowEvery == 0 {
let fullOutput = tokenizer.decode(tokens: tokens)
await MainActor.run {
self.output = fullOutput
}
}
return tokens.count >= maxTokens ? .stop : .more
})
await MainActor.run {
self.output = result.output
self.progress = nil
}
}
}

View File

@@ -0,0 +1,14 @@
<?xml version="1.0" encoding="UTF-8"?>
<!DOCTYPE plist PUBLIC "-//Apple//DTD PLIST 1.0//EN" "http://www.apple.com/DTDs/PropertyList-1.0.dtd">
<plist version="1.0">
<dict>
<key>com.apple.developer.kernel.increased-memory-limit</key>
<true/>
<key>com.apple.security.app-sandbox</key>
<true/>
<key>com.apple.security.files.user-selected.read-only</key>
<true/>
<key>com.apple.security.network.client</key>
<true/>
</dict>
</plist>

View File

@@ -0,0 +1,12 @@
// Copyright © 2024 Apple Inc.
import SwiftUI
@main
struct LoRATrainingExampleApp: App {
var body: some Scene {
WindowGroup {
ContentView()
}
}
}

View File

@@ -0,0 +1,6 @@
{
"info" : {
"author" : "xcode",
"version" : 1
}
}

View File

@@ -0,0 +1,21 @@
# LoRATrainingExample
Example application that:
- downloads the `mlx-community/Mistral-7B-v0.1-hf-4bit-mlx` model from huggingface
- loads the train/valid/test data from `$SRCROOT/Data/lora` (this is copied into the build but you can imagine how it might be downloaded)
- adds LoRA adapters and trains the model
- let's you evaluate a prompt against the model
This roughly equates to the command line example in [Tools/llm-tool](../../Tools/llm-tool) and
you can read more about LoRA there.
This evaluates the LoRA adapted model rather than a fused model. This doesn't persist
the LoRA weights or the fused model -- it will retrain it each time the program is launched.
### Troubleshooting
The `mlx-community/Mistral-7B-v0.1-hf-4bit-mlx` model requires a little over 4G of
memory to load an train -- this may require ~6G of physical RAM.