implement LoRA / QLoRA (#46)
* implement LoRA / QLoRA - example of using MLX to fine-tune an LLM with low rank adaptation (LoRA) for a target task - see also https://arxiv.org/abs/2106.09685 - based on https://github.com/ml-explore/mlx-examples/tree/main/lora * add some command line flags I found useful during use - --quiet -- don't print decorator text, just the generated text - --prompt @/tmp/file.txt -- load prompt from file * user can specify path to model OR model identifier in huggingface * update mlx-swift reference Co-authored-by: Ashraful Islam <ashraful.meche@gmail.com> Co-authored-by: JustinMeans <46542161+JustinMeans@users.noreply.github.com>
This commit is contained in:
15
Tools/llm-tool/Arguments.swift
Normal file
15
Tools/llm-tool/Arguments.swift
Normal file
@@ -0,0 +1,15 @@
|
||||
// Copyright © 2024 Apple Inc.
|
||||
|
||||
import ArgumentParser
|
||||
import Foundation
|
||||
|
||||
/// Extension to allow URL command line arguments.
|
||||
extension URL: ExpressibleByArgument {
|
||||
public init?(argument: String) {
|
||||
if argument.contains("://") {
|
||||
self.init(string: argument)
|
||||
} else {
|
||||
self.init(filePath: argument)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -11,18 +11,26 @@ import Tokenizers
|
||||
struct LLMTool: AsyncParsableCommand {
|
||||
static var configuration = CommandConfiguration(
|
||||
abstract: "Command line tool for generating text and manipulating LLMs",
|
||||
subcommands: [EvaluateCommand.self],
|
||||
subcommands: [EvaluateCommand.self, LoRACommand.self],
|
||||
defaultSubcommand: EvaluateCommand.self)
|
||||
}
|
||||
|
||||
/// Command line arguments for loading a model.
|
||||
struct ModelArguments: ParsableArguments {
|
||||
|
||||
@Option(name: .long, help: "Name of the huggingface model")
|
||||
@Option(name: .long, help: "Name of the huggingface model or absolute path to directory")
|
||||
var model: String = "mlx-community/Mistral-7B-v0.1-hf-4bit-mlx"
|
||||
|
||||
func load() async throws -> (LLMModel, Tokenizer, ModelConfiguration) {
|
||||
let modelConfiguration = ModelConfiguration.configuration(id: model)
|
||||
let modelConfiguration: ModelConfiguration
|
||||
|
||||
if self.model.hasPrefix("/") {
|
||||
// path
|
||||
modelConfiguration = ModelConfiguration(directory: URL(filePath: self.model))
|
||||
} else {
|
||||
// identifier
|
||||
modelConfiguration = ModelConfiguration.configuration(id: model)
|
||||
}
|
||||
let (model, tokenizer) = try await LLM.load(configuration: modelConfiguration)
|
||||
return (model, tokenizer, modelConfiguration)
|
||||
}
|
||||
@@ -31,7 +39,11 @@ struct ModelArguments: ParsableArguments {
|
||||
/// Command line arguments for controlling generation of text.
|
||||
struct GenerateArguments: ParsableArguments {
|
||||
|
||||
@Option(name: .shortAndLong, help: "The message to be processed by the model")
|
||||
@Option(
|
||||
name: .shortAndLong,
|
||||
help:
|
||||
"The message to be processed by the model. Use @path,@path to load from files, e.g. @/tmp/prompt.txt"
|
||||
)
|
||||
var prompt = "compare python and swift"
|
||||
|
||||
@Option(name: .shortAndLong, help: "Maximum number of tokens to generate")
|
||||
@@ -52,18 +64,32 @@ struct GenerateArguments: ParsableArguments {
|
||||
@Option(name: .long, help: "The PRNG seed")
|
||||
var seed: UInt64 = 0
|
||||
|
||||
@Flag(name: .shortAndLong, help: "If true only print the generated output")
|
||||
var quiet = false
|
||||
|
||||
var generateParameters: GenerateParameters {
|
||||
GenerateParameters(
|
||||
temperature: temperature, topP: topP, repetitionPenalty: repetitionPenalty,
|
||||
repetitionContextSize: repetitionContextSize)
|
||||
}
|
||||
|
||||
func tokenizePrompt(configuration: ModelConfiguration, tokenizer: Tokenizer) -> (String, [Int])
|
||||
{
|
||||
func resolvePrompt() throws -> String {
|
||||
if prompt.hasPrefix("@") {
|
||||
let names = prompt.split(separator: ",").map { String($0.dropFirst()) }
|
||||
return try names.map { try String(contentsOfFile: $0) }.joined(separator: "\n")
|
||||
} else {
|
||||
return prompt
|
||||
}
|
||||
}
|
||||
|
||||
func tokenizePrompt(configuration: ModelConfiguration, tokenizer: Tokenizer) throws -> (
|
||||
String, [Int]
|
||||
) {
|
||||
MLXRandom.seed(seed)
|
||||
|
||||
let prompt = configuration.prepare(prompt: self.prompt)
|
||||
let promptTokens = tokenizer.encode(text: prompt)
|
||||
let prompt = try resolvePrompt()
|
||||
let preparedPrompt = configuration.prepare(prompt: prompt)
|
||||
let promptTokens = tokenizer.encode(text: preparedPrompt)
|
||||
|
||||
return (prompt, promptTokens)
|
||||
}
|
||||
@@ -187,21 +213,27 @@ struct EvaluateCommand: AsyncParsableCommand {
|
||||
mutating func run() async throws {
|
||||
let (model, tokenizer, modelConfiguration) = try await memory.start(args.load)
|
||||
|
||||
print("Model loaded -> \(modelConfiguration.id)")
|
||||
if !generate.quiet {
|
||||
print("Model loaded -> \(modelConfiguration.id)")
|
||||
}
|
||||
|
||||
let (prompt, promptTokens) = generate.tokenizePrompt(
|
||||
let (prompt, promptTokens) = try generate.tokenizePrompt(
|
||||
configuration: modelConfiguration, tokenizer: tokenizer)
|
||||
|
||||
print("Starting generation ...")
|
||||
print(prompt, terminator: "")
|
||||
if !generate.quiet {
|
||||
print("Starting generation ...")
|
||||
print(prompt, terminator: "")
|
||||
}
|
||||
|
||||
let result = await generate.generate(
|
||||
promptTokens: promptTokens, model: model, tokenizer: tokenizer)
|
||||
|
||||
print()
|
||||
print("------")
|
||||
print(result.summary())
|
||||
|
||||
memory.reportMemoryStatistics()
|
||||
if !generate.quiet {
|
||||
print("------")
|
||||
print(result.summary())
|
||||
|
||||
memory.reportMemoryStatistics()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
281
Tools/llm-tool/LoraCommands.swift
Normal file
281
Tools/llm-tool/LoraCommands.swift
Normal file
@@ -0,0 +1,281 @@
|
||||
// Copyright © 2024 Apple Inc.
|
||||
|
||||
import ArgumentParser
|
||||
import Foundation
|
||||
import Hub
|
||||
import LLM
|
||||
import MLX
|
||||
import MLXNN
|
||||
import MLXOptimizers
|
||||
import MLXRandom
|
||||
import Tokenizers
|
||||
|
||||
struct LoRACommand: AsyncParsableCommand {
|
||||
|
||||
static var configuration = CommandConfiguration(
|
||||
commandName: "lora",
|
||||
abstract: "LoRA commands",
|
||||
subcommands: [
|
||||
LoRATrainCommand.self, LoRAFuseCommand.self, LoRATestCommand.self, LoRAEvalCommand.self,
|
||||
]
|
||||
)
|
||||
}
|
||||
|
||||
/// Common arguments for loading a LoRA mdoel with adapter weights
|
||||
struct LoRAModelArguments: ParsableArguments {
|
||||
|
||||
@OptionGroup var args: ModelArguments
|
||||
|
||||
@Option(name: .long, help: "Save/load path for the trained adapter weights")
|
||||
public var adapter: URL = URL(filePath: "adapters.safetensors")
|
||||
|
||||
@Option(name: .long, help: "Number of layers to fine-tune")
|
||||
public var loraLayers = 16
|
||||
|
||||
/// Load the model and apply the LoRA adapters.
|
||||
///
|
||||
/// This does not load the adapter weights as they may not exist yet.
|
||||
func load() async throws -> (LLMModel, Tokenizer, ModelConfiguration) {
|
||||
let (model, tokenizer, modelConfiguration) = try await args.load()
|
||||
|
||||
// convert some of the Linear layers to LoRALinear
|
||||
LoRATrain.convert(model: model, layers: loraLayers(model: model))
|
||||
|
||||
return (model, tokenizer, modelConfiguration)
|
||||
}
|
||||
|
||||
func loraLayers(model: Module) -> LoRALinearLayers {
|
||||
guard let layerProvider = model as? LoRAModel else {
|
||||
// the layerProvider will indicate which Linear layers need to be replaced
|
||||
fatalError(
|
||||
"Model \(type(of: model)) (\(args.model)) must implement the LoRALayerProvider protocol"
|
||||
)
|
||||
}
|
||||
|
||||
return Array(layerProvider.loraLinearLayers().suffix(loraLayers))
|
||||
}
|
||||
|
||||
func describe(model: Module) {
|
||||
let totalParameterCount = model.parameters()
|
||||
.flattenedValues().map { $0.size }.reduce(0, +)
|
||||
let trainableParameterCount = model.trainableParameters()
|
||||
.flattenedValues().map { $0.size }.reduce(0, +)
|
||||
|
||||
print("Model: \(args.model)")
|
||||
print("Total parameters: \((totalParameterCount / 1_000_000).formatted())M")
|
||||
print(
|
||||
"Trainable parameters: \((Float(trainableParameterCount) / 1_000_000).formatted(.number.precision(.significantDigits(1 ..< 4))))M"
|
||||
)
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
struct LoRATrainCommand: AsyncParsableCommand {
|
||||
|
||||
static var configuration = CommandConfiguration(
|
||||
commandName: "train",
|
||||
abstract: "LoRA training"
|
||||
)
|
||||
|
||||
@OptionGroup var args: LoRAModelArguments
|
||||
@OptionGroup var memory: MemoryArguments
|
||||
|
||||
@Flag(help: "Resume training with the given adapter file")
|
||||
public var resume = false
|
||||
|
||||
@Option(name: .long, help: "Directory with {train, valid, test}.{jsonl,txt} files")
|
||||
public var data: URL = URL(filePath: "data")
|
||||
|
||||
@Option(name: .long, help: "Learning rate for the optimizer")
|
||||
public var learningRate: Float = 1e-5
|
||||
|
||||
@Option(name: .long, help: "Number of dataset items to evaluate per iteration (batch)")
|
||||
public var batchSize = 4
|
||||
|
||||
@Option(name: .long, help: "Number iterations to train for")
|
||||
public var iterations = 1000
|
||||
|
||||
@Option(name: .long, help: "Number of iterations between loss reporting")
|
||||
public var stepsPerReport = 10
|
||||
|
||||
@Option(name: .long, help: "Number of iterations between validations")
|
||||
public var stepsPerEval = 100
|
||||
|
||||
@Option(name: .long, help: "Number of validation batches, 0 uses the entire set")
|
||||
public var validationBatches = 10
|
||||
|
||||
@Option(name: .long, help: "Number of iterations between checkpointing the adapter weights")
|
||||
public var saveEvery = 100
|
||||
|
||||
var parameters: LoRATrain.Parameters {
|
||||
var p = LoRATrain.Parameters()
|
||||
p.batchSize = self.batchSize
|
||||
p.iterations = self.iterations
|
||||
p.stepsPerReport = self.stepsPerReport
|
||||
p.stepsPerEval = self.stepsPerEval
|
||||
p.validationBatches = self.validationBatches
|
||||
p.saveEvery = self.saveEvery
|
||||
p.adapterURL = args.adapter
|
||||
return p
|
||||
}
|
||||
|
||||
@MainActor
|
||||
mutating func run() async throws {
|
||||
let (model, tokenizer, _) = try await args.load()
|
||||
args.describe(model: model)
|
||||
|
||||
memory.start()
|
||||
|
||||
if resume {
|
||||
print("Loading pretrained adapters from \(args.adapter.path())")
|
||||
try LoRATrain.loadLoRAWeights(model: model, url: args.adapter)
|
||||
}
|
||||
|
||||
// load the train/validation data
|
||||
let train = try loadLoRAData(directory: data, name: "train")
|
||||
let valid = try loadLoRAData(directory: data, name: "valid")
|
||||
|
||||
if train.isEmpty {
|
||||
fatalError("Training set is empty: \(data.path()))")
|
||||
}
|
||||
if valid.isEmpty {
|
||||
fatalError("Validation set is empty: \(data.path()))")
|
||||
}
|
||||
|
||||
// train
|
||||
let optimizer = Adam(learningRate: learningRate)
|
||||
try await LoRATrain.train(
|
||||
model: model, train: train, validate: valid, optimizer: optimizer, tokenizer: tokenizer,
|
||||
parameters: parameters
|
||||
) { progress in
|
||||
print(progress)
|
||||
return .more
|
||||
}
|
||||
try LoRATrain.saveLoRAWeights(model: model, url: args.adapter)
|
||||
}
|
||||
}
|
||||
|
||||
struct LoRAFuseCommand: AsyncParsableCommand {
|
||||
|
||||
static var configuration = CommandConfiguration(
|
||||
commandName: "fuse",
|
||||
abstract: "Fuse lora adapter weights back in to original model"
|
||||
)
|
||||
|
||||
@OptionGroup var args: LoRAModelArguments
|
||||
|
||||
@Flag(name: .long, help: "De-quantize QuantizedLinear layers back into Linear")
|
||||
var deQuantize = false
|
||||
|
||||
@Option(name: .long, help: "Hub ID (mlx-community/mistral-lora) or path (/tmp/mistral-lora)")
|
||||
var output: String
|
||||
|
||||
@MainActor
|
||||
mutating func run() async throws {
|
||||
let outputURL: URL
|
||||
if output.hasPrefix("/") {
|
||||
outputURL = URL(filePath: output)
|
||||
} else {
|
||||
let repo = HubApi.Repo(id: output)
|
||||
outputURL = HubApi().localRepoLocation(repo)
|
||||
}
|
||||
|
||||
let (model, _, modelConfiguration) = try await args.load()
|
||||
|
||||
// load the prepared weights
|
||||
try LoRATrain.loadLoRAWeights(model: model, url: args.adapter)
|
||||
|
||||
// fuse them back into Linear/QuantizedLinear
|
||||
LoRATrain.fuse(model: model, layers: args.loraLayers(model: model), deQuantize: deQuantize)
|
||||
|
||||
// make the new directory and copy files from source model
|
||||
try FileManager.default.createDirectory(at: outputURL, withIntermediateDirectories: true)
|
||||
let inputURL = modelConfiguration.modelDirectory()
|
||||
let enumerator = FileManager.default.enumerator(
|
||||
at: inputURL, includingPropertiesForKeys: nil)!
|
||||
for case let url as URL in enumerator {
|
||||
// copy everything except the model weights -- we will write out the fused one below
|
||||
if url.pathExtension == "safetensors" {
|
||||
continue
|
||||
}
|
||||
|
||||
try FileManager.default.copyItem(
|
||||
at: url, to: outputURL.appending(component: url.lastPathComponent))
|
||||
}
|
||||
|
||||
// write them back out
|
||||
let weights = Dictionary(uniqueKeysWithValues: model.parameters().flattened())
|
||||
try save(arrays: weights, url: outputURL.appending(component: "weights.safetensors"))
|
||||
|
||||
print("Fused weights written to \(outputURL.path())")
|
||||
print("Use with:\n\tllm-tool eval --model \(output)")
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
struct LoRATestCommand: AsyncParsableCommand {
|
||||
|
||||
static var configuration = CommandConfiguration(
|
||||
commandName: "test",
|
||||
abstract: "LoRA testing"
|
||||
)
|
||||
|
||||
@OptionGroup var args: LoRAModelArguments
|
||||
@OptionGroup var memory: MemoryArguments
|
||||
|
||||
@Option(name: .long, help: "Directory with {train, valid, test}.{jsonl,txt} files")
|
||||
public var data: URL = URL(filePath: "data")
|
||||
|
||||
@Option(name: .long, help: "Minibatch size")
|
||||
public var batchSize = 4
|
||||
|
||||
@MainActor
|
||||
mutating func run() async throws {
|
||||
let (model, tokenizer, _) = try await args.load()
|
||||
args.describe(model: model)
|
||||
try LoRATrain.loadLoRAWeights(model: model, url: args.adapter)
|
||||
|
||||
memory.start()
|
||||
|
||||
let test = try loadLoRAData(directory: data, name: "test")
|
||||
let loss = LoRATrain.evaluate(
|
||||
model: model, dataset: test, tokenizer: tokenizer, batchSize: batchSize, batchCount: 0)
|
||||
|
||||
print("Test loss \(loss.formatted()), ppl \(exp(loss).formatted())")
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
struct LoRAEvalCommand: AsyncParsableCommand {
|
||||
|
||||
static var configuration = CommandConfiguration(
|
||||
commandName: "eval",
|
||||
abstract: "LoRA evaluation"
|
||||
)
|
||||
|
||||
@OptionGroup var args: LoRAModelArguments
|
||||
@OptionGroup var memory: MemoryArguments
|
||||
@OptionGroup var generate: GenerateArguments
|
||||
|
||||
@MainActor
|
||||
mutating func run() async throws {
|
||||
let (model, tokenizer, modelConfiguration) = try await args.load()
|
||||
args.describe(model: model)
|
||||
try LoRATrain.loadLoRAWeights(model: model, url: args.adapter)
|
||||
|
||||
memory.start()
|
||||
|
||||
let (prompt, promptTokens) = try generate.tokenizePrompt(
|
||||
configuration: modelConfiguration, tokenizer: tokenizer)
|
||||
|
||||
if !generate.quiet {
|
||||
print("Starting generation ...")
|
||||
print(prompt, terminator: "")
|
||||
}
|
||||
|
||||
// generate and print the result
|
||||
let _ = await generate.generate(
|
||||
promptTokens: promptTokens, model: model, tokenizer: tokenizer)
|
||||
print()
|
||||
}
|
||||
}
|
||||
@@ -8,7 +8,7 @@ See various READMEs:
|
||||
|
||||
Build the `llm-tool` scheme in Xcode.
|
||||
|
||||
### Running (Xcode)
|
||||
### Running: Xcode
|
||||
|
||||
To run this in Xcode simply press cmd-opt-r to set the scheme arguments. For example:
|
||||
|
||||
@@ -30,7 +30,7 @@ The model should be a path in the Hugging Face repository, e.g.:
|
||||
|
||||
See [LLM](../../Libraries/LLM/README.md) for more info.
|
||||
|
||||
### Running (Command Line)
|
||||
### Running: Command Line
|
||||
|
||||
Use the `mlx-run` script to run the command line tools:
|
||||
|
||||
@@ -60,3 +60,184 @@ Building in Release / optimizations will remove a lot of tail calls in the C++
|
||||
layer. These lead to the stack overflows.
|
||||
|
||||
See discussion here: https://github.com/ml-explore/mlx-swift-examples/issues/3
|
||||
|
||||
## LoRA
|
||||
|
||||
`llm-tool` provides an example LoRA driver based on:
|
||||
|
||||
- https://github.com/ml-explore/mlx-examples/blob/main/lora/README.md
|
||||
|
||||
This is an example of using MLX to fine-tune an LLM with low rank adaptation
|
||||
(LoRA) for a target task.[^lora] The example also supports quantized LoRA
|
||||
(QLoRA).[^qlora] The example works with Llama and Mistral style models
|
||||
available on Hugging Face.
|
||||
|
||||
In this example we'll use the WikiSQL[^wikisql] dataset to train the LLM to
|
||||
generate SQL queries from natural language. However, the example is intended to
|
||||
be general should you wish to use a custom dataset.
|
||||
|
||||
> Note: Some of the prompts have newlines in them which is difficult to achieve via running in Xcode.
|
||||
|
||||
Running `llm-tool lora` will produce help:
|
||||
|
||||
```
|
||||
SUBCOMMANDS:
|
||||
train LoRA training
|
||||
fuse Fuse lora adapter weights back in to original model
|
||||
test LoRA testing
|
||||
eval LoRA evaluation
|
||||
```
|
||||
|
||||
### Training
|
||||
|
||||
The first step will be training the LoRA adapter. Example training data
|
||||
is available in $SRCROOT/Data/lora. You can use your
|
||||
own data in either `jsonl` or `txt` format with one entry per line.
|
||||
|
||||
We need to specify a number of parameters:
|
||||
|
||||
- `--model` -- which model to use. This can be quantized [^qlora] or not [^lora]
|
||||
- `--data` -- directory with the test, train and valid files. These can be either `jsonl` or `txt` files
|
||||
- `--adapter` -- path to a safetensors file to write the fine tuned parameters into
|
||||
|
||||
Additionally the performance of the fine tuning can be controlled with:
|
||||
|
||||
- `--batch-size` -- size of the minibatches to run in the training loop, e.g. how many prompts to process per iteration
|
||||
- `--lora-layers` -- the number of layers in the Attention section of the model to adapt and train
|
||||
- `--iterations` -- the number of iterations to train for
|
||||
|
||||
If desired, the amount of memory used can be adjusted with:
|
||||
|
||||
- `--cache-size` -- the number shown below limits the cache size to 1024M
|
||||
|
||||
Here is an example run using adapters on the last 4 layers of the model:
|
||||
|
||||
```
|
||||
./mlx-run llm-tool lora train \
|
||||
--model mlx-community/Mistral-7B-v0.1-hf-4bit-mlx \
|
||||
--data Data/lora \
|
||||
--adapter /tmp/lora-layers-4.safetensors \
|
||||
--batch-size 1 --lora-layers 4 \
|
||||
--cache-size 1024
|
||||
```
|
||||
|
||||
giving output like this:
|
||||
|
||||
```
|
||||
Model: mlx-community/Mistral-7B-v0.1-hf-4bit-mlx
|
||||
Total parameters: 1,242M
|
||||
Trainable parameters: 0.426M
|
||||
Iteration 1: validation loss 2.443872, validation time 3.330629s
|
||||
Iteration 10: training loss 2.356848, iterations/sec 2.640604, Tokens/sec 260.363581
|
||||
Iteration 20: training loss 2.063395, iterations/sec 2.294999, Tokens/sec 232.483365
|
||||
Iteration 30: training loss 1.63846, iterations/sec 2.279401, Tokens/sec 225.204788
|
||||
Iteration 40: training loss 1.66366, iterations/sec 2.493669, Tokens/sec 218.196057
|
||||
Iteration 50: training loss 1.470927, iterations/sec 2.301153, Tokens/sec 231.72614
|
||||
Iteration 60: training loss 1.396581, iterations/sec 2.400012, Tokens/sec 230.401195
|
||||
Iteration 70: training loss 1.587023, iterations/sec 2.422193, Tokens/sec 218.966258
|
||||
Iteration 80: training loss 1.376895, iterations/sec 2.111973, Tokens/sec 216.477187
|
||||
Iteration 90: training loss 1.245127, iterations/sec 2.383802, Tokens/sec 214.065436
|
||||
Iteration 100: training loss 1.344523, iterations/sec 2.424746, Tokens/sec 223.076649
|
||||
Iteration 100: validation loss 1.400582, validation time 3.489797s
|
||||
Iteration 100: saved weights to /tmp/lora.safetensors
|
||||
...
|
||||
Iteration 910: training loss 1.181306, iterations/sec 2.355085, Tokens/sec 212.428628
|
||||
Iteration 920: training loss 1.042286, iterations/sec 2.374377, Tokens/sec 222.479127
|
||||
Iteration 930: training loss 0.920768, iterations/sec 2.475088, Tokens/sec 220.035347
|
||||
Iteration 940: training loss 1.140762, iterations/sec 2.119886, Tokens/sec 227.039828
|
||||
Iteration 950: training loss 1.068073, iterations/sec 2.523047, Tokens/sec 218.495903
|
||||
Iteration 960: training loss 1.106662, iterations/sec 2.339293, Tokens/sec 221.063186
|
||||
Iteration 970: training loss 0.833658, iterations/sec 2.474683, Tokens/sec 213.56517
|
||||
Iteration 980: training loss 0.844026, iterations/sec 2.441064, Tokens/sec 210.663791
|
||||
Iteration 990: training loss 0.903735, iterations/sec 2.253876, Tokens/sec 218.175162
|
||||
Iteration 1000: training loss 0.872615, iterations/sec 2.343899, Tokens/sec 219.62336
|
||||
Iteration 1000: validation loss 0.714194, validation time 3.470462s
|
||||
Iteration 1000: saved weights to /tmp/lora-layers-4.safetensors
|
||||
```
|
||||
|
||||
### Testing
|
||||
|
||||
You can test the LoRA adapated model against the `test` dataset using this command:
|
||||
|
||||
```
|
||||
./mlx-run llm-tool lora test \
|
||||
--model mlx-community/Mistral-7B-v0.1-hf-4bit-mlx \
|
||||
--data Data/lora \
|
||||
--adapter /tmp/lora-layers-4.safetensors \
|
||||
--batch-size 1 --lora-layers 4 \
|
||||
--cache-size 1024
|
||||
```
|
||||
|
||||
This will run all the items (100 in the example data we are using) in the test set and compute the loss:
|
||||
|
||||
```
|
||||
Model: mlx-community/Mistral-7B-v0.1-hf-4bit-mlx
|
||||
Total parameters: 1,242M
|
||||
Trainable parameters: 0.426M
|
||||
Test loss 1.327623, ppl 3.772065
|
||||
```
|
||||
|
||||
### Evaluate
|
||||
|
||||
Next you can evaluate your own prompts with the fine tuned LoRA adapters. It is important to
|
||||
follow the prompt example from the training data to match the format:
|
||||
|
||||
```
|
||||
{"text": "table: 1-10015132-1\ncolumns: Player, No., Nationality, Position, Years in Toronto, School/Club Team\nQ: What school did player number 6 come from?\nA: SELECT School/Club Team FROM 1-10015132-1 WHERE No. = '6'"}
|
||||
```
|
||||
|
||||
Given that format you might issue a command like this:
|
||||
|
||||
```
|
||||
./mlx-run llm-tool lora eval \
|
||||
--model mlx-community/Mistral-7B-v0.1-hf-4bit-mlx \
|
||||
--adapter /tmp/lora-layers-4.safetensors \
|
||||
--lora-layers 4 \
|
||||
--prompt "table: 1-10015132-16
|
||||
columns: Player, No., Nationality, Position, Years in Toronto, School/Club Team
|
||||
Q: What is terrence ross' nationality
|
||||
A: "
|
||||
```
|
||||
|
||||
> Note: the prompt has newlines in it to match the format of the fine tuned prompts -- this may be easier to do with the command line than Xcode.
|
||||
|
||||
You might be treated to a response like this:
|
||||
|
||||
```
|
||||
Model: mlx-community/Mistral-7B-v0.1-hf-4bit-mlx
|
||||
Total parameters: 1,242M
|
||||
Trainable parameters: 0.426M
|
||||
Starting generation ...
|
||||
table: 1-10015132-16
|
||||
columns: Player, No., Nationality, Position, Years in Toronto, School/Club Team
|
||||
Q: What is terrence ross' nationality
|
||||
A: SELECT Nationality FROM 1-10015132-16 WHERE Player = 'Terrence Ross' AND No. = 1
|
||||
```
|
||||
|
||||
### Fusing
|
||||
|
||||
Once the adapter weights are trained you can produce new weights with the original achitecture that
|
||||
have the adapter weights merged in:
|
||||
|
||||
```
|
||||
./mlx-run llm-tool lora fuse \
|
||||
--model mlx-community/Mistral-7B-v0.1-hf-4bit-mlx \
|
||||
--adapter /tmp/lora-layers-4.safetensors \
|
||||
--output mlx-community/mistral-lora
|
||||
```
|
||||
|
||||
outputs:
|
||||
|
||||
```
|
||||
Total parameters: 1,244M
|
||||
Trainable parameters: 0.426M
|
||||
Use with:
|
||||
llm-tool eval --model mlx-community/mistral-lora
|
||||
```
|
||||
|
||||
As noted in the output these new weights can be used with the original model architecture.
|
||||
|
||||
|
||||
[^lora]: Refer to the [arXiv paper](https://arxiv.org/abs/2106.09685) for more details on LoRA.
|
||||
[^qlora]: Refer to the paper [QLoRA: Efficient Finetuning of Quantized LLMs](https://arxiv.org/abs/2305.14314)
|
||||
[^wikisql]: Refer to the [GitHub repo](https://github.com/salesforce/WikiSQL/tree/master) for more information about WikiSQL.
|
||||
|
||||
Reference in New Issue
Block a user