implement LoRA / QLoRA (#46)
* implement LoRA / QLoRA - example of using MLX to fine-tune an LLM with low rank adaptation (LoRA) for a target task - see also https://arxiv.org/abs/2106.09685 - based on https://github.com/ml-explore/mlx-examples/tree/main/lora * add some command line flags I found useful during use - --quiet -- don't print decorator text, just the generated text - --prompt @/tmp/file.txt -- load prompt from file * user can specify path to model OR model identifier in huggingface * update mlx-swift reference Co-authored-by: Ashraful Islam <ashraful.meche@gmail.com> Co-authored-by: JustinMeans <46542161+JustinMeans@users.noreply.github.com>
This commit is contained in:
@@ -11,18 +11,26 @@ import Tokenizers
|
||||
struct LLMTool: AsyncParsableCommand {
|
||||
static var configuration = CommandConfiguration(
|
||||
abstract: "Command line tool for generating text and manipulating LLMs",
|
||||
subcommands: [EvaluateCommand.self],
|
||||
subcommands: [EvaluateCommand.self, LoRACommand.self],
|
||||
defaultSubcommand: EvaluateCommand.self)
|
||||
}
|
||||
|
||||
/// Command line arguments for loading a model.
|
||||
struct ModelArguments: ParsableArguments {
|
||||
|
||||
@Option(name: .long, help: "Name of the huggingface model")
|
||||
@Option(name: .long, help: "Name of the huggingface model or absolute path to directory")
|
||||
var model: String = "mlx-community/Mistral-7B-v0.1-hf-4bit-mlx"
|
||||
|
||||
func load() async throws -> (LLMModel, Tokenizer, ModelConfiguration) {
|
||||
let modelConfiguration = ModelConfiguration.configuration(id: model)
|
||||
let modelConfiguration: ModelConfiguration
|
||||
|
||||
if self.model.hasPrefix("/") {
|
||||
// path
|
||||
modelConfiguration = ModelConfiguration(directory: URL(filePath: self.model))
|
||||
} else {
|
||||
// identifier
|
||||
modelConfiguration = ModelConfiguration.configuration(id: model)
|
||||
}
|
||||
let (model, tokenizer) = try await LLM.load(configuration: modelConfiguration)
|
||||
return (model, tokenizer, modelConfiguration)
|
||||
}
|
||||
@@ -31,7 +39,11 @@ struct ModelArguments: ParsableArguments {
|
||||
/// Command line arguments for controlling generation of text.
|
||||
struct GenerateArguments: ParsableArguments {
|
||||
|
||||
@Option(name: .shortAndLong, help: "The message to be processed by the model")
|
||||
@Option(
|
||||
name: .shortAndLong,
|
||||
help:
|
||||
"The message to be processed by the model. Use @path,@path to load from files, e.g. @/tmp/prompt.txt"
|
||||
)
|
||||
var prompt = "compare python and swift"
|
||||
|
||||
@Option(name: .shortAndLong, help: "Maximum number of tokens to generate")
|
||||
@@ -52,18 +64,32 @@ struct GenerateArguments: ParsableArguments {
|
||||
@Option(name: .long, help: "The PRNG seed")
|
||||
var seed: UInt64 = 0
|
||||
|
||||
@Flag(name: .shortAndLong, help: "If true only print the generated output")
|
||||
var quiet = false
|
||||
|
||||
var generateParameters: GenerateParameters {
|
||||
GenerateParameters(
|
||||
temperature: temperature, topP: topP, repetitionPenalty: repetitionPenalty,
|
||||
repetitionContextSize: repetitionContextSize)
|
||||
}
|
||||
|
||||
func tokenizePrompt(configuration: ModelConfiguration, tokenizer: Tokenizer) -> (String, [Int])
|
||||
{
|
||||
func resolvePrompt() throws -> String {
|
||||
if prompt.hasPrefix("@") {
|
||||
let names = prompt.split(separator: ",").map { String($0.dropFirst()) }
|
||||
return try names.map { try String(contentsOfFile: $0) }.joined(separator: "\n")
|
||||
} else {
|
||||
return prompt
|
||||
}
|
||||
}
|
||||
|
||||
func tokenizePrompt(configuration: ModelConfiguration, tokenizer: Tokenizer) throws -> (
|
||||
String, [Int]
|
||||
) {
|
||||
MLXRandom.seed(seed)
|
||||
|
||||
let prompt = configuration.prepare(prompt: self.prompt)
|
||||
let promptTokens = tokenizer.encode(text: prompt)
|
||||
let prompt = try resolvePrompt()
|
||||
let preparedPrompt = configuration.prepare(prompt: prompt)
|
||||
let promptTokens = tokenizer.encode(text: preparedPrompt)
|
||||
|
||||
return (prompt, promptTokens)
|
||||
}
|
||||
@@ -187,21 +213,27 @@ struct EvaluateCommand: AsyncParsableCommand {
|
||||
mutating func run() async throws {
|
||||
let (model, tokenizer, modelConfiguration) = try await memory.start(args.load)
|
||||
|
||||
print("Model loaded -> \(modelConfiguration.id)")
|
||||
if !generate.quiet {
|
||||
print("Model loaded -> \(modelConfiguration.id)")
|
||||
}
|
||||
|
||||
let (prompt, promptTokens) = generate.tokenizePrompt(
|
||||
let (prompt, promptTokens) = try generate.tokenizePrompt(
|
||||
configuration: modelConfiguration, tokenizer: tokenizer)
|
||||
|
||||
print("Starting generation ...")
|
||||
print(prompt, terminator: "")
|
||||
if !generate.quiet {
|
||||
print("Starting generation ...")
|
||||
print(prompt, terminator: "")
|
||||
}
|
||||
|
||||
let result = await generate.generate(
|
||||
promptTokens: promptTokens, model: model, tokenizer: tokenizer)
|
||||
|
||||
print()
|
||||
print("------")
|
||||
print(result.summary())
|
||||
|
||||
memory.reportMemoryStatistics()
|
||||
if !generate.quiet {
|
||||
print("------")
|
||||
print(result.summary())
|
||||
|
||||
memory.reportMemoryStatistics()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user