Files
mlx-swift-examples/Libraries/LLM/Tokenizer.swift
David Koski 6c0b66f90a implement LoRA / QLoRA (#46)
* implement LoRA / QLoRA

- example of using MLX to fine-tune an LLM with low rank adaptation (LoRA) for a target task
- see also https://arxiv.org/abs/2106.09685
- based on https://github.com/ml-explore/mlx-examples/tree/main/lora

* add some command line flags I found useful during use
- --quiet -- don't print decorator text, just the generated text
- --prompt @/tmp/file.txt -- load prompt from file

* user can specify path to model OR model identifier in huggingface

* update mlx-swift reference

Co-authored-by: Ashraful Islam <ashraful.meche@gmail.com>
Co-authored-by: JustinMeans <46542161+JustinMeans@users.noreply.github.com>
2024-04-22 09:30:12 -07:00

44 lines
1.4 KiB
Swift

// Copyright © 2024 Apple Inc.
import Foundation
import Hub
import Tokenizers
public func loadTokenizer(configuration: ModelConfiguration, hub: HubApi) async throws -> Tokenizer
{
// from AutoTokenizer.from() -- this lets us override parts of the configuration
let config: LanguageModelConfigurationFromHub
switch configuration.id {
case .id(let id):
config = LanguageModelConfigurationFromHub(
modelName: configuration.tokenizerId ?? id, hubApi: hub)
case .directory(let directory):
config = LanguageModelConfigurationFromHub(modelFolder: directory, hubApi: hub)
}
guard var tokenizerConfig = try await config.tokenizerConfig else {
throw LLMError(message: "missing config")
}
let tokenizerData = try await config.tokenizerData
// workaround: replacement tokenizers for unhandled values in swift-transform
if let tokenizerClass = tokenizerConfig.tokenizerClass?.stringValue,
let replacement = replacementTokenizers[tokenizerClass]
{
var dictionary = tokenizerConfig.dictionary
dictionary["tokenizer_class"] = replacement
tokenizerConfig = Config(dictionary)
}
return try PreTrainedTokenizer(
tokenizerConfig: tokenizerConfig, tokenizerData: tokenizerData)
}
/// overrides for TokenizerModel/knownTokenizers
let replacementTokenizers = [
"Qwen2Tokenizer": "PreTrainedTokenizer",
"CohereTokenizer": "PreTrainedTokenizer",
]