implement LoRA / QLoRA (#46)
* implement LoRA / QLoRA - example of using MLX to fine-tune an LLM with low rank adaptation (LoRA) for a target task - see also https://arxiv.org/abs/2106.09685 - based on https://github.com/ml-explore/mlx-examples/tree/main/lora * add some command line flags I found useful during use - --quiet -- don't print decorator text, just the generated text - --prompt @/tmp/file.txt -- load prompt from file * user can specify path to model OR model identifier in huggingface * update mlx-swift reference Co-authored-by: Ashraful Islam <ashraful.meche@gmail.com> Co-authored-by: JustinMeans <46542161+JustinMeans@users.noreply.github.com>
This commit is contained in:
@@ -236,3 +236,11 @@ public struct CohereConfiguration: Codable {
|
||||
Float.self, forKey: CohereConfiguration.CodingKeys.logitScale)
|
||||
}
|
||||
}
|
||||
|
||||
// MARK: - LoRA
|
||||
|
||||
extension CohereModel: LoRAModel {
|
||||
public func loraLinearLayers() -> LoRALinearLayers {
|
||||
model.layers.map { ($0.attention, ["q_proj", "v_proj"]) }
|
||||
}
|
||||
}
|
||||
|
||||
@@ -254,3 +254,11 @@ public struct GemmaConfiguration: Codable {
|
||||
Bool.self, forKey: CodingKeys.ropeTraditional) ?? false
|
||||
}
|
||||
}
|
||||
|
||||
// MARK: - LoRA
|
||||
|
||||
extension GemmaModel: LoRAModel {
|
||||
public func loraLinearLayers() -> LoRALinearLayers {
|
||||
model.layers.map { ($0.attention, ["q_proj", "v_proj"]) }
|
||||
}
|
||||
}
|
||||
|
||||
@@ -253,3 +253,11 @@ public struct LlamaConfiguration: Codable {
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
// MARK: - LoRA
|
||||
|
||||
extension LlamaModel: LoRAModel {
|
||||
public func loraLinearLayers() -> LoRALinearLayers {
|
||||
model.layers.map { ($0.attention, ["q_proj", "v_proj"]) }
|
||||
}
|
||||
}
|
||||
|
||||
@@ -17,47 +17,64 @@ public func load(
|
||||
hub: HubApi = HubApi(), configuration: ModelConfiguration,
|
||||
progressHandler: @escaping (Progress) -> Void = { _ in }
|
||||
) async throws -> (LLMModel, Tokenizer) {
|
||||
// note: this doesn't have a way to pass the HubApi
|
||||
let tokenizer = try await loadTokenizer(configuration: configuration)
|
||||
do {
|
||||
let tokenizer = try await loadTokenizer(configuration: configuration, hub: hub)
|
||||
|
||||
// download the model weights and config
|
||||
let repo = Hub.Repo(id: configuration.id)
|
||||
let modelFiles = ["config.json", "*.safetensors"]
|
||||
let modelDirectory = try await hub.snapshot(
|
||||
from: repo, matching: modelFiles, progressHandler: progressHandler)
|
||||
let modelDirectory: URL
|
||||
|
||||
// create the model (no weights loaded)
|
||||
let configurationURL = modelDirectory.appending(component: "config.json")
|
||||
let baseConfig = try JSONDecoder().decode(
|
||||
BaseConfiguration.self, from: Data(contentsOf: configurationURL))
|
||||
switch configuration.id {
|
||||
case .id(let id):
|
||||
// download the model weights and config
|
||||
let repo = Hub.Repo(id: id)
|
||||
let modelFiles = ["config.json", "*.safetensors"]
|
||||
modelDirectory = try await hub.snapshot(
|
||||
from: repo, matching: modelFiles, progressHandler: progressHandler)
|
||||
|
||||
let model = try baseConfig.modelType.createModel(configuration: configurationURL)
|
||||
case .directory(let directory):
|
||||
modelDirectory = directory
|
||||
}
|
||||
|
||||
// load the weights
|
||||
var weights = [String: MLXArray]()
|
||||
let enumerator = FileManager.default.enumerator(
|
||||
at: modelDirectory, includingPropertiesForKeys: nil)!
|
||||
for case let url as URL in enumerator {
|
||||
if url.pathExtension == "safetensors" {
|
||||
let w = try loadArrays(url: url)
|
||||
for (key, value) in w {
|
||||
weights[key] = value
|
||||
// create the model (no weights loaded)
|
||||
let configurationURL = modelDirectory.appending(component: "config.json")
|
||||
let baseConfig = try JSONDecoder().decode(
|
||||
BaseConfiguration.self, from: Data(contentsOf: configurationURL))
|
||||
|
||||
let model = try baseConfig.modelType.createModel(configuration: configurationURL)
|
||||
|
||||
// load the weights
|
||||
var weights = [String: MLXArray]()
|
||||
let enumerator = FileManager.default.enumerator(
|
||||
at: modelDirectory, includingPropertiesForKeys: nil)!
|
||||
for case let url as URL in enumerator {
|
||||
if url.pathExtension == "safetensors" {
|
||||
let w = try loadArrays(url: url)
|
||||
for (key, value) in w {
|
||||
weights[key] = value
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// quantize if needed
|
||||
if let quantization = baseConfig.quantization {
|
||||
quantizeIfNeeded(model: model, weights: weights, quantization: quantization)
|
||||
}
|
||||
|
||||
// apply the loaded weights
|
||||
let parameters = ModuleParameters.unflattened(weights)
|
||||
try model.update(parameters: parameters, verify: [.all])
|
||||
|
||||
eval(model)
|
||||
|
||||
return (model, tokenizer)
|
||||
|
||||
} catch Hub.HubClientError.authorizationRequired {
|
||||
// an authorizationRequired means (typically) that the named repo doesn't exist on
|
||||
// on the server so retry with local only configuration
|
||||
var newConfiguration = configuration
|
||||
newConfiguration.id = .directory(configuration.modelDirectory(hub: hub))
|
||||
return try await load(
|
||||
hub: hub, configuration: newConfiguration, progressHandler: progressHandler)
|
||||
}
|
||||
|
||||
// quantize if needed
|
||||
if let quantization = baseConfig.quantization {
|
||||
quantizeIfNeeded(model: model, weights: weights, quantization: quantization)
|
||||
}
|
||||
|
||||
// apply the loaded weights
|
||||
let parameters = ModuleParameters.unflattened(weights)
|
||||
try model.update(parameters: parameters, verify: [.all])
|
||||
|
||||
eval(model)
|
||||
|
||||
return (model, tokenizer)
|
||||
}
|
||||
|
||||
// MARK: - Quantization
|
||||
|
||||
61
Libraries/LLM/Lora+Data.swift
Normal file
61
Libraries/LLM/Lora+Data.swift
Normal file
@@ -0,0 +1,61 @@
|
||||
// Copyright © 2024 Apple Inc.
|
||||
|
||||
import Foundation
|
||||
|
||||
enum LoRADataError: Error {
|
||||
case fileNotFound(URL, String)
|
||||
}
|
||||
|
||||
/// Load a LoRA data file.
|
||||
///
|
||||
/// Given a directory and a base name, e.g. `train`, this will load a `.jsonl` or `.txt` file
|
||||
/// if possible.
|
||||
public func loadLoRAData(directory: URL, name: String) throws -> [String] {
|
||||
let extensions = ["jsonl", "txt"]
|
||||
|
||||
for ext in extensions {
|
||||
let url = directory.appending(component: "\(name).\(ext)")
|
||||
if FileManager.default.fileExists(atPath: url.path()) {
|
||||
return try loadLoRAData(url: url)
|
||||
}
|
||||
}
|
||||
|
||||
throw LoRADataError.fileNotFound(directory, name)
|
||||
}
|
||||
|
||||
/// Load a .txt or .jsonl file and return the contents
|
||||
public func loadLoRAData(url: URL) throws -> [String] {
|
||||
switch url.pathExtension {
|
||||
case "jsonl":
|
||||
return try loadJSONL(url: url)
|
||||
|
||||
case "txt":
|
||||
return try loadLines(url: url)
|
||||
|
||||
default:
|
||||
fatalError("Unable to load data file, unknown type: \(url)")
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
func loadJSONL(url: URL) throws -> [String] {
|
||||
|
||||
struct Line: Codable {
|
||||
let text: String?
|
||||
}
|
||||
|
||||
return try String(contentsOf: url)
|
||||
.components(separatedBy: .newlines)
|
||||
.filter {
|
||||
$0.first == "{"
|
||||
}
|
||||
.compactMap {
|
||||
try JSONDecoder().decode(Line.self, from: $0.data(using: .utf8)!).text
|
||||
}
|
||||
}
|
||||
|
||||
func loadLines(url: URL) throws -> [String] {
|
||||
try String(contentsOf: url)
|
||||
.components(separatedBy: .newlines)
|
||||
.filter { !$0.isEmpty }
|
||||
}
|
||||
639
Libraries/LLM/Lora.swift
Normal file
639
Libraries/LLM/Lora.swift
Normal file
@@ -0,0 +1,639 @@
|
||||
// Copyright © 2024 Apple Inc.
|
||||
|
||||
import Foundation
|
||||
import MLX
|
||||
import MLXNN
|
||||
import MLXOptimizers
|
||||
import MLXRandom
|
||||
import Tokenizers
|
||||
|
||||
/// Layers to apply LoRA adapters to.
|
||||
///
|
||||
/// This is the value returned by ``LoRAModel/loraLinearLayers()``.
|
||||
public typealias LoRALinearLayers = [(Module, [String])]
|
||||
|
||||
public protocol LoRAModel {
|
||||
/// Return the layers and keys to apply LoRA adapters to.
|
||||
///
|
||||
/// For example this might apply the adapters to the `q` an `v` projections in the
|
||||
/// Attention layers:
|
||||
///
|
||||
/// ```swift
|
||||
/// model.layers.map { ($0.attention, ["q_proj", "v_proj"]) }
|
||||
/// ```
|
||||
///
|
||||
/// It is not required that a model implement this protocol to have LoRA adapters applied, but
|
||||
/// the command line driver example uses this to produce the ``LoRALinearLayers``.
|
||||
///
|
||||
/// ### See Also
|
||||
/// - ``LoRATrain/convert(model:layers:)``
|
||||
func loraLinearLayers() -> LoRALinearLayers
|
||||
}
|
||||
|
||||
/// Protocol for LoRA implementations that provides a method for converting back to a `Linear`
|
||||
/// (or subtype).
|
||||
///
|
||||
/// This is normally called via ``LoRATrain/fuse(model:layers:deQuantize:)``
|
||||
public protocol LoRAConvertToLinear {
|
||||
func toLinear(deQuantize: Bool) -> Linear
|
||||
}
|
||||
|
||||
/// Implementation of LoRA `Linear` replacement layer.
|
||||
///
|
||||
/// This layer implements the LoRA capabilities for `Linear` layers, specifically:
|
||||
///
|
||||
/// - converting `Linear` or `QuantizedLinear` layers to ``LoRALinear`` / ``QLoRALinear``
|
||||
/// - converting ``LoRALinear`` back to `Linear` or `QuantizedLinear` (``LoRAConvertToLinear``)
|
||||
/// - implementing the LoRA evaluation
|
||||
///
|
||||
/// ``QLoRALinear`` is the equivalent class for `QuantizedLinear`.
|
||||
///
|
||||
/// This is not typically used directly -- ``LoRATrain/convert(model:layers:)`` is used to
|
||||
/// add the adapter layers to a given model.
|
||||
///
|
||||
/// ### See Also
|
||||
/// - [LoRA: Low-Rank Adaptation of Large Language Models](https://arxiv.org/abs/2106.09685)
|
||||
/// - [QLoRA: Efficient Finetuning of Quantized LLMs](https://arxiv.org/abs/2305.14314)
|
||||
/// - ``QLoRALinear``
|
||||
/// - ``LoRATrain/convert(model:layers:)``
|
||||
/// - ``LoRATrain/fuse(model:layers:deQuantize:)``
|
||||
public class LoRALinear: Linear, LoRAConvertToLinear {
|
||||
|
||||
let scale: Float
|
||||
|
||||
@ParameterInfo(key: "lora_a") var loraA: MLXArray
|
||||
@ParameterInfo(key: "lora_b") var loraB: MLXArray
|
||||
|
||||
required public init(
|
||||
_ inputDimensions: Int, _ outputDimensions: Int, rank: Int = 8, bias: Bool = false,
|
||||
scale: Float = 20.0, linear: Linear
|
||||
) {
|
||||
// Scale for low-rank update
|
||||
self.scale = scale
|
||||
|
||||
// Low rank lora weights
|
||||
let loraScale = 1 / sqrt(Float(inputDimensions))
|
||||
self._loraA.wrappedValue = MLXRandom.uniform(
|
||||
low: -loraScale, high: loraScale, [inputDimensions, rank])
|
||||
self._loraB.wrappedValue = MLXArray.zeros([rank, outputDimensions])
|
||||
|
||||
super.init(weight: linear.weight, bias: linear.bias)
|
||||
|
||||
freeze()
|
||||
}
|
||||
|
||||
/// Freeze all parameters except the lora parameters
|
||||
public override func freeze(recursive: Bool = true, keys: [String]? = nil, strict: Bool = false)
|
||||
throws
|
||||
{
|
||||
// realize the keys and omit the lora parameters
|
||||
let keys =
|
||||
(keys ?? self.filterMap(filter: Self.filterLocalParameters).flattened().map { $0.0 })
|
||||
.filter {
|
||||
$0 != "lora_a" && $0 != "lora_b"
|
||||
}
|
||||
try super.freeze(recursive: recursive, keys: keys, strict: strict)
|
||||
}
|
||||
|
||||
/// Convert a `Linear` or `QuantizedLinear` layer into a new `Linear` layer
|
||||
/// that implements the `LoRA` adapter.
|
||||
///
|
||||
/// This is typically called via ``LoRATrain/convert(model:layers:)``.
|
||||
///
|
||||
/// ### See Also
|
||||
/// - ``LoRATrain/convert(model:layers:)``
|
||||
/// - ``QLoRALinear/from(linear:rank:)``
|
||||
public static func from(linear: Linear, rank: Int = 8) -> Linear {
|
||||
if let linear = linear as? QuantizedLinear {
|
||||
return QLoRALinear.from(linear: linear, rank: rank)
|
||||
}
|
||||
let (outputDimensions, inputDimensions) = linear.shape
|
||||
return LoRALinear(inputDimensions, outputDimensions, rank: rank, linear: linear)
|
||||
}
|
||||
|
||||
/// Convert back into a fused `Linear` layer.
|
||||
///
|
||||
/// This is typically called via ``LoRATrain/fuse(model:layers:deQuantize:)``.
|
||||
///
|
||||
/// ### See Also
|
||||
/// - ``LoRATrain/fuse(model:layers:deQuantize:)``
|
||||
/// - ``LoRAConvertToLinear``
|
||||
/// - ``QLoRALinear/toLinear(deQuantize:)``
|
||||
public func toLinear(deQuantize: Bool = false) -> Linear {
|
||||
let dtype = weight.dtype
|
||||
let loraB = (scale * loraB.T).asType(dtype)
|
||||
let loraA = loraA.T.asType(dtype)
|
||||
return Linear(weight: weight + matmul(loraB, loraA), bias: bias)
|
||||
}
|
||||
|
||||
public override func callAsFunction(_ x: MLXArray) -> MLXArray {
|
||||
let y = super.callAsFunction(x.asType(weight.dtype))
|
||||
let z = matmul(matmul(x, self.loraA), self.loraB)
|
||||
return y + scale * z
|
||||
}
|
||||
}
|
||||
|
||||
/// Implementation of LoRA `QuantizedLinear` replacement layer.
|
||||
///
|
||||
/// See ``LoRALinear`` (equivalent class for `Linear` layers) for more information.
|
||||
public class QLoRALinear: QuantizedLinear, LoRAConvertToLinear {
|
||||
|
||||
let scale: Float
|
||||
|
||||
@ParameterInfo(key: "lora_a") var loraA: MLXArray
|
||||
@ParameterInfo(key: "lora_b") var loraB: MLXArray
|
||||
|
||||
required public init(
|
||||
_ inputDimensions: Int, _ outputDimensions: Int, rank: Int = 8, bias: Bool = false,
|
||||
scale: Float = 20.0, linear: QuantizedLinear
|
||||
) {
|
||||
|
||||
// Scale for low-rank update
|
||||
self.scale = scale
|
||||
|
||||
// Low rank lora weights
|
||||
let loraScale = 1 / sqrt(Float(inputDimensions))
|
||||
self._loraA.wrappedValue = MLXRandom.uniform(
|
||||
low: -loraScale, high: loraScale, [inputDimensions, rank])
|
||||
self._loraB.wrappedValue = MLXArray.zeros([rank, outputDimensions])
|
||||
|
||||
super.init(
|
||||
weight: linear.weight, bias: linear.bias, scales: linear.scales, biases: linear.biases,
|
||||
groupSize: linear.groupSize, bits: linear.bits)
|
||||
|
||||
// start frozen except for the lora keys
|
||||
freeze()
|
||||
}
|
||||
|
||||
/// Freeze all parameters except the lora parameters
|
||||
public override func freeze(recursive: Bool = true, keys: [String]? = nil, strict: Bool = false)
|
||||
throws
|
||||
{
|
||||
// realize the keys and omit the lora parameters
|
||||
let keys =
|
||||
(keys ?? self.filterMap(filter: Self.filterLocalParameters).flattened().map { $0.0 })
|
||||
.filter {
|
||||
$0 != "lora_a" && $0 != "lora_b"
|
||||
}
|
||||
try super.freeze(recursive: recursive, keys: keys, strict: strict)
|
||||
}
|
||||
|
||||
/// Convert a `QuantizedLinear` layer into a new `Linear` layer
|
||||
/// that implements the `LoRA` adapter.
|
||||
///
|
||||
/// This is typically called via ``LoRATrain/convert(model:layers:)``.
|
||||
///
|
||||
/// ### See Also
|
||||
/// - ``LoRATrain/convert(model:layers:)``
|
||||
/// - ``LoRALinear/from(linear:rank:)``
|
||||
public static func from(linear: QuantizedLinear, rank: Int = 8) -> Linear {
|
||||
var (outputDimensions, inputDimensions) = linear.shape
|
||||
inputDimensions = inputDimensions * 32 / linear.bits
|
||||
return QLoRALinear(inputDimensions, outputDimensions, rank: rank, linear: linear)
|
||||
}
|
||||
|
||||
/// Convert back into a fused `QuantizedLinear` layer.
|
||||
///
|
||||
/// This is typically called via ``LoRATrain/fuse(model:layers:deQuantize:)``.
|
||||
///
|
||||
/// ### See Also
|
||||
/// - ``LoRATrain/fuse(model:layers:deQuantize:)``
|
||||
public func toLinear(deQuantize: Bool = false) -> Linear {
|
||||
// convert back into full weights
|
||||
let weight = dequantized(
|
||||
weight, scales: scales, biases: biases, groupSize: groupSize, bits: bits)
|
||||
|
||||
let loraB = (scale * loraB.T).asType(.float16)
|
||||
let loraA = loraA.T.asType(.float16)
|
||||
|
||||
// convert back into quantized
|
||||
return QuantizedLinear(
|
||||
weight: weight + matmul(loraB, loraA), bias: bias, groupSize: groupSize, bits: bits)
|
||||
}
|
||||
|
||||
public override func callAsFunction(_ x: MLXArray) -> MLXArray {
|
||||
let y = super.callAsFunction(x.asType(scales.dtype))
|
||||
let z = matmul(matmul(x, self.loraA), self.loraB)
|
||||
return y + scale * z
|
||||
}
|
||||
}
|
||||
|
||||
/// Equivalent to `lora.py/iterate_batches()`. Used internally by ``LoRATrain``.
|
||||
struct LoRABatchIterator: Sequence, IteratorProtocol {
|
||||
|
||||
let dataset: [String]
|
||||
let batchSize: Int
|
||||
let tokenizer: Tokenizer
|
||||
|
||||
let train: Bool
|
||||
|
||||
var indices: [Int]
|
||||
var index = 0
|
||||
|
||||
public init(dataset: [String], tokenizer: Tokenizer, batchSize: Int, train: Bool) {
|
||||
self.dataset = dataset
|
||||
self.batchSize = batchSize
|
||||
self.tokenizer = tokenizer
|
||||
self.train = train
|
||||
|
||||
self.indices = Array(0 ..< dataset.count)
|
||||
if train {
|
||||
indices.shuffle()
|
||||
}
|
||||
}
|
||||
|
||||
mutating public func next() -> (MLXArray, MLXArray, MLXArray)? {
|
||||
if index >= indices.count {
|
||||
if !train {
|
||||
return nil
|
||||
}
|
||||
|
||||
indices.shuffle()
|
||||
index = 0
|
||||
}
|
||||
|
||||
let endIndex = Swift.min(index + batchSize, indices.count)
|
||||
|
||||
let batch = (index ..< endIndex)
|
||||
.map { tokenizer.encode(text: dataset[indices[$0]]) }
|
||||
let lengths = batch.map { $0.count }
|
||||
let maxLength = lengths.max() ?? 0
|
||||
|
||||
if maxLength > 2048 {
|
||||
print(
|
||||
"""
|
||||
[WARNING] Some sequences are longer than 2048 tokens.
|
||||
Consider pre-splitting your data to save memory.
|
||||
""")
|
||||
}
|
||||
|
||||
// pad to the max length
|
||||
let batchArray = MLXArray.zeros([lengths.count, maxLength], type: Int32.self)
|
||||
for (j, (b, l)) in zip(batch, lengths).enumerated() {
|
||||
batchArray[j, 0 ..< l] = MLXArray(b)
|
||||
}
|
||||
|
||||
index = endIndex
|
||||
|
||||
return (batchArray[0..., .stride(to: -1)], batchArray[0..., 1...], MLXArray(lengths))
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
/// Collection of functions for adding LoRA adapters to an LLM model, training, fusing and saving/loading weights.
|
||||
///
|
||||
/// The typical flow for training is:
|
||||
///
|
||||
/// ```swift
|
||||
/// // load the base model and tokenizer
|
||||
/// let (model, tokenizer) = try await LLM.load(configuration: ModelConfiguration.mistral7B4bit)
|
||||
///
|
||||
/// // add LoRALinear adapter layers
|
||||
/// LoRATrain.convert(model: model, layers: Array(model.loraLinearLayers().suffix(4)))
|
||||
///
|
||||
/// // optionally load LoRA weights
|
||||
/// try LoRATrain.loadLoRAWeights(model: model, url: ...)
|
||||
///
|
||||
/// // load the train/validation data
|
||||
/// let train = try loadLoRAData(directory: data, name: "train")
|
||||
/// let valid = try loadLoRAData(directory: data, name: "valid")
|
||||
///
|
||||
/// // train
|
||||
/// let optimizer = Adam(learningRate: 1e-5)
|
||||
/// try await LoRATrain.train(
|
||||
/// model: model, train: train, validate: valid, optimizer: optimizer, tokenizer: tokenizer,
|
||||
/// parameters: LoRATrain.Parameters()
|
||||
/// ) { progress in
|
||||
/// print(progress)
|
||||
/// return .more
|
||||
/// }
|
||||
/// ```
|
||||
///
|
||||
/// At this point the model will be trained and you could do one of the following:
|
||||
///
|
||||
/// - ``saveLoRAWeights(model:url:)`` -- write the LoRA weights to a file
|
||||
/// - ``fuse(model:layers:deQuantize:)`` -- fuse the LoRA weights and convert back into the original model
|
||||
/// architecture. These weights can be saved and reloaded with normal model handling code.
|
||||
/// - ``evaluate(model:dataset:loss:tokenizer:batchSize:batchCount:)``-- compute the test loss
|
||||
/// againts a test dataset
|
||||
/// - use the in memory model as a normal `LLMModel` and evaluate a prompt
|
||||
///
|
||||
public enum LoRATrain {
|
||||
|
||||
public typealias LoraLossFunction = (Module, MLXArray, MLXArray, MLXArray) -> (
|
||||
MLXArray, MLXArray
|
||||
)
|
||||
|
||||
/// LoRA training parameters
|
||||
public struct Parameters {
|
||||
/// number of prompts to evaluate per iteration
|
||||
public var batchSize = 4
|
||||
|
||||
/// number of iterations to train for
|
||||
public var iterations = 1000
|
||||
|
||||
/// number of training steps between loss reporting
|
||||
public var stepsPerReport = 10
|
||||
|
||||
/// number of steps between validations
|
||||
public var stepsPerEval = 100
|
||||
|
||||
/// number of validations batches, `0` uses the entire validation set
|
||||
public var validationBatches = 10
|
||||
|
||||
/// save the model every N iterations
|
||||
public var saveEvery = 100
|
||||
|
||||
/// save path for the adapter `.safetensors`
|
||||
public var adapterURL: URL?
|
||||
|
||||
public init(
|
||||
batchSize: Int = 4, iterations: Int = 1000, stepsPerReport: Int = 10,
|
||||
stepsPerEval: Int = 100, validationBatches: Int = 10, saveEvery: Int = 100,
|
||||
adapterURL: URL? = nil
|
||||
) {
|
||||
self.batchSize = batchSize
|
||||
self.iterations = iterations
|
||||
self.stepsPerReport = stepsPerReport
|
||||
self.stepsPerEval = stepsPerEval
|
||||
self.validationBatches = validationBatches
|
||||
self.saveEvery = saveEvery
|
||||
self.adapterURL = adapterURL
|
||||
}
|
||||
}
|
||||
|
||||
/// Freeze the model layers and replace the indicated modules (Linear) that should be
|
||||
/// converted to ``LoRALinear`` and remain trainable.
|
||||
///
|
||||
/// Once a model has had the LoRA adapters applied, adapter weights can be loaded
|
||||
/// (if available):
|
||||
///
|
||||
/// ```swift
|
||||
/// try LoRATrain.loadLoRAWeights(model: model, url: args.adapter)
|
||||
/// ```
|
||||
///
|
||||
/// At this point the model is ready for one or more of the following:
|
||||
///
|
||||
/// - training with ``train(model:train:validate:optimizer:loss:tokenizer:parameters:progress:)``
|
||||
/// - loss evaluation with ``evaluate(model:dataset:loss:tokenizer:batchSize:batchCount:)``
|
||||
/// - fusing with ``fuse(model:layers:deQuantize:)``
|
||||
/// - text generation with ``generate(promptTokens:parameters:model:tokenizer:didGenerate:)``
|
||||
/// - note that this is just using normal model text generation
|
||||
///
|
||||
/// - Parameters:
|
||||
/// - model: model to convert
|
||||
/// - layers: number of suffix layers to convert
|
||||
public static func convert(model: Module, layers: LoRALinearLayers) {
|
||||
model.freeze()
|
||||
|
||||
for (layer, keys) in layers {
|
||||
var update = ModuleChildren()
|
||||
let children = layer.children()
|
||||
for key in keys {
|
||||
if let item = children[key], case .value(let child) = item {
|
||||
if let linear = child as? Linear {
|
||||
update[key] = .value(LoRALinear.from(linear: linear))
|
||||
} else {
|
||||
print("\(key) on \(layer) is not Linear")
|
||||
}
|
||||
} else {
|
||||
print("failed to find key \(key) on \(layer)")
|
||||
}
|
||||
}
|
||||
layer.update(modules: update)
|
||||
}
|
||||
}
|
||||
|
||||
/// Fuses the LoRA adapters back into the model weights.
|
||||
///
|
||||
/// This produces a model in the original format with `Linear` or `QuantizedLinear` layer
|
||||
/// weights that incorporate the LoRA adapter.
|
||||
///
|
||||
/// - Parameters:
|
||||
/// - model: model to convert
|
||||
/// - deQuantize: if `true` will convert `QuantizedLinear` back into `Linear`
|
||||
public static func fuse(model: Module, layers: LoRALinearLayers, deQuantize: Bool = false) {
|
||||
for (layer, keys) in layers {
|
||||
var update = ModuleChildren()
|
||||
let children = layer.children()
|
||||
for key in keys {
|
||||
if let item = children[key], case .value(let child) = item {
|
||||
if let lora = child as? LoRAConvertToLinear {
|
||||
update[key] = .value(lora.toLinear(deQuantize: deQuantize))
|
||||
}
|
||||
}
|
||||
}
|
||||
if !update.isEmpty {
|
||||
layer.update(modules: update)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
public static func loss(model: Module, inputs: MLXArray, targets: MLXArray, lengths: MLXArray)
|
||||
-> (
|
||||
MLXArray, MLXArray
|
||||
)
|
||||
{
|
||||
// def loss(model, inputs, targets, lengths):
|
||||
|
||||
// run model on inputs
|
||||
let model = model as! LLMModel
|
||||
let logits = model(inputs, cache: nil).0.asType(.float32)
|
||||
|
||||
// mask padding tokens
|
||||
let lengthMask = MLXArray(0 ..< inputs.dim(1))[.newAxis, 0...] .< lengths[0..., .newAxis]
|
||||
|
||||
// calculate the loss
|
||||
let ntoks = lengthMask.sum()
|
||||
let ce = (crossEntropy(logits: logits, targets: targets) * lengthMask).sum() / ntoks
|
||||
return (ce, ntoks)
|
||||
}
|
||||
|
||||
/// Evaluate the model and dataset and return the loss over the entire dataset.
|
||||
///
|
||||
/// - Parameters:
|
||||
/// - model: the model to evaluate
|
||||
/// - dataset: the dataset
|
||||
/// - loss: loss function
|
||||
/// - tokenizer: tokenizer
|
||||
/// - batchSize: number of items from the dataset to evaluate at once
|
||||
/// - batchCount: number of batch elements to evaluate, 0 for all
|
||||
/// - Returns: the loss over the enumerate data
|
||||
///
|
||||
/// ### See Also
|
||||
/// - ``loadLoRAData(directory:name:)``
|
||||
public static func evaluate(
|
||||
model: Module, dataset: [String], loss: LoraLossFunction = loss, tokenizer: Tokenizer,
|
||||
batchSize: Int, batchCount: Int
|
||||
) -> Float {
|
||||
var allLosses = [Float]()
|
||||
var tokenCount = 0
|
||||
|
||||
for (iteration, (inputs, targets, lengths)) in LoRABatchIterator(
|
||||
dataset: dataset, tokenizer: tokenizer, batchSize: batchSize, train: false
|
||||
).enumerated() {
|
||||
let (losses, tokens) = loss(model, inputs, targets, lengths)
|
||||
allLosses.append((losses * tokens).item(Float.self))
|
||||
tokenCount += tokens.item(Int.self)
|
||||
|
||||
if batchCount != 0 && iteration + 1 >= batchCount {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
return (sum(MLXArray(allLosses), stream: .cpu) / tokenCount).item(Float.self)
|
||||
}
|
||||
|
||||
/// Given a model with LoRA adaptors applied, load adapter weights from a `.safetensors` file.
|
||||
///
|
||||
/// ### See Also
|
||||
/// - ``convert(model:layers:)``
|
||||
/// - ``saveLoRAWeights(model:url:)``
|
||||
public static func loadLoRAWeights(model: Module, url: URL) throws {
|
||||
let weights = try ModuleParameters.unflattened(loadArrays(url: url))
|
||||
try model.update(parameters: weights, verify: .noUnusedKeys)
|
||||
eval(model)
|
||||
}
|
||||
|
||||
/// Given a model with LoRA adaptors applied, write adapter weights to a `.safetensors` file.
|
||||
///
|
||||
/// ### See Also
|
||||
/// - ``convert(model:layers:)``
|
||||
/// - ``loadLoRAWeights(model:url:)``
|
||||
public static func saveLoRAWeights(model: Module, url: URL) throws {
|
||||
let parameters = Dictionary(
|
||||
uniqueKeysWithValues: model.trainableParameters().flattened())
|
||||
try save(arrays: parameters, url: url)
|
||||
}
|
||||
|
||||
public enum Progress: CustomStringConvertible {
|
||||
case train(
|
||||
iteration: Int, trainingLoss: Float, iterationsPerSecond: Double,
|
||||
tokensPerSecond: Double)
|
||||
case validation(iteration: Int, validationLoss: Float, validationTime: Double)
|
||||
case save(iteration: Int, url: URL)
|
||||
|
||||
public var description: String {
|
||||
switch self {
|
||||
case .train(
|
||||
let iteration, let trainingLoss, let iterationsPerSecond, let tokensPerSecond):
|
||||
"Iteration \(iteration + 1): training loss \(trainingLoss.formatted()), "
|
||||
+ "iterations/sec \(iterationsPerSecond.formatted()), "
|
||||
+ "Tokens/sec \(tokensPerSecond.formatted())"
|
||||
case .validation(let iteration, let validationLoss, let validationTime):
|
||||
"Iteration \(iteration + 1): "
|
||||
+ "validation loss \(validationLoss.formatted()), "
|
||||
+ "validation time \(validationTime.formatted())s"
|
||||
case .save(let iteration, let url):
|
||||
"Iteration \(iteration + 1): saved weights to \(url.path())"
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
public enum ProgressDisposition {
|
||||
case stop
|
||||
case more
|
||||
}
|
||||
|
||||
/// Train (or continue training) LoRA weights.
|
||||
///
|
||||
/// - Parameters:
|
||||
/// - model: model to train
|
||||
/// - train: training dataset
|
||||
/// - validate: validate dataset
|
||||
/// - optimizer: optimizer used in training
|
||||
/// - loss: loss function
|
||||
/// - tokenizer: tokenizer
|
||||
/// - parameters: training parameters
|
||||
/// - progress: progress callback
|
||||
public static func train(
|
||||
model: Module, train: [String], validate: [String], optimizer: Optimizer,
|
||||
loss: @escaping LoraLossFunction = loss, tokenizer: Tokenizer, parameters: Parameters,
|
||||
progress: (Progress) async -> ProgressDisposition
|
||||
) async throws {
|
||||
// def train(model, train_set, val_set, optimizer, loss, tokenizer, args)
|
||||
|
||||
let lossValueGrad = valueAndGrad(model: model) { model, arrays in
|
||||
let (ce, ntoks) = loss(model, arrays[0], arrays[1], arrays[2])
|
||||
return [ce, ntoks]
|
||||
}
|
||||
|
||||
var losses = [Float]()
|
||||
var tokenCount = 0
|
||||
|
||||
var start = Date.timeIntervalSinceReferenceDate
|
||||
|
||||
for (iteration, (inputs, targets, lengths)) in LoRABatchIterator(
|
||||
dataset: train, tokenizer: tokenizer, batchSize: parameters.batchSize, train: true
|
||||
).enumerated() {
|
||||
// forward and backward pass
|
||||
let (resultArray, grad) = lossValueGrad(model, [inputs, targets, lengths])
|
||||
let lvalue = resultArray[0]
|
||||
let tokens = resultArray[1]
|
||||
|
||||
// model update
|
||||
optimizer.update(model: model, gradients: grad)
|
||||
eval(model, optimizer, lvalue)
|
||||
|
||||
// record loss
|
||||
losses.append(lvalue.item(Float.self))
|
||||
tokenCount += tokens.item(Int.self)
|
||||
|
||||
// report training loss
|
||||
if (iteration + 1) % parameters.stepsPerReport == 0 {
|
||||
let trainingLoss = MLXArray(losses).mean(stream: .cpu).item(Float.self)
|
||||
let now = Date.timeIntervalSinceReferenceDate
|
||||
|
||||
let iterationsPerSecond = Double(parameters.stepsPerReport) / (now - start)
|
||||
let tokensPerSecond = Double(tokenCount) / (now - start)
|
||||
|
||||
if await progress(
|
||||
.train(
|
||||
iteration: iteration, trainingLoss: trainingLoss,
|
||||
iterationsPerSecond: iterationsPerSecond, tokensPerSecond: tokensPerSecond))
|
||||
== .stop
|
||||
{
|
||||
break
|
||||
}
|
||||
|
||||
losses.removeAll()
|
||||
tokenCount = 0
|
||||
start = Date.timeIntervalSinceReferenceDate
|
||||
}
|
||||
|
||||
// report validation loss
|
||||
if iteration == 0 || (iteration + 1) % parameters.stepsPerEval == 0 {
|
||||
let validationStart = Date.timeIntervalSinceReferenceDate
|
||||
let validationLoss = evaluate(
|
||||
model: model, dataset: validate, loss: loss, tokenizer: tokenizer,
|
||||
batchSize: parameters.batchSize, batchCount: parameters.validationBatches)
|
||||
let now = Date.timeIntervalSinceReferenceDate
|
||||
|
||||
if await progress(
|
||||
.validation(
|
||||
iteration: iteration, validationLoss: validationLoss,
|
||||
validationTime: now - validationStart)) == .stop
|
||||
{
|
||||
break
|
||||
}
|
||||
|
||||
start = Date.timeIntervalSinceReferenceDate
|
||||
}
|
||||
|
||||
// save adapter weights if needed
|
||||
if let adapterURL = parameters.adapterURL, (iteration + 1) % parameters.saveEvery == 0 {
|
||||
try saveLoRAWeights(model: model, url: adapterURL)
|
||||
|
||||
if await progress(.save(iteration: iteration, url: adapterURL)) == .stop {
|
||||
break
|
||||
}
|
||||
|
||||
start = Date.timeIntervalSinceReferenceDate
|
||||
}
|
||||
|
||||
if iteration + 1 >= parameters.iterations {
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1,6 +1,7 @@
|
||||
// Copyright © 2024 Apple Inc.
|
||||
|
||||
import Foundation
|
||||
import Hub
|
||||
|
||||
/// Registry of models and and any overrides that go with them, e.g. prompt augmentation.
|
||||
/// If asked for an unknown configuration this will use the model/tokenizer as-is.
|
||||
@@ -9,7 +10,22 @@ import Foundation
|
||||
/// swift-tokenizers code handles a good chunk of that and this is a place to augment that
|
||||
/// implementation, if needed.
|
||||
public struct ModelConfiguration {
|
||||
public let id: String
|
||||
|
||||
public enum Identifier {
|
||||
case id(String)
|
||||
case directory(URL)
|
||||
}
|
||||
|
||||
public var id: Identifier
|
||||
|
||||
public var name: String {
|
||||
switch id {
|
||||
case .id(let string):
|
||||
string
|
||||
case .directory(let url):
|
||||
url.deletingLastPathComponent().lastPathComponent + "/" + url.lastPathComponent
|
||||
}
|
||||
}
|
||||
|
||||
/// pull the tokenizer from an alternate id
|
||||
public let tokenizerId: String?
|
||||
@@ -26,7 +42,17 @@ public struct ModelConfiguration {
|
||||
id: String, tokenizerId: String? = nil, overrideTokenizer: String? = nil,
|
||||
preparePrompt: ((String) -> String)? = nil
|
||||
) {
|
||||
self.id = id
|
||||
self.id = .id(id)
|
||||
self.tokenizerId = tokenizerId
|
||||
self.overrideTokenizer = overrideTokenizer
|
||||
self.preparePrompt = preparePrompt
|
||||
}
|
||||
|
||||
public init(
|
||||
directory: URL, tokenizerId: String? = nil, overrideTokenizer: String? = nil,
|
||||
preparePrompt: ((String) -> String)? = nil
|
||||
) {
|
||||
self.id = .directory(directory)
|
||||
self.tokenizerId = tokenizerId
|
||||
self.overrideTokenizer = overrideTokenizer
|
||||
self.preparePrompt = preparePrompt
|
||||
@@ -36,13 +62,25 @@ public struct ModelConfiguration {
|
||||
preparePrompt?(prompt) ?? prompt
|
||||
}
|
||||
|
||||
public func modelDirectory(hub: HubApi = HubApi()) -> URL {
|
||||
switch id {
|
||||
case .id(let id):
|
||||
// download the model weights and config
|
||||
let repo = Hub.Repo(id: id)
|
||||
return hub.localRepoLocation(repo)
|
||||
|
||||
case .directory(let directory):
|
||||
return directory
|
||||
}
|
||||
}
|
||||
|
||||
public static var registry = [String: ModelConfiguration]()
|
||||
|
||||
public static func register(configurations: [ModelConfiguration]) {
|
||||
bootstrap()
|
||||
|
||||
for c in configurations {
|
||||
registry[c.id] = c
|
||||
registry[c.name] = c
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -240,3 +240,11 @@ public struct PhiConfiguration: Codable {
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
// MARK: - LoRA
|
||||
|
||||
extension PhiModel: LoRAModel {
|
||||
public func loraLinearLayers() -> LoRALinearLayers {
|
||||
model.layers.map { ($0.selfAttention, ["q_proj", "v_proj"]) }
|
||||
}
|
||||
}
|
||||
|
||||
@@ -251,3 +251,11 @@ public struct Qwen2Configuration: Codable {
|
||||
[String: StringOrNumber].self, forKey: Qwen2Configuration.CodingKeys.ropeScaling)
|
||||
}
|
||||
}
|
||||
|
||||
// MARK: - LoRA
|
||||
|
||||
extension Qwen2Model: LoRAModel {
|
||||
public func loraLinearLayers() -> LoRALinearLayers {
|
||||
model.layers.map { ($0.attention, ["q_proj", "v_proj"]) }
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
# Llama
|
||||
# LLM
|
||||
|
||||
This is a port of several models from:
|
||||
|
||||
@@ -6,7 +6,7 @@ This is a port of several models from:
|
||||
|
||||
using the Hugging Face swift transformers package to provide tokenization:
|
||||
|
||||
https://github.com/huggingface/swift-transformers
|
||||
- https://github.com/huggingface/swift-transformers
|
||||
|
||||
The [Models.swift](Models.swift) provides minor overrides and customization --
|
||||
if you require overrides for the tokenizer or prompt customizations they can be
|
||||
@@ -30,3 +30,12 @@ Currently supported model types are:
|
||||
See [Configuration.swift](Configuration.swift) for more info.
|
||||
|
||||
See [llm-tool](../../Tools/llm-tool)
|
||||
|
||||
# LoRA
|
||||
|
||||
[Lora.swift](Lora.swift) contains an implementation of LoRA based on this example:
|
||||
|
||||
- https://github.com/ml-explore/mlx-examples/tree/main/lora
|
||||
|
||||
See [llm-tool/LoraCommands.swift](../../Tools/llm-tool/LoraCommands.swift) for an example of a driver and
|
||||
[llm-tool](../../Tools/llm-tool) for examples of how to run it.
|
||||
|
||||
@@ -254,3 +254,11 @@ public struct Starcoder2Configuration: Codable {
|
||||
?? true
|
||||
}
|
||||
}
|
||||
|
||||
// MARK: - LoRA
|
||||
|
||||
extension Starcoder2Model: LoRAModel {
|
||||
public func loraLinearLayers() -> LoRALinearLayers {
|
||||
model.layers.map { ($0.attention, ["q_proj", "v_proj"]) }
|
||||
}
|
||||
}
|
||||
|
||||
@@ -4,10 +4,20 @@ import Foundation
|
||||
import Hub
|
||||
import Tokenizers
|
||||
|
||||
public func loadTokenizer(configuration: ModelConfiguration) async throws -> Tokenizer {
|
||||
public func loadTokenizer(configuration: ModelConfiguration, hub: HubApi) async throws -> Tokenizer
|
||||
{
|
||||
// from AutoTokenizer.from() -- this lets us override parts of the configuration
|
||||
let config = LanguageModelConfigurationFromHub(
|
||||
modelName: configuration.tokenizerId ?? configuration.id)
|
||||
|
||||
let config: LanguageModelConfigurationFromHub
|
||||
|
||||
switch configuration.id {
|
||||
case .id(let id):
|
||||
config = LanguageModelConfigurationFromHub(
|
||||
modelName: configuration.tokenizerId ?? id, hubApi: hub)
|
||||
case .directory(let directory):
|
||||
config = LanguageModelConfigurationFromHub(modelFolder: directory, hubApi: hub)
|
||||
}
|
||||
|
||||
guard var tokenizerConfig = try await config.tokenizerConfig else {
|
||||
throw LLMError(message: "missing config")
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user