implement LoRA / QLoRA (#46)

* implement LoRA / QLoRA

- example of using MLX to fine-tune an LLM with low rank adaptation (LoRA) for a target task
- see also https://arxiv.org/abs/2106.09685
- based on https://github.com/ml-explore/mlx-examples/tree/main/lora

* add some command line flags I found useful during use
- --quiet -- don't print decorator text, just the generated text
- --prompt @/tmp/file.txt -- load prompt from file

* user can specify path to model OR model identifier in huggingface

* update mlx-swift reference

Co-authored-by: Ashraful Islam <ashraful.meche@gmail.com>
Co-authored-by: JustinMeans <46542161+JustinMeans@users.noreply.github.com>
This commit is contained in:
David Koski
2024-04-22 09:30:12 -07:00
committed by GitHub
parent 7e85eb8b88
commit 6c0b66f90a
32 changed files with 3483 additions and 64 deletions

View File

@@ -17,47 +17,64 @@ public func load(
hub: HubApi = HubApi(), configuration: ModelConfiguration,
progressHandler: @escaping (Progress) -> Void = { _ in }
) async throws -> (LLMModel, Tokenizer) {
// note: this doesn't have a way to pass the HubApi
let tokenizer = try await loadTokenizer(configuration: configuration)
do {
let tokenizer = try await loadTokenizer(configuration: configuration, hub: hub)
// download the model weights and config
let repo = Hub.Repo(id: configuration.id)
let modelFiles = ["config.json", "*.safetensors"]
let modelDirectory = try await hub.snapshot(
from: repo, matching: modelFiles, progressHandler: progressHandler)
let modelDirectory: URL
// create the model (no weights loaded)
let configurationURL = modelDirectory.appending(component: "config.json")
let baseConfig = try JSONDecoder().decode(
BaseConfiguration.self, from: Data(contentsOf: configurationURL))
switch configuration.id {
case .id(let id):
// download the model weights and config
let repo = Hub.Repo(id: id)
let modelFiles = ["config.json", "*.safetensors"]
modelDirectory = try await hub.snapshot(
from: repo, matching: modelFiles, progressHandler: progressHandler)
let model = try baseConfig.modelType.createModel(configuration: configurationURL)
case .directory(let directory):
modelDirectory = directory
}
// load the weights
var weights = [String: MLXArray]()
let enumerator = FileManager.default.enumerator(
at: modelDirectory, includingPropertiesForKeys: nil)!
for case let url as URL in enumerator {
if url.pathExtension == "safetensors" {
let w = try loadArrays(url: url)
for (key, value) in w {
weights[key] = value
// create the model (no weights loaded)
let configurationURL = modelDirectory.appending(component: "config.json")
let baseConfig = try JSONDecoder().decode(
BaseConfiguration.self, from: Data(contentsOf: configurationURL))
let model = try baseConfig.modelType.createModel(configuration: configurationURL)
// load the weights
var weights = [String: MLXArray]()
let enumerator = FileManager.default.enumerator(
at: modelDirectory, includingPropertiesForKeys: nil)!
for case let url as URL in enumerator {
if url.pathExtension == "safetensors" {
let w = try loadArrays(url: url)
for (key, value) in w {
weights[key] = value
}
}
}
// quantize if needed
if let quantization = baseConfig.quantization {
quantizeIfNeeded(model: model, weights: weights, quantization: quantization)
}
// apply the loaded weights
let parameters = ModuleParameters.unflattened(weights)
try model.update(parameters: parameters, verify: [.all])
eval(model)
return (model, tokenizer)
} catch Hub.HubClientError.authorizationRequired {
// an authorizationRequired means (typically) that the named repo doesn't exist on
// on the server so retry with local only configuration
var newConfiguration = configuration
newConfiguration.id = .directory(configuration.modelDirectory(hub: hub))
return try await load(
hub: hub, configuration: newConfiguration, progressHandler: progressHandler)
}
// quantize if needed
if let quantization = baseConfig.quantization {
quantizeIfNeeded(model: model, weights: weights, quantization: quantization)
}
// apply the loaded weights
let parameters = ModuleParameters.unflattened(weights)
try model.update(parameters: parameters, verify: [.all])
eval(model)
return (model, tokenizer)
}
// MARK: - Quantization