llm improvements
- document the tokenizer used (https://github.com/huggingface/swift-transformers) - provide a hook for tokenizer configuration, prompt augmentation - this isn't as rich as the python equivalents but it helps a little
This commit is contained in:
@@ -14,13 +14,13 @@ struct LLMError: Error {
|
||||
|
||||
/// Load and return the model and tokenizer
|
||||
public func load(
|
||||
hub: HubApi = HubApi(), name: String, progressHandler: @escaping (Progress) -> Void = { _ in }
|
||||
hub: HubApi = HubApi(), configuration: ModelConfiguration, progressHandler: @escaping (Progress) -> Void = { _ in }
|
||||
) async throws -> (LLMModel, Tokenizer) {
|
||||
// note: this doesn't have a way to pass the HubApi
|
||||
let tokenizer = try await loadTokenizer(name: name)
|
||||
let tokenizer = try await loadTokenizer(configuration: configuration)
|
||||
|
||||
// download the model weights and config
|
||||
let repo = Hub.Repo(id: name)
|
||||
let repo = Hub.Repo(id: configuration.id)
|
||||
let modelFiles = ["config.json", "*.safetensors"]
|
||||
let modelDirectory = try await hub.snapshot(
|
||||
from: repo, matching: modelFiles, progressHandler: progressHandler)
|
||||
|
||||
107
Libraries/LLM/Models.swift
Normal file
107
Libraries/LLM/Models.swift
Normal file
@@ -0,0 +1,107 @@
|
||||
// Copyright © 2024 Apple Inc.
|
||||
|
||||
import Foundation
|
||||
|
||||
/// Registry of models and and any overrides that go with them, e.g. prompt augmentation.
|
||||
/// If asked for an unknown configuration this will use the model/tokenizer as-is.
|
||||
///
|
||||
/// The python tokenizers have a very rich set of implementations and configuration. The
|
||||
/// swift-tokenizers code handles a good chunk of that and this is a place to augment that
|
||||
/// implementation, if needed.
|
||||
public struct ModelConfiguration {
|
||||
public let id: String
|
||||
|
||||
/// overrides for TokenizerModel/knownTokenizers -- useful before swift-transformers is updated
|
||||
public let overrideTokenizer: String?
|
||||
|
||||
/// custom preparation logic for the prompt. custom tokenizers provide more capability, but this
|
||||
/// allows some minor formtting changes, e.g. wrapping the user input in the expected prompt
|
||||
/// format
|
||||
private let preparePrompt: ((String) -> String)?
|
||||
|
||||
public init(id: String, overrideTokenizer: String? = nil, preparePrompt: ((String) -> String)? = nil) {
|
||||
self.id = id
|
||||
self.overrideTokenizer = overrideTokenizer
|
||||
self.preparePrompt = preparePrompt
|
||||
}
|
||||
|
||||
public func prepare(prompt: String) -> String {
|
||||
preparePrompt?(prompt) ?? prompt
|
||||
}
|
||||
|
||||
public static var registry = [String:ModelConfiguration]()
|
||||
|
||||
public static func register(configurations: [ModelConfiguration]) {
|
||||
bootstrap()
|
||||
|
||||
for c in configurations {
|
||||
registry[c.id] = c
|
||||
}
|
||||
}
|
||||
|
||||
public static func configuration(id: String) -> ModelConfiguration {
|
||||
bootstrap()
|
||||
|
||||
if let c = registry[id] {
|
||||
return c
|
||||
} else {
|
||||
return ModelConfiguration(id: id)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
extension ModelConfiguration {
|
||||
|
||||
static let mistral7B4bit = ModelConfiguration(id: "mlx-community/Mistral-7B-v0.1-hf-4bit-mlx")
|
||||
|
||||
static let codeLlama13b4bit = ModelConfiguration(
|
||||
id: "mlx-community/CodeLlama-13b-Instruct-hf-4bit-MLX",
|
||||
overrideTokenizer: "PreTrainedTokenizer")
|
||||
{ prompt in
|
||||
// given the prompt: func sortArray(_ array: [Int]) -> String { <FILL_ME> }
|
||||
// the python code produces this (via its custom tokenizer):
|
||||
// <PRE> func sortArray(_ array: [Int]) -> String { <SUF> } <MID>
|
||||
|
||||
"<PRE> " +
|
||||
prompt.replacingOccurrences(of: "<FILL_ME>", with: "<SUF>") +
|
||||
" <MID>"
|
||||
}
|
||||
|
||||
static let phi4bit = ModelConfiguration(id: "mlx-community/phi-2-hf-4bit-mlx") { prompt in
|
||||
"Instruct: \(prompt). Output: "
|
||||
}
|
||||
|
||||
static let gemma2bQuantized = ModelConfiguration(
|
||||
id: "mlx-community/quantized-gemma-2b-it",
|
||||
overrideTokenizer: "PreTrainedTokenizer") { prompt in
|
||||
"<start_of_turn>user \(prompt)<end_of_turn><start_of_turn>model"
|
||||
}
|
||||
|
||||
private enum BootstrapState {
|
||||
case idle
|
||||
case bootstrapping
|
||||
case bootstrapped
|
||||
}
|
||||
|
||||
static private var bootstrapState = BootstrapState.idle
|
||||
|
||||
static func bootstrap() {
|
||||
switch bootstrapState {
|
||||
case .idle:
|
||||
bootstrapState = .bootstrapping
|
||||
register(configurations: [
|
||||
mistral7B4bit,
|
||||
codeLlama13b4bit,
|
||||
phi4bit,
|
||||
gemma2bQuantized,
|
||||
])
|
||||
bootstrapState = .bootstrapped
|
||||
|
||||
case .bootstrapping:
|
||||
break
|
||||
|
||||
case .bootstrapped:
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -4,9 +4,22 @@ This is a port of several models from:
|
||||
|
||||
- https://github.com/ml-explore/mlx-examples/blob/main/llms/mlx_lm/models/
|
||||
|
||||
You can use this to load models from huggingface, e.g.:
|
||||
using the Hugging Face swift transformers package to provide tokenization:
|
||||
|
||||
- https://huggingface.co/mlx-community/Mistral-7B-v0.1-hf-4bit-mlx
|
||||
https://github.com/huggingface/swift-transformers
|
||||
|
||||
The [Models.swift](Models.swift) provides minor overrides and customization --
|
||||
if you require overrides for the tokenizer or prompt customizations they can be
|
||||
added there.
|
||||
|
||||
This is set up to load models from Hugging Face, e.g. https://huggingface.co/mlx-community
|
||||
|
||||
The following models have been tried:
|
||||
|
||||
- mlx-community/Mistral-7B-v0.1-hf-4bit-mlx
|
||||
- mlx-community/CodeLlama-13b-Instruct-hf-4bit-MLX
|
||||
- mlx-community/phi-2-hf-4bit-mlx
|
||||
- mlx-community/quantized-gemma-2b-it
|
||||
|
||||
Currently supported model types are:
|
||||
|
||||
|
||||
@@ -49,9 +49,9 @@ public struct Tokenizer: Tokenizers.Tokenizer {
|
||||
|
||||
}
|
||||
|
||||
public func loadTokenizer(name: String) async throws -> Tokenizer {
|
||||
public func loadTokenizer(configuration: ModelConfiguration) async throws -> Tokenizer {
|
||||
// from AutoTokenizer.from() -- this lets us override parts of the configuration
|
||||
let config = LanguageModelConfigurationFromHub(modelName: name)
|
||||
let config = LanguageModelConfigurationFromHub(modelName: configuration.id)
|
||||
guard var tokenizerConfig = try await config.tokenizerConfig else {
|
||||
throw LLMError(message: "missing config")
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user