implement LoRA / QLoRA (#46)

* implement LoRA / QLoRA

- example of using MLX to fine-tune an LLM with low rank adaptation (LoRA) for a target task
- see also https://arxiv.org/abs/2106.09685
- based on https://github.com/ml-explore/mlx-examples/tree/main/lora

* add some command line flags I found useful during use
- --quiet -- don't print decorator text, just the generated text
- --prompt @/tmp/file.txt -- load prompt from file

* user can specify path to model OR model identifier in huggingface

* update mlx-swift reference

Co-authored-by: Ashraful Islam <ashraful.meche@gmail.com>
Co-authored-by: JustinMeans <46542161+JustinMeans@users.noreply.github.com>
This commit is contained in:
David Koski
2024-04-22 09:30:12 -07:00
committed by GitHub
parent 7e85eb8b88
commit 6c0b66f90a
32 changed files with 3483 additions and 64 deletions

View File

@@ -0,0 +1,281 @@
// Copyright © 2024 Apple Inc.
import ArgumentParser
import Foundation
import Hub
import LLM
import MLX
import MLXNN
import MLXOptimizers
import MLXRandom
import Tokenizers
struct LoRACommand: AsyncParsableCommand {
static var configuration = CommandConfiguration(
commandName: "lora",
abstract: "LoRA commands",
subcommands: [
LoRATrainCommand.self, LoRAFuseCommand.self, LoRATestCommand.self, LoRAEvalCommand.self,
]
)
}
/// Common arguments for loading a LoRA mdoel with adapter weights
struct LoRAModelArguments: ParsableArguments {
@OptionGroup var args: ModelArguments
@Option(name: .long, help: "Save/load path for the trained adapter weights")
public var adapter: URL = URL(filePath: "adapters.safetensors")
@Option(name: .long, help: "Number of layers to fine-tune")
public var loraLayers = 16
/// Load the model and apply the LoRA adapters.
///
/// This does not load the adapter weights as they may not exist yet.
func load() async throws -> (LLMModel, Tokenizer, ModelConfiguration) {
let (model, tokenizer, modelConfiguration) = try await args.load()
// convert some of the Linear layers to LoRALinear
LoRATrain.convert(model: model, layers: loraLayers(model: model))
return (model, tokenizer, modelConfiguration)
}
func loraLayers(model: Module) -> LoRALinearLayers {
guard let layerProvider = model as? LoRAModel else {
// the layerProvider will indicate which Linear layers need to be replaced
fatalError(
"Model \(type(of: model)) (\(args.model)) must implement the LoRALayerProvider protocol"
)
}
return Array(layerProvider.loraLinearLayers().suffix(loraLayers))
}
func describe(model: Module) {
let totalParameterCount = model.parameters()
.flattenedValues().map { $0.size }.reduce(0, +)
let trainableParameterCount = model.trainableParameters()
.flattenedValues().map { $0.size }.reduce(0, +)
print("Model: \(args.model)")
print("Total parameters: \((totalParameterCount / 1_000_000).formatted())M")
print(
"Trainable parameters: \((Float(trainableParameterCount) / 1_000_000).formatted(.number.precision(.significantDigits(1 ..< 4))))M"
)
}
}
struct LoRATrainCommand: AsyncParsableCommand {
static var configuration = CommandConfiguration(
commandName: "train",
abstract: "LoRA training"
)
@OptionGroup var args: LoRAModelArguments
@OptionGroup var memory: MemoryArguments
@Flag(help: "Resume training with the given adapter file")
public var resume = false
@Option(name: .long, help: "Directory with {train, valid, test}.{jsonl,txt} files")
public var data: URL = URL(filePath: "data")
@Option(name: .long, help: "Learning rate for the optimizer")
public var learningRate: Float = 1e-5
@Option(name: .long, help: "Number of dataset items to evaluate per iteration (batch)")
public var batchSize = 4
@Option(name: .long, help: "Number iterations to train for")
public var iterations = 1000
@Option(name: .long, help: "Number of iterations between loss reporting")
public var stepsPerReport = 10
@Option(name: .long, help: "Number of iterations between validations")
public var stepsPerEval = 100
@Option(name: .long, help: "Number of validation batches, 0 uses the entire set")
public var validationBatches = 10
@Option(name: .long, help: "Number of iterations between checkpointing the adapter weights")
public var saveEvery = 100
var parameters: LoRATrain.Parameters {
var p = LoRATrain.Parameters()
p.batchSize = self.batchSize
p.iterations = self.iterations
p.stepsPerReport = self.stepsPerReport
p.stepsPerEval = self.stepsPerEval
p.validationBatches = self.validationBatches
p.saveEvery = self.saveEvery
p.adapterURL = args.adapter
return p
}
@MainActor
mutating func run() async throws {
let (model, tokenizer, _) = try await args.load()
args.describe(model: model)
memory.start()
if resume {
print("Loading pretrained adapters from \(args.adapter.path())")
try LoRATrain.loadLoRAWeights(model: model, url: args.adapter)
}
// load the train/validation data
let train = try loadLoRAData(directory: data, name: "train")
let valid = try loadLoRAData(directory: data, name: "valid")
if train.isEmpty {
fatalError("Training set is empty: \(data.path()))")
}
if valid.isEmpty {
fatalError("Validation set is empty: \(data.path()))")
}
// train
let optimizer = Adam(learningRate: learningRate)
try await LoRATrain.train(
model: model, train: train, validate: valid, optimizer: optimizer, tokenizer: tokenizer,
parameters: parameters
) { progress in
print(progress)
return .more
}
try LoRATrain.saveLoRAWeights(model: model, url: args.adapter)
}
}
struct LoRAFuseCommand: AsyncParsableCommand {
static var configuration = CommandConfiguration(
commandName: "fuse",
abstract: "Fuse lora adapter weights back in to original model"
)
@OptionGroup var args: LoRAModelArguments
@Flag(name: .long, help: "De-quantize QuantizedLinear layers back into Linear")
var deQuantize = false
@Option(name: .long, help: "Hub ID (mlx-community/mistral-lora) or path (/tmp/mistral-lora)")
var output: String
@MainActor
mutating func run() async throws {
let outputURL: URL
if output.hasPrefix("/") {
outputURL = URL(filePath: output)
} else {
let repo = HubApi.Repo(id: output)
outputURL = HubApi().localRepoLocation(repo)
}
let (model, _, modelConfiguration) = try await args.load()
// load the prepared weights
try LoRATrain.loadLoRAWeights(model: model, url: args.adapter)
// fuse them back into Linear/QuantizedLinear
LoRATrain.fuse(model: model, layers: args.loraLayers(model: model), deQuantize: deQuantize)
// make the new directory and copy files from source model
try FileManager.default.createDirectory(at: outputURL, withIntermediateDirectories: true)
let inputURL = modelConfiguration.modelDirectory()
let enumerator = FileManager.default.enumerator(
at: inputURL, includingPropertiesForKeys: nil)!
for case let url as URL in enumerator {
// copy everything except the model weights -- we will write out the fused one below
if url.pathExtension == "safetensors" {
continue
}
try FileManager.default.copyItem(
at: url, to: outputURL.appending(component: url.lastPathComponent))
}
// write them back out
let weights = Dictionary(uniqueKeysWithValues: model.parameters().flattened())
try save(arrays: weights, url: outputURL.appending(component: "weights.safetensors"))
print("Fused weights written to \(outputURL.path())")
print("Use with:\n\tllm-tool eval --model \(output)")
}
}
struct LoRATestCommand: AsyncParsableCommand {
static var configuration = CommandConfiguration(
commandName: "test",
abstract: "LoRA testing"
)
@OptionGroup var args: LoRAModelArguments
@OptionGroup var memory: MemoryArguments
@Option(name: .long, help: "Directory with {train, valid, test}.{jsonl,txt} files")
public var data: URL = URL(filePath: "data")
@Option(name: .long, help: "Minibatch size")
public var batchSize = 4
@MainActor
mutating func run() async throws {
let (model, tokenizer, _) = try await args.load()
args.describe(model: model)
try LoRATrain.loadLoRAWeights(model: model, url: args.adapter)
memory.start()
let test = try loadLoRAData(directory: data, name: "test")
let loss = LoRATrain.evaluate(
model: model, dataset: test, tokenizer: tokenizer, batchSize: batchSize, batchCount: 0)
print("Test loss \(loss.formatted()), ppl \(exp(loss).formatted())")
}
}
struct LoRAEvalCommand: AsyncParsableCommand {
static var configuration = CommandConfiguration(
commandName: "eval",
abstract: "LoRA evaluation"
)
@OptionGroup var args: LoRAModelArguments
@OptionGroup var memory: MemoryArguments
@OptionGroup var generate: GenerateArguments
@MainActor
mutating func run() async throws {
let (model, tokenizer, modelConfiguration) = try await args.load()
args.describe(model: model)
try LoRATrain.loadLoRAWeights(model: model, url: args.adapter)
memory.start()
let (prompt, promptTokens) = try generate.tokenizePrompt(
configuration: modelConfiguration, tokenizer: tokenizer)
if !generate.quiet {
print("Starting generation ...")
print(prompt, terminator: "")
}
// generate and print the result
let _ = await generate.generate(
promptTokens: promptTokens, model: model, tokenizer: tokenizer)
print()
}
}