implement LoRA / QLoRA (#46)

* implement LoRA / QLoRA

- example of using MLX to fine-tune an LLM with low rank adaptation (LoRA) for a target task
- see also https://arxiv.org/abs/2106.09685
- based on https://github.com/ml-explore/mlx-examples/tree/main/lora

* add some command line flags I found useful during use
- --quiet -- don't print decorator text, just the generated text
- --prompt @/tmp/file.txt -- load prompt from file

* user can specify path to model OR model identifier in huggingface

* update mlx-swift reference

Co-authored-by: Ashraful Islam <ashraful.meche@gmail.com>
Co-authored-by: JustinMeans <46542161+JustinMeans@users.noreply.github.com>
This commit is contained in:
David Koski
2024-04-22 09:30:12 -07:00
committed by GitHub
parent 7e85eb8b88
commit 6c0b66f90a
32 changed files with 3483 additions and 64 deletions

View File

@@ -0,0 +1,61 @@
// Copyright © 2024 Apple Inc.
import Foundation
enum LoRADataError: Error {
case fileNotFound(URL, String)
}
/// Load a LoRA data file.
///
/// Given a directory and a base name, e.g. `train`, this will load a `.jsonl` or `.txt` file
/// if possible.
public func loadLoRAData(directory: URL, name: String) throws -> [String] {
let extensions = ["jsonl", "txt"]
for ext in extensions {
let url = directory.appending(component: "\(name).\(ext)")
if FileManager.default.fileExists(atPath: url.path()) {
return try loadLoRAData(url: url)
}
}
throw LoRADataError.fileNotFound(directory, name)
}
/// Load a .txt or .jsonl file and return the contents
public func loadLoRAData(url: URL) throws -> [String] {
switch url.pathExtension {
case "jsonl":
return try loadJSONL(url: url)
case "txt":
return try loadLines(url: url)
default:
fatalError("Unable to load data file, unknown type: \(url)")
}
}
func loadJSONL(url: URL) throws -> [String] {
struct Line: Codable {
let text: String?
}
return try String(contentsOf: url)
.components(separatedBy: .newlines)
.filter {
$0.first == "{"
}
.compactMap {
try JSONDecoder().decode(Line.self, from: $0.data(using: .utf8)!).text
}
}
func loadLines(url: URL) throws -> [String] {
try String(contentsOf: url)
.components(separatedBy: .newlines)
.filter { !$0.isEmpty }
}