// Copyright © 2024 Apple Inc. import Foundation /// Registry of models and and any overrides that go with them, e.g. prompt augmentation. /// If asked for an unknown configuration this will use the model/tokenizer as-is. /// /// The python tokenizers have a very rich set of implementations and configuration. The /// swift-tokenizers code handles a good chunk of that and this is a place to augment that /// implementation, if needed. public struct ModelConfiguration { public let id: String /// overrides for TokenizerModel/knownTokenizers -- useful before swift-transformers is updated public let overrideTokenizer: String? /// custom preparation logic for the prompt. custom tokenizers provide more capability, but this /// allows some minor formtting changes, e.g. wrapping the user input in the expected prompt /// format private let preparePrompt: ((String) -> String)? public init( id: String, overrideTokenizer: String? = nil, preparePrompt: ((String) -> String)? = nil ) { self.id = id self.overrideTokenizer = overrideTokenizer self.preparePrompt = preparePrompt } public func prepare(prompt: String) -> String { preparePrompt?(prompt) ?? prompt } public static var registry = [String: ModelConfiguration]() public static func register(configurations: [ModelConfiguration]) { bootstrap() for c in configurations { registry[c.id] = c } } public static func configuration(id: String) -> ModelConfiguration { bootstrap() if let c = registry[id] { return c } else { return ModelConfiguration(id: id) } } } extension ModelConfiguration { public static let mistral7B4bit = ModelConfiguration( id: "mlx-community/Mistral-7B-v0.1-hf-4bit-mlx") public static let codeLlama13b4bit = ModelConfiguration( id: "mlx-community/CodeLlama-13b-Instruct-hf-4bit-MLX", overrideTokenizer: "PreTrainedTokenizer" ) { prompt in // given the prompt: func sortArray(_ array: [Int]) -> String { } // the python code produces this (via its custom tokenizer): //
 func sortArray(_ array: [Int]) -> String {   } 

        "
 " + prompt.replacingOccurrences(of: "", with: "") + " "
    }

    public static let phi4bit = ModelConfiguration(id: "mlx-community/phi-2-hf-4bit-mlx") {
        prompt in
        "Instruct: \(prompt). Output: "
    }

    public static let gemma2bQuantized = ModelConfiguration(
        id: "mlx-community/quantized-gemma-2b-it",
        overrideTokenizer: "PreTrainedTokenizer"
    ) { prompt in
        "user \(prompt)model"
    }

    private enum BootstrapState {
        case idle
        case bootstrapping
        case bootstrapped
    }

    static private var bootstrapState = BootstrapState.idle

    static func bootstrap() {
        switch bootstrapState {
        case .idle:
            bootstrapState = .bootstrapping
            register(configurations: [
                mistral7B4bit,
                codeLlama13b4bit,
                phi4bit,
                gemma2bQuantized,
            ])
            bootstrapState = .bootstrapped

        case .bootstrapping:
            break

        case .bootstrapped:
            break
        }
    }
}