llm improvements

- document the tokenizer used (https://github.com/huggingface/swift-transformers)
- provide a hook for tokenizer configuration, prompt augmentation
	- this isn't as rich as the python equivalents but it helps a little
This commit is contained in:
David Koski
2024-03-01 14:46:32 -08:00
parent 599661774a
commit 82f6a969d4
8 changed files with 250 additions and 22 deletions

View File

@@ -14,13 +14,13 @@ struct LLMError: Error {
/// Load and return the model and tokenizer
public func load(
hub: HubApi = HubApi(), name: String, progressHandler: @escaping (Progress) -> Void = { _ in }
hub: HubApi = HubApi(), configuration: ModelConfiguration, progressHandler: @escaping (Progress) -> Void = { _ in }
) async throws -> (LLMModel, Tokenizer) {
// note: this doesn't have a way to pass the HubApi
let tokenizer = try await loadTokenizer(name: name)
let tokenizer = try await loadTokenizer(configuration: configuration)
// download the model weights and config
let repo = Hub.Repo(id: name)
let repo = Hub.Repo(id: configuration.id)
let modelFiles = ["config.json", "*.safetensors"]
let modelDirectory = try await hub.snapshot(
from: repo, matching: modelFiles, progressHandler: progressHandler)

107
Libraries/LLM/Models.swift Normal file
View File

@@ -0,0 +1,107 @@
// Copyright © 2024 Apple Inc.
import Foundation
/// Registry of models and and any overrides that go with them, e.g. prompt augmentation.
/// If asked for an unknown configuration this will use the model/tokenizer as-is.
///
/// The python tokenizers have a very rich set of implementations and configuration. The
/// swift-tokenizers code handles a good chunk of that and this is a place to augment that
/// implementation, if needed.
public struct ModelConfiguration {
public let id: String
/// overrides for TokenizerModel/knownTokenizers -- useful before swift-transformers is updated
public let overrideTokenizer: String?
/// custom preparation logic for the prompt. custom tokenizers provide more capability, but this
/// allows some minor formtting changes, e.g. wrapping the user input in the expected prompt
/// format
private let preparePrompt: ((String) -> String)?
public init(id: String, overrideTokenizer: String? = nil, preparePrompt: ((String) -> String)? = nil) {
self.id = id
self.overrideTokenizer = overrideTokenizer
self.preparePrompt = preparePrompt
}
public func prepare(prompt: String) -> String {
preparePrompt?(prompt) ?? prompt
}
public static var registry = [String:ModelConfiguration]()
public static func register(configurations: [ModelConfiguration]) {
bootstrap()
for c in configurations {
registry[c.id] = c
}
}
public static func configuration(id: String) -> ModelConfiguration {
bootstrap()
if let c = registry[id] {
return c
} else {
return ModelConfiguration(id: id)
}
}
}
extension ModelConfiguration {
static let mistral7B4bit = ModelConfiguration(id: "mlx-community/Mistral-7B-v0.1-hf-4bit-mlx")
static let codeLlama13b4bit = ModelConfiguration(
id: "mlx-community/CodeLlama-13b-Instruct-hf-4bit-MLX",
overrideTokenizer: "PreTrainedTokenizer")
{ prompt in
// given the prompt: func sortArray(_ array: [Int]) -> String { <FILL_ME> }
// the python code produces this (via its custom tokenizer):
// <PRE> func sortArray(_ array: [Int]) -> String { <SUF> } <MID>
"<PRE> " +
prompt.replacingOccurrences(of: "<FILL_ME>", with: "<SUF>") +
" <MID>"
}
static let phi4bit = ModelConfiguration(id: "mlx-community/phi-2-hf-4bit-mlx") { prompt in
"Instruct: \(prompt). Output: "
}
static let gemma2bQuantized = ModelConfiguration(
id: "mlx-community/quantized-gemma-2b-it",
overrideTokenizer: "PreTrainedTokenizer") { prompt in
"<start_of_turn>user \(prompt)<end_of_turn><start_of_turn>model"
}
private enum BootstrapState {
case idle
case bootstrapping
case bootstrapped
}
static private var bootstrapState = BootstrapState.idle
static func bootstrap() {
switch bootstrapState {
case .idle:
bootstrapState = .bootstrapping
register(configurations: [
mistral7B4bit,
codeLlama13b4bit,
phi4bit,
gemma2bQuantized,
])
bootstrapState = .bootstrapped
case .bootstrapping:
break
case .bootstrapped:
break
}
}
}

View File

@@ -4,9 +4,22 @@ This is a port of several models from:
- https://github.com/ml-explore/mlx-examples/blob/main/llms/mlx_lm/models/
You can use this to load models from huggingface, e.g.:
using the Hugging Face swift transformers package to provide tokenization:
- https://huggingface.co/mlx-community/Mistral-7B-v0.1-hf-4bit-mlx
https://github.com/huggingface/swift-transformers
The [Models.swift](Models.swift) provides minor overrides and customization --
if you require overrides for the tokenizer or prompt customizations they can be
added there.
This is set up to load models from Hugging Face, e.g. https://huggingface.co/mlx-community
The following models have been tried:
- mlx-community/Mistral-7B-v0.1-hf-4bit-mlx
- mlx-community/CodeLlama-13b-Instruct-hf-4bit-MLX
- mlx-community/phi-2-hf-4bit-mlx
- mlx-community/quantized-gemma-2b-it
Currently supported model types are:

View File

@@ -49,9 +49,9 @@ public struct Tokenizer: Tokenizers.Tokenizer {
}
public func loadTokenizer(name: String) async throws -> Tokenizer {
public func loadTokenizer(configuration: ModelConfiguration) async throws -> Tokenizer {
// from AutoTokenizer.from() -- this lets us override parts of the configuration
let config = LanguageModelConfigurationFromHub(modelName: name)
let config = LanguageModelConfigurationFromHub(modelName: configuration.id)
guard var tokenizerConfig = try await config.tokenizerConfig else {
throw LLMError(message: "missing config")
}