implement LoRA / QLoRA (#46)

* implement LoRA / QLoRA

- example of using MLX to fine-tune an LLM with low rank adaptation (LoRA) for a target task
- see also https://arxiv.org/abs/2106.09685
- based on https://github.com/ml-explore/mlx-examples/tree/main/lora

* add some command line flags I found useful during use
- --quiet -- don't print decorator text, just the generated text
- --prompt @/tmp/file.txt -- load prompt from file

* user can specify path to model OR model identifier in huggingface

* update mlx-swift reference

Co-authored-by: Ashraful Islam <ashraful.meche@gmail.com>
Co-authored-by: JustinMeans <46542161+JustinMeans@users.noreply.github.com>
This commit is contained in:
David Koski
2024-04-22 09:30:12 -07:00
committed by GitHub
parent 7e85eb8b88
commit 6c0b66f90a
32 changed files with 3483 additions and 64 deletions

View File

@@ -11,18 +11,26 @@ import Tokenizers
struct LLMTool: AsyncParsableCommand {
static var configuration = CommandConfiguration(
abstract: "Command line tool for generating text and manipulating LLMs",
subcommands: [EvaluateCommand.self],
subcommands: [EvaluateCommand.self, LoRACommand.self],
defaultSubcommand: EvaluateCommand.self)
}
/// Command line arguments for loading a model.
struct ModelArguments: ParsableArguments {
@Option(name: .long, help: "Name of the huggingface model")
@Option(name: .long, help: "Name of the huggingface model or absolute path to directory")
var model: String = "mlx-community/Mistral-7B-v0.1-hf-4bit-mlx"
func load() async throws -> (LLMModel, Tokenizer, ModelConfiguration) {
let modelConfiguration = ModelConfiguration.configuration(id: model)
let modelConfiguration: ModelConfiguration
if self.model.hasPrefix("/") {
// path
modelConfiguration = ModelConfiguration(directory: URL(filePath: self.model))
} else {
// identifier
modelConfiguration = ModelConfiguration.configuration(id: model)
}
let (model, tokenizer) = try await LLM.load(configuration: modelConfiguration)
return (model, tokenizer, modelConfiguration)
}
@@ -31,7 +39,11 @@ struct ModelArguments: ParsableArguments {
/// Command line arguments for controlling generation of text.
struct GenerateArguments: ParsableArguments {
@Option(name: .shortAndLong, help: "The message to be processed by the model")
@Option(
name: .shortAndLong,
help:
"The message to be processed by the model. Use @path,@path to load from files, e.g. @/tmp/prompt.txt"
)
var prompt = "compare python and swift"
@Option(name: .shortAndLong, help: "Maximum number of tokens to generate")
@@ -52,18 +64,32 @@ struct GenerateArguments: ParsableArguments {
@Option(name: .long, help: "The PRNG seed")
var seed: UInt64 = 0
@Flag(name: .shortAndLong, help: "If true only print the generated output")
var quiet = false
var generateParameters: GenerateParameters {
GenerateParameters(
temperature: temperature, topP: topP, repetitionPenalty: repetitionPenalty,
repetitionContextSize: repetitionContextSize)
}
func tokenizePrompt(configuration: ModelConfiguration, tokenizer: Tokenizer) -> (String, [Int])
{
func resolvePrompt() throws -> String {
if prompt.hasPrefix("@") {
let names = prompt.split(separator: ",").map { String($0.dropFirst()) }
return try names.map { try String(contentsOfFile: $0) }.joined(separator: "\n")
} else {
return prompt
}
}
func tokenizePrompt(configuration: ModelConfiguration, tokenizer: Tokenizer) throws -> (
String, [Int]
) {
MLXRandom.seed(seed)
let prompt = configuration.prepare(prompt: self.prompt)
let promptTokens = tokenizer.encode(text: prompt)
let prompt = try resolvePrompt()
let preparedPrompt = configuration.prepare(prompt: prompt)
let promptTokens = tokenizer.encode(text: preparedPrompt)
return (prompt, promptTokens)
}
@@ -187,21 +213,27 @@ struct EvaluateCommand: AsyncParsableCommand {
mutating func run() async throws {
let (model, tokenizer, modelConfiguration) = try await memory.start(args.load)
print("Model loaded -> \(modelConfiguration.id)")
if !generate.quiet {
print("Model loaded -> \(modelConfiguration.id)")
}
let (prompt, promptTokens) = generate.tokenizePrompt(
let (prompt, promptTokens) = try generate.tokenizePrompt(
configuration: modelConfiguration, tokenizer: tokenizer)
print("Starting generation ...")
print(prompt, terminator: "")
if !generate.quiet {
print("Starting generation ...")
print(prompt, terminator: "")
}
let result = await generate.generate(
promptTokens: promptTokens, model: model, tokenizer: tokenizer)
print()
print("------")
print(result.summary())
memory.reportMemoryStatistics()
if !generate.quiet {
print("------")
print(result.summary())
memory.reportMemoryStatistics()
}
}
}