Files
mlx-swift-examples/Libraries/LLM/Load.swift
David Koski 9d74afd119 handle partially quantized models (#76)
* handle partially quantized models

- fix for #53 #71 #69 #74
- in order to test the models
	- I added a default prompt of an appropriate form
	- while working on the model configuration also added additional stop tokens (#74)
- fixed the repetitionPenalty code (#71)
2024-05-28 16:35:11 -07:00

85 lines
2.8 KiB
Swift

// Copyright © 2024 Apple Inc.
import AsyncAlgorithms
import Foundation
import Hub
import MLX
import MLXNN
import MLXRandom
import Tokenizers
struct LLMError: Error {
let message: String
}
/// Load and return the model and tokenizer
public func load(
hub: HubApi = HubApi(), configuration: ModelConfiguration,
progressHandler: @escaping (Progress) -> Void = { _ in }
) async throws -> (LLMModel, Tokenizer) {
do {
let tokenizer = try await loadTokenizer(configuration: configuration, hub: hub)
let modelDirectory: URL
switch configuration.id {
case .id(let id):
// download the model weights
let repo = Hub.Repo(id: id)
let modelFiles = ["*.safetensors"]
modelDirectory = try await hub.snapshot(
from: repo, matching: modelFiles, progressHandler: progressHandler)
case .directory(let directory):
modelDirectory = directory
}
// create the model (no weights loaded)
let configurationURL = modelDirectory.appending(component: "config.json")
let baseConfig = try JSONDecoder().decode(
BaseConfiguration.self, from: Data(contentsOf: configurationURL))
let model = try baseConfig.modelType.createModel(configuration: configurationURL)
// load the weights
var weights = [String: MLXArray]()
let enumerator = FileManager.default.enumerator(
at: modelDirectory, includingPropertiesForKeys: nil)!
for case let url as URL in enumerator {
if url.pathExtension == "safetensors" {
let w = try loadArrays(url: url)
for (key, value) in w {
weights[key] = value
}
}
}
// per-model cleanup
weights = model.sanitize(weights: weights)
// quantize if needed
if let quantization = baseConfig.quantization {
quantize(model: model, groupSize: quantization.groupSize, bits: quantization.bits) {
path, module in
weights["\(path).scales"] != nil
}
}
// apply the loaded weights
let parameters = ModuleParameters.unflattened(weights)
try model.update(parameters: parameters, verify: [.all])
eval(model)
return (model, tokenizer)
} catch Hub.HubClientError.authorizationRequired {
// an authorizationRequired means (typically) that the named repo doesn't exist on
// on the server so retry with local only configuration
var newConfiguration = configuration
newConfiguration.id = .directory(configuration.modelDirectory(hub: hub))
return try await load(
hub: hub, configuration: newConfiguration, progressHandler: progressHandler)
}
}