implement LoRA / QLoRA (#46)
* implement LoRA / QLoRA - example of using MLX to fine-tune an LLM with low rank adaptation (LoRA) for a target task - see also https://arxiv.org/abs/2106.09685 - based on https://github.com/ml-explore/mlx-examples/tree/main/lora * add some command line flags I found useful during use - --quiet -- don't print decorator text, just the generated text - --prompt @/tmp/file.txt -- load prompt from file * user can specify path to model OR model identifier in huggingface * update mlx-swift reference Co-authored-by: Ashraful Islam <ashraful.meche@gmail.com> Co-authored-by: JustinMeans <46542161+JustinMeans@users.noreply.github.com>
This commit is contained in:
@@ -0,0 +1,11 @@
|
||||
{
|
||||
"colors" : [
|
||||
{
|
||||
"idiom" : "universal"
|
||||
}
|
||||
],
|
||||
"info" : {
|
||||
"author" : "xcode",
|
||||
"version" : 1
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,63 @@
|
||||
{
|
||||
"images" : [
|
||||
{
|
||||
"idiom" : "universal",
|
||||
"platform" : "ios",
|
||||
"size" : "1024x1024"
|
||||
},
|
||||
{
|
||||
"idiom" : "mac",
|
||||
"scale" : "1x",
|
||||
"size" : "16x16"
|
||||
},
|
||||
{
|
||||
"idiom" : "mac",
|
||||
"scale" : "2x",
|
||||
"size" : "16x16"
|
||||
},
|
||||
{
|
||||
"idiom" : "mac",
|
||||
"scale" : "1x",
|
||||
"size" : "32x32"
|
||||
},
|
||||
{
|
||||
"idiom" : "mac",
|
||||
"scale" : "2x",
|
||||
"size" : "32x32"
|
||||
},
|
||||
{
|
||||
"idiom" : "mac",
|
||||
"scale" : "1x",
|
||||
"size" : "128x128"
|
||||
},
|
||||
{
|
||||
"idiom" : "mac",
|
||||
"scale" : "2x",
|
||||
"size" : "128x128"
|
||||
},
|
||||
{
|
||||
"idiom" : "mac",
|
||||
"scale" : "1x",
|
||||
"size" : "256x256"
|
||||
},
|
||||
{
|
||||
"idiom" : "mac",
|
||||
"scale" : "2x",
|
||||
"size" : "256x256"
|
||||
},
|
||||
{
|
||||
"idiom" : "mac",
|
||||
"scale" : "1x",
|
||||
"size" : "512x512"
|
||||
},
|
||||
{
|
||||
"idiom" : "mac",
|
||||
"scale" : "2x",
|
||||
"size" : "512x512"
|
||||
}
|
||||
],
|
||||
"info" : {
|
||||
"author" : "xcode",
|
||||
"version" : 1
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,6 @@
|
||||
{
|
||||
"info" : {
|
||||
"author" : "xcode",
|
||||
"version" : 1
|
||||
}
|
||||
}
|
||||
284
Applications/LoRATrainingExample/ContentView.swift
Normal file
284
Applications/LoRATrainingExample/ContentView.swift
Normal file
@@ -0,0 +1,284 @@
|
||||
// Copyright © 2024 Apple Inc.
|
||||
|
||||
import LLM
|
||||
import MLX
|
||||
import MLXOptimizers
|
||||
import MLXRandom
|
||||
import SwiftUI
|
||||
import Tokenizers
|
||||
|
||||
struct ContentView: View {
|
||||
|
||||
@State var evaluator = LoRAEvaluator()
|
||||
|
||||
@State var prompt = """
|
||||
table: 1-10015132-16
|
||||
columns: Player, No., Nationality, Position, Years in Toronto, School/Club Team
|
||||
Q: What is terrence ross' nationality
|
||||
A:
|
||||
"""
|
||||
|
||||
var body: some View {
|
||||
VStack {
|
||||
HStack {
|
||||
if let progress = evaluator.progress {
|
||||
if let current = progress.current, let limit = progress.limit {
|
||||
ProgressView(progress.title, value: current, total: limit)
|
||||
} else {
|
||||
ProgressView(progress.title)
|
||||
}
|
||||
}
|
||||
}
|
||||
.frame(maxWidth: .infinity, minHeight: 25)
|
||||
|
||||
VStack {
|
||||
ScrollView(.vertical) {
|
||||
ScrollViewReader { sp in
|
||||
Group {
|
||||
Text(evaluator.output)
|
||||
.textSelection(.enabled)
|
||||
.frame(maxWidth: .infinity)
|
||||
}
|
||||
.onChange(of: evaluator.output) { _, _ in
|
||||
sp.scrollTo("bottom")
|
||||
}
|
||||
.padding()
|
||||
|
||||
Spacer()
|
||||
.frame(width: 1, height: 1)
|
||||
.id("bottom")
|
||||
}
|
||||
}
|
||||
|
||||
// controls for each of the different states
|
||||
VStack {
|
||||
switch evaluator.state {
|
||||
case .idle:
|
||||
Button("Start", action: start)
|
||||
|
||||
case .training:
|
||||
EmptyView()
|
||||
|
||||
case .evaluate:
|
||||
Group {
|
||||
TextEditor(text: $prompt)
|
||||
.frame(minHeight: 60)
|
||||
Button("Evaluate", action: evaluate)
|
||||
}
|
||||
.disabled(evaluator.progress != nil)
|
||||
|
||||
case .failed(let message):
|
||||
Text("Failed: \(message)")
|
||||
.bold()
|
||||
.foregroundStyle(.red)
|
||||
}
|
||||
}
|
||||
.padding()
|
||||
.frame(maxWidth: .infinity)
|
||||
}
|
||||
}
|
||||
.padding()
|
||||
}
|
||||
|
||||
func start() {
|
||||
Task {
|
||||
await evaluator.start()
|
||||
}
|
||||
}
|
||||
|
||||
func evaluate() {
|
||||
Task {
|
||||
await evaluator.evaluate(prompt: prompt)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Progress reporting with a title.
|
||||
struct Progress: Equatable {
|
||||
let title: String
|
||||
let current: Double?
|
||||
let limit: Double?
|
||||
}
|
||||
|
||||
@Observable
|
||||
class LoRAEvaluator {
|
||||
|
||||
enum State {
|
||||
case idle
|
||||
case training
|
||||
case evaluate
|
||||
case failed(String)
|
||||
}
|
||||
|
||||
enum ModelState {
|
||||
case idle
|
||||
case loaded(LLMModel, Tokenizer)
|
||||
}
|
||||
|
||||
var state = State.idle
|
||||
var progress: Progress?
|
||||
|
||||
var output = ""
|
||||
|
||||
private let modelConfiguration = ModelConfiguration.mistral7B4bit
|
||||
private var model: ModelState = .idle
|
||||
|
||||
private let loraLayers = 4
|
||||
private let learningRate: Float = 1e-5
|
||||
private let parameters = LoRATrain.Parameters(batchSize: 1, iterations: 200)
|
||||
|
||||
private let generateParameters = GenerateParameters(temperature: 0.6, topP: 0.9)
|
||||
private let evaluateShowEvery = 8
|
||||
private let maxTokens = 200
|
||||
|
||||
private func loadModel() async throws -> (LLMModel, Tokenizer) {
|
||||
switch self.model {
|
||||
case .idle:
|
||||
let name = modelConfiguration.name
|
||||
await MainActor.run {
|
||||
progress = .init(title: "Loading \(name)", current: 0, limit: 1)
|
||||
}
|
||||
|
||||
let (model, tokenizer) = try await LLM.load(configuration: modelConfiguration) {
|
||||
progress in
|
||||
if progress.fractionCompleted < 1.0 {
|
||||
DispatchQueue.main.sync {
|
||||
self.progress = .init(
|
||||
title: "Download \(name)", current: progress.fractionCompleted,
|
||||
limit: 1.0)
|
||||
}
|
||||
}
|
||||
}
|
||||
eval(model)
|
||||
self.model = .loaded(model, tokenizer)
|
||||
return (model, tokenizer)
|
||||
|
||||
case .loaded(let model, let tokenizer):
|
||||
return (model, tokenizer)
|
||||
}
|
||||
}
|
||||
|
||||
private func loadLoRAData(name: String) throws -> [String]? {
|
||||
if let url = Bundle.main.url(forResource: name, withExtension: "jsonl") {
|
||||
return try LLM.loadLoRAData(url: url)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func start() async {
|
||||
do {
|
||||
try await startInner()
|
||||
} catch {
|
||||
self.state = .failed("Failed: \(error)")
|
||||
}
|
||||
}
|
||||
|
||||
private func startInner() async throws {
|
||||
// setup
|
||||
GPU.set(cacheLimit: 32 * 1024 * 1024)
|
||||
await MainActor.run {
|
||||
output = ""
|
||||
state = .training
|
||||
}
|
||||
|
||||
// load the model
|
||||
let (model, tokenizer) = try await loadModel()
|
||||
|
||||
// apply LoRA adapters and train
|
||||
guard let layerProvider = model as? LoRAModel else {
|
||||
state = .failed("Model must implement the LoRALayerProvider protocol")
|
||||
return
|
||||
}
|
||||
LoRATrain.convert(
|
||||
model: model, layers: Array(layerProvider.loraLinearLayers().suffix(loraLayers)))
|
||||
|
||||
let train = try loadLoRAData(name: "train")
|
||||
let valid = try loadLoRAData(name: "valid")
|
||||
guard let train, let valid else {
|
||||
state = .failed("Failed to load train/validation data")
|
||||
return
|
||||
}
|
||||
|
||||
let optimizer = Adam(learningRate: learningRate)
|
||||
try await LoRATrain.train(
|
||||
model: model, train: train, validate: valid, optimizer: optimizer, tokenizer: tokenizer,
|
||||
parameters: parameters
|
||||
) { progress in
|
||||
await MainActor.run {
|
||||
switch progress {
|
||||
case .train(let i, _, _, _):
|
||||
self.progress = .init(
|
||||
title: "Train", current: Double(i), limit: Double(parameters.iterations))
|
||||
case .validation:
|
||||
output += "\n"
|
||||
default:
|
||||
break
|
||||
}
|
||||
|
||||
output += progress.description + "\n"
|
||||
}
|
||||
|
||||
return .more
|
||||
}
|
||||
|
||||
// done training, test
|
||||
await MainActor.run {
|
||||
self.progress = .init(title: "Testing", current: nil, limit: nil)
|
||||
}
|
||||
guard let test = try loadLoRAData(name: "test") else {
|
||||
state = .failed("Failed to load test data")
|
||||
return
|
||||
}
|
||||
|
||||
let loss = LoRATrain.evaluate(
|
||||
model: model, dataset: test, tokenizer: tokenizer, batchSize: 1, batchCount: 0)
|
||||
await MainActor.run {
|
||||
self.progress = nil
|
||||
self.output += "\n"
|
||||
self.output += "Test loss \(loss.formatted()), ppl \(exp(loss).formatted())\n"
|
||||
self.state = .evaluate
|
||||
}
|
||||
}
|
||||
|
||||
func evaluate(prompt: String) async {
|
||||
do {
|
||||
try await evaluateInner(prompt: prompt)
|
||||
} catch {
|
||||
self.state = .failed("Failed: \(error)")
|
||||
}
|
||||
}
|
||||
|
||||
func evaluateInner(prompt: String) async throws {
|
||||
await MainActor.run {
|
||||
self.progress = .init(title: "Evaluating", current: nil, limit: nil)
|
||||
self.output = ""
|
||||
}
|
||||
|
||||
MLXRandom.seed(UInt64(Date.timeIntervalSinceReferenceDate * 1000))
|
||||
|
||||
let (model, tokenizer) = try await loadModel()
|
||||
|
||||
// prepare the prompt
|
||||
let preparedPrompt = modelConfiguration.prepare(prompt: prompt)
|
||||
let promptTokens = tokenizer.encode(text: preparedPrompt)
|
||||
|
||||
// evaluate
|
||||
let result = await LLM.generate(
|
||||
promptTokens: promptTokens, parameters: generateParameters, model: model,
|
||||
tokenizer: tokenizer,
|
||||
didGenerate: { tokens in
|
||||
if tokens.count % evaluateShowEvery == 0 {
|
||||
let fullOutput = tokenizer.decode(tokens: tokens)
|
||||
await MainActor.run {
|
||||
self.output = fullOutput
|
||||
}
|
||||
}
|
||||
return tokens.count >= maxTokens ? .stop : .more
|
||||
})
|
||||
|
||||
await MainActor.run {
|
||||
self.output = result.output
|
||||
self.progress = nil
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,14 @@
|
||||
<?xml version="1.0" encoding="UTF-8"?>
|
||||
<!DOCTYPE plist PUBLIC "-//Apple//DTD PLIST 1.0//EN" "http://www.apple.com/DTDs/PropertyList-1.0.dtd">
|
||||
<plist version="1.0">
|
||||
<dict>
|
||||
<key>com.apple.developer.kernel.increased-memory-limit</key>
|
||||
<true/>
|
||||
<key>com.apple.security.app-sandbox</key>
|
||||
<true/>
|
||||
<key>com.apple.security.files.user-selected.read-only</key>
|
||||
<true/>
|
||||
<key>com.apple.security.network.client</key>
|
||||
<true/>
|
||||
</dict>
|
||||
</plist>
|
||||
@@ -0,0 +1,12 @@
|
||||
// Copyright © 2024 Apple Inc.
|
||||
|
||||
import SwiftUI
|
||||
|
||||
@main
|
||||
struct LoRATrainingExampleApp: App {
|
||||
var body: some Scene {
|
||||
WindowGroup {
|
||||
ContentView()
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,6 @@
|
||||
{
|
||||
"info" : {
|
||||
"author" : "xcode",
|
||||
"version" : 1
|
||||
}
|
||||
}
|
||||
21
Applications/LoRATrainingExample/README.md
Normal file
21
Applications/LoRATrainingExample/README.md
Normal file
@@ -0,0 +1,21 @@
|
||||
# LoRATrainingExample
|
||||
|
||||
Example application that:
|
||||
|
||||
- downloads the `mlx-community/Mistral-7B-v0.1-hf-4bit-mlx` model from huggingface
|
||||
- loads the train/valid/test data from `$SRCROOT/Data/lora` (this is copied into the build but you can imagine how it might be downloaded)
|
||||
- adds LoRA adapters and trains the model
|
||||
- let's you evaluate a prompt against the model
|
||||
|
||||
This roughly equates to the command line example in [Tools/llm-tool](../../Tools/llm-tool) and
|
||||
you can read more about LoRA there.
|
||||
|
||||
This evaluates the LoRA adapted model rather than a fused model. This doesn't persist
|
||||
the LoRA weights or the fused model -- it will retrain it each time the program is launched.
|
||||
|
||||
### Troubleshooting
|
||||
|
||||
The `mlx-community/Mistral-7B-v0.1-hf-4bit-mlx` model requires a little over 4G of
|
||||
memory to load an train -- this may require ~6G of physical RAM.
|
||||
|
||||
|
||||
Reference in New Issue
Block a user