llm improvements

- document the tokenizer used (https://github.com/huggingface/swift-transformers)
- provide a hook for tokenizer configuration, prompt augmentation
	- this isn't as rich as the python equivalents but it helps a little
This commit is contained in:
David Koski
2024-03-01 14:46:32 -08:00
parent 599661774a
commit 82f6a969d4
8 changed files with 250 additions and 22 deletions

View File

@@ -14,13 +14,13 @@ struct LLMError: Error {
/// Load and return the model and tokenizer /// Load and return the model and tokenizer
public func load( public func load(
hub: HubApi = HubApi(), name: String, progressHandler: @escaping (Progress) -> Void = { _ in } hub: HubApi = HubApi(), configuration: ModelConfiguration, progressHandler: @escaping (Progress) -> Void = { _ in }
) async throws -> (LLMModel, Tokenizer) { ) async throws -> (LLMModel, Tokenizer) {
// note: this doesn't have a way to pass the HubApi // note: this doesn't have a way to pass the HubApi
let tokenizer = try await loadTokenizer(name: name) let tokenizer = try await loadTokenizer(configuration: configuration)
// download the model weights and config // download the model weights and config
let repo = Hub.Repo(id: name) let repo = Hub.Repo(id: configuration.id)
let modelFiles = ["config.json", "*.safetensors"] let modelFiles = ["config.json", "*.safetensors"]
let modelDirectory = try await hub.snapshot( let modelDirectory = try await hub.snapshot(
from: repo, matching: modelFiles, progressHandler: progressHandler) from: repo, matching: modelFiles, progressHandler: progressHandler)

107
Libraries/LLM/Models.swift Normal file
View File

@@ -0,0 +1,107 @@
// Copyright © 2024 Apple Inc.
import Foundation
/// Registry of models and and any overrides that go with them, e.g. prompt augmentation.
/// If asked for an unknown configuration this will use the model/tokenizer as-is.
///
/// The python tokenizers have a very rich set of implementations and configuration. The
/// swift-tokenizers code handles a good chunk of that and this is a place to augment that
/// implementation, if needed.
public struct ModelConfiguration {
public let id: String
/// overrides for TokenizerModel/knownTokenizers -- useful before swift-transformers is updated
public let overrideTokenizer: String?
/// custom preparation logic for the prompt. custom tokenizers provide more capability, but this
/// allows some minor formtting changes, e.g. wrapping the user input in the expected prompt
/// format
private let preparePrompt: ((String) -> String)?
public init(id: String, overrideTokenizer: String? = nil, preparePrompt: ((String) -> String)? = nil) {
self.id = id
self.overrideTokenizer = overrideTokenizer
self.preparePrompt = preparePrompt
}
public func prepare(prompt: String) -> String {
preparePrompt?(prompt) ?? prompt
}
public static var registry = [String:ModelConfiguration]()
public static func register(configurations: [ModelConfiguration]) {
bootstrap()
for c in configurations {
registry[c.id] = c
}
}
public static func configuration(id: String) -> ModelConfiguration {
bootstrap()
if let c = registry[id] {
return c
} else {
return ModelConfiguration(id: id)
}
}
}
extension ModelConfiguration {
static let mistral7B4bit = ModelConfiguration(id: "mlx-community/Mistral-7B-v0.1-hf-4bit-mlx")
static let codeLlama13b4bit = ModelConfiguration(
id: "mlx-community/CodeLlama-13b-Instruct-hf-4bit-MLX",
overrideTokenizer: "PreTrainedTokenizer")
{ prompt in
// given the prompt: func sortArray(_ array: [Int]) -> String { <FILL_ME> }
// the python code produces this (via its custom tokenizer):
// <PRE> func sortArray(_ array: [Int]) -> String { <SUF> } <MID>
"<PRE> " +
prompt.replacingOccurrences(of: "<FILL_ME>", with: "<SUF>") +
" <MID>"
}
static let phi4bit = ModelConfiguration(id: "mlx-community/phi-2-hf-4bit-mlx") { prompt in
"Instruct: \(prompt). Output: "
}
static let gemma2bQuantized = ModelConfiguration(
id: "mlx-community/quantized-gemma-2b-it",
overrideTokenizer: "PreTrainedTokenizer") { prompt in
"<start_of_turn>user \(prompt)<end_of_turn><start_of_turn>model"
}
private enum BootstrapState {
case idle
case bootstrapping
case bootstrapped
}
static private var bootstrapState = BootstrapState.idle
static func bootstrap() {
switch bootstrapState {
case .idle:
bootstrapState = .bootstrapping
register(configurations: [
mistral7B4bit,
codeLlama13b4bit,
phi4bit,
gemma2bQuantized,
])
bootstrapState = .bootstrapped
case .bootstrapping:
break
case .bootstrapped:
break
}
}
}

View File

@@ -4,9 +4,22 @@ This is a port of several models from:
- https://github.com/ml-explore/mlx-examples/blob/main/llms/mlx_lm/models/ - https://github.com/ml-explore/mlx-examples/blob/main/llms/mlx_lm/models/
You can use this to load models from huggingface, e.g.: using the Hugging Face swift transformers package to provide tokenization:
- https://huggingface.co/mlx-community/Mistral-7B-v0.1-hf-4bit-mlx https://github.com/huggingface/swift-transformers
The [Models.swift](Models.swift) provides minor overrides and customization --
if you require overrides for the tokenizer or prompt customizations they can be
added there.
This is set up to load models from Hugging Face, e.g. https://huggingface.co/mlx-community
The following models have been tried:
- mlx-community/Mistral-7B-v0.1-hf-4bit-mlx
- mlx-community/CodeLlama-13b-Instruct-hf-4bit-MLX
- mlx-community/phi-2-hf-4bit-mlx
- mlx-community/quantized-gemma-2b-it
Currently supported model types are: Currently supported model types are:

View File

@@ -49,9 +49,9 @@ public struct Tokenizer: Tokenizers.Tokenizer {
} }
public func loadTokenizer(name: String) async throws -> Tokenizer { public func loadTokenizer(configuration: ModelConfiguration) async throws -> Tokenizer {
// from AutoTokenizer.from() -- this lets us override parts of the configuration // from AutoTokenizer.from() -- this lets us override parts of the configuration
let config = LanguageModelConfigurationFromHub(modelName: name) let config = LanguageModelConfigurationFromHub(modelName: configuration.id)
guard var tokenizerConfig = try await config.tokenizerConfig else { guard var tokenizerConfig = try await config.tokenizerConfig else {
throw LLMError(message: "missing config") throw LLMError(message: "missing config")
} }

View File

@@ -25,7 +25,7 @@ struct SyncGenerator: AsyncParsableCommand {
var model: String = "mlx-community/Mistral-7B-v0.1-hf-4bit-mlx" var model: String = "mlx-community/Mistral-7B-v0.1-hf-4bit-mlx"
@Option(name: .shortAndLong, help: "The message to be processed by the model") @Option(name: .shortAndLong, help: "The message to be processed by the model")
var prompt = "compare swift and python" var prompt = "compare python and swift"
@Option(name: .shortAndLong, help: "Maximum number of tokens to generate") @Option(name: .shortAndLong, help: "Maximum number of tokens to generate")
var maxTokens = 100 var maxTokens = 100
@@ -40,22 +40,24 @@ struct SyncGenerator: AsyncParsableCommand {
func run() async throws { func run() async throws {
MLXRandom.seed(seed) MLXRandom.seed(seed)
let (model, tokenizer) = try await load(name: model) let modelConfiguration = ModelConfiguration.configuration(id: model)
let (model, tokenizer) = try await load(configuration: modelConfiguration)
let prompt = modelConfiguration.prepare(prompt: self.prompt)
let promptTokens = tokenizer.encode(text: prompt)
print("Starting generation ...") print("Starting generation ...")
print(prompt, terminator: "") print(self.prompt, terminator: "")
var start = Date.timeIntervalSinceReferenceDate var start = Date.timeIntervalSinceReferenceDate
var promptTime: TimeInterval = 0 var promptTime: TimeInterval = 0
let prompt = MLXArray(tokenizer.encode(text: prompt))
// collect the tokens and keep track of how much of the string // collect the tokens and keep track of how much of the string
// we have printed already // we have printed already
var tokens = [Int]() var tokens = [Int]()
var printed = 0 var printed = 0
for token in TokenIterator(prompt: prompt, model: model, temp: temperature) { for token in TokenIterator(prompt: MLXArray(promptTokens), model: model, temp: temperature) {
if tokens.isEmpty { if tokens.isEmpty {
eval(token) eval(token)
let now = Date.timeIntervalSinceReferenceDate let now = Date.timeIntervalSinceReferenceDate
@@ -90,7 +92,7 @@ struct SyncGenerator: AsyncParsableCommand {
print( print(
""" """
Prompt Tokens per second: \((Double(prompt.size) / promptTime).formatted()) Prompt Tokens per second: \((Double(promptTokens.count) / promptTime).formatted())
Generation tokens per second: \((Double(tokens.count - 1) / generateTime).formatted()) Generation tokens per second: \((Double(tokens.count - 1) / generateTime).formatted())
""") """)
} }
@@ -111,7 +113,7 @@ struct AsyncGenerator: AsyncParsableCommand {
var model: String = "mlx-community/Mistral-7B-v0.1-hf-4bit-mlx" var model: String = "mlx-community/Mistral-7B-v0.1-hf-4bit-mlx"
@Option(name: .shortAndLong, help: "The message to be processed by the model") @Option(name: .shortAndLong, help: "The message to be processed by the model")
var prompt = "compare swift and python" var prompt = "compare python and swift"
@Option(name: .shortAndLong, help: "Maximum number of tokens to generate") @Option(name: .shortAndLong, help: "Maximum number of tokens to generate")
var maxTokens = 100 var maxTokens = 100
@@ -126,22 +128,24 @@ struct AsyncGenerator: AsyncParsableCommand {
func run() async throws { func run() async throws {
MLXRandom.seed(seed) MLXRandom.seed(seed)
let (model, tokenizer) = try await load(name: model) let modelConfiguration = ModelConfiguration.configuration(id: model)
let (model, tokenizer) = try await load(configuration: modelConfiguration)
let prompt = modelConfiguration.prepare(prompt: self.prompt)
let promptTokens = tokenizer.encode(text: prompt)
print("Starting generation ...") print("Starting generation ...")
print(prompt, terminator: "") print(self.prompt, terminator: "")
var start = Date.timeIntervalSinceReferenceDate var start = Date.timeIntervalSinceReferenceDate
var promptTime: TimeInterval = 0 var promptTime: TimeInterval = 0
let prompt = MLXArray(tokenizer.encode(text: prompt))
// collect the tokens and keep track of how much of the string // collect the tokens and keep track of how much of the string
// we have printed already // we have printed already
var tokens = [Int]() var tokens = [Int]()
var printed = 0 var printed = 0
let (task, channel) = generate(prompt: prompt, model: model, temp: temperature) let (task, channel) = generate(prompt: MLXArray(promptTokens), model: model, temp: temperature)
for await token in channel { for await token in channel {
if tokens.isEmpty { if tokens.isEmpty {
@@ -179,7 +183,7 @@ struct AsyncGenerator: AsyncParsableCommand {
print( print(
""" """
Prompt Tokens per second: \((Double(prompt.size) / promptTime).formatted()) Prompt Tokens per second: \((Double(promptTokens.count) / promptTime).formatted())
Generation tokens per second: \((Double(tokens.count - 1) / generateTime).formatted()) Generation tokens per second: \((Double(tokens.count - 1) / generateTime).formatted())
""") """)

View File

@@ -36,6 +36,7 @@
C3932D572B6A060B00A81055 /* MNIST.swift in Sources */ = {isa = PBXBuildFile; fileRef = C3932D562B6A060B00A81055 /* MNIST.swift */; }; C3932D572B6A060B00A81055 /* MNIST.swift in Sources */ = {isa = PBXBuildFile; fileRef = C3932D562B6A060B00A81055 /* MNIST.swift */; };
C3932D592B6A0BE400A81055 /* Random.swift in Sources */ = {isa = PBXBuildFile; fileRef = C3932D582B6A0BE400A81055 /* Random.swift */; }; C3932D592B6A0BE400A81055 /* Random.swift in Sources */ = {isa = PBXBuildFile; fileRef = C3932D582B6A0BE400A81055 /* Random.swift */; };
C397C59C2B62C6D0004B084D /* ArgumentParser in Frameworks */ = {isa = PBXBuildFile; productRef = C397C59B2B62C6D0004B084D /* ArgumentParser */; }; C397C59C2B62C6D0004B084D /* ArgumentParser in Frameworks */ = {isa = PBXBuildFile; productRef = C397C59B2B62C6D0004B084D /* ArgumentParser */; };
C3A8B3AC2B9283150002EFB8 /* Models.swift in Sources */ = {isa = PBXBuildFile; fileRef = C3A8B3AB2B9283150002EFB8 /* Models.swift */; };
C3E786AB2B8D1AEC0004D037 /* Evaluate.swift in Sources */ = {isa = PBXBuildFile; fileRef = C3E786AA2B8D1AEC0004D037 /* Evaluate.swift */; }; C3E786AB2B8D1AEC0004D037 /* Evaluate.swift in Sources */ = {isa = PBXBuildFile; fileRef = C3E786AA2B8D1AEC0004D037 /* Evaluate.swift */; };
C3E786AD2B8D4AF50004D037 /* Tokenizer.swift in Sources */ = {isa = PBXBuildFile; fileRef = C3E786AC2B8D4AF50004D037 /* Tokenizer.swift */; }; C3E786AD2B8D4AF50004D037 /* Tokenizer.swift in Sources */ = {isa = PBXBuildFile; fileRef = C3E786AC2B8D4AF50004D037 /* Tokenizer.swift */; };
C3FBCB212B8520B80007E490 /* MLX in Frameworks */ = {isa = PBXBuildFile; productRef = C3FBCB202B8520B80007E490 /* MLX */; }; C3FBCB212B8520B80007E490 /* MLX in Frameworks */ = {isa = PBXBuildFile; productRef = C3FBCB202B8520B80007E490 /* MLX */; };
@@ -152,6 +153,7 @@
C3932D562B6A060B00A81055 /* MNIST.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = MNIST.swift; sourceTree = "<group>"; }; C3932D562B6A060B00A81055 /* MNIST.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = MNIST.swift; sourceTree = "<group>"; };
C3932D582B6A0BE400A81055 /* Random.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = Random.swift; sourceTree = "<group>"; }; C3932D582B6A0BE400A81055 /* Random.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = Random.swift; sourceTree = "<group>"; };
C397C58B2B62C6A9004B084D /* llm-tool */ = {isa = PBXFileReference; explicitFileType = "compiled.mach-o.executable"; includeInIndex = 0; path = "llm-tool"; sourceTree = BUILT_PRODUCTS_DIR; }; C397C58B2B62C6A9004B084D /* llm-tool */ = {isa = PBXFileReference; explicitFileType = "compiled.mach-o.executable"; includeInIndex = 0; path = "llm-tool"; sourceTree = BUILT_PRODUCTS_DIR; };
C3A8B3AB2B9283150002EFB8 /* Models.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = Models.swift; sourceTree = "<group>"; };
C3C3240B2B6CA689007D2D9A /* README.md */ = {isa = PBXFileReference; lastKnownFileType = net.daringfireball.markdown; path = README.md; sourceTree = "<group>"; }; C3C3240B2B6CA689007D2D9A /* README.md */ = {isa = PBXFileReference; lastKnownFileType = net.daringfireball.markdown; path = README.md; sourceTree = "<group>"; };
C3C3240C2B6CA792007D2D9A /* README.md */ = {isa = PBXFileReference; lastKnownFileType = net.daringfireball.markdown; path = README.md; sourceTree = "<group>"; }; C3C3240C2B6CA792007D2D9A /* README.md */ = {isa = PBXFileReference; lastKnownFileType = net.daringfireball.markdown; path = README.md; sourceTree = "<group>"; };
C3E786AA2B8D1AEC0004D037 /* Evaluate.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = Evaluate.swift; sourceTree = "<group>"; }; C3E786AA2B8D1AEC0004D037 /* Evaluate.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = Evaluate.swift; sourceTree = "<group>"; };
@@ -267,6 +269,7 @@
isa = PBXGroup; isa = PBXGroup;
children = ( children = (
C34E48EF2B696E6500FCB841 /* Configuration.swift */, C34E48EF2B696E6500FCB841 /* Configuration.swift */,
C3A8B3AB2B9283150002EFB8 /* Models.swift */,
C34E48EE2B696E6500FCB841 /* Llama.swift */, C34E48EE2B696E6500FCB841 /* Llama.swift */,
C38935E22B86C0FE0037B833 /* Gemma.swift */, C38935E22B86C0FE0037B833 /* Gemma.swift */,
C38935C72B869C7A0037B833 /* LLM.h */, C38935C72B869C7A0037B833 /* LLM.h */,
@@ -614,6 +617,7 @@
C38935DF2B869DD00037B833 /* Phi.swift in Sources */, C38935DF2B869DD00037B833 /* Phi.swift in Sources */,
C38935CE2B869C870037B833 /* Load.swift in Sources */, C38935CE2B869C870037B833 /* Load.swift in Sources */,
C3E786AD2B8D4AF50004D037 /* Tokenizer.swift in Sources */, C3E786AD2B8D4AF50004D037 /* Tokenizer.swift in Sources */,
C3A8B3AC2B9283150002EFB8 /* Models.swift in Sources */,
C3E786AB2B8D1AEC0004D037 /* Evaluate.swift in Sources */, C3E786AB2B8D1AEC0004D037 /* Evaluate.swift in Sources */,
C38935CC2B869C870037B833 /* Llama.swift in Sources */, C38935CC2B869C870037B833 /* Llama.swift in Sources */,
); );

View File

@@ -15,7 +15,7 @@
"location" : "https://github.com/ml-explore/mlx-swift", "location" : "https://github.com/ml-explore/mlx-swift",
"state" : { "state" : {
"branch" : "main", "branch" : "main",
"revision" : "fbe215ae29ec286bdc66f4b3423e3eea1b3ef2fc" "revision" : "83bd19f9fe93d77d9f89981eeead2d6d190afdba"
} }
}, },
{ {

View File

@@ -0,0 +1,100 @@
<?xml version="1.0" encoding="UTF-8"?>
<Scheme
LastUpgradeVersion = "1520"
version = "1.7">
<BuildAction
parallelizeBuildables = "YES"
buildImplicitDependencies = "YES">
<BuildActionEntries>
<BuildActionEntry
buildForTesting = "YES"
buildForRunning = "YES"
buildForProfiling = "YES"
buildForArchiving = "YES"
buildForAnalyzing = "YES">
<BuildableReference
BuildableIdentifier = "primary"
BlueprintIdentifier = "C397C58A2B62C6A9004B084D"
BuildableName = "llm-tool"
BlueprintName = "llm-tool"
ReferencedContainer = "container:mlx-swift-examples.xcodeproj">
</BuildableReference>
</BuildActionEntry>
</BuildActionEntries>
</BuildAction>
<TestAction
buildConfiguration = "Debug"
selectedDebuggerIdentifier = "Xcode.DebuggerFoundation.Debugger.LLDB"
selectedLauncherIdentifier = "Xcode.DebuggerFoundation.Launcher.LLDB"
shouldUseLaunchSchemeArgsEnv = "YES"
shouldAutocreateTestPlan = "YES">
</TestAction>
<LaunchAction
buildConfiguration = "Release"
selectedDebuggerIdentifier = "Xcode.DebuggerFoundation.Debugger.LLDB"
selectedLauncherIdentifier = "Xcode.DebuggerFoundation.Launcher.LLDB"
launchStyle = "0"
useCustomWorkingDirectory = "NO"
ignoresPersistentStateOnLaunch = "NO"
debugDocumentVersioning = "YES"
debugServiceExtension = "internal"
allowLocationSimulation = "YES"
viewDebuggingEnabled = "No">
<BuildableProductRunnable
runnableDebuggingMode = "0">
<BuildableReference
BuildableIdentifier = "primary"
BlueprintIdentifier = "C397C58A2B62C6A9004B084D"
BuildableName = "llm-tool"
BlueprintName = "llm-tool"
ReferencedContainer = "container:mlx-swift-examples.xcodeproj">
</BuildableReference>
</BuildableProductRunnable>
<CommandLineArguments>
<CommandLineArgument
argument = "--model mlx-community/CodeLlama-13b-Instruct-hf-4bit-MLX"
isEnabled = "NO">
</CommandLineArgument>
<CommandLineArgument
argument = "--prompt &apos;func sortArray(_ array: [Int]) -&gt; String { &lt;FILL_ME&gt; }&apos;"
isEnabled = "NO">
</CommandLineArgument>
<CommandLineArgument
argument = "--model mlx-community/Mistral-7B-v0.1-hf-4bit-mlx"
isEnabled = "NO">
</CommandLineArgument>
<CommandLineArgument
argument = "--model mlx-community/quantized-gemma-2b-it"
isEnabled = "NO">
</CommandLineArgument>
<CommandLineArgument
argument = "--model mlx-community/phi-2-hf-4bit-mlx"
isEnabled = "YES">
</CommandLineArgument>
</CommandLineArguments>
</LaunchAction>
<ProfileAction
buildConfiguration = "Release"
shouldUseLaunchSchemeArgsEnv = "YES"
savedToolIdentifier = ""
useCustomWorkingDirectory = "NO"
debugDocumentVersioning = "YES">
<BuildableProductRunnable
runnableDebuggingMode = "0">
<BuildableReference
BuildableIdentifier = "primary"
BlueprintIdentifier = "C397C58A2B62C6A9004B084D"
BuildableName = "llm-tool"
BlueprintName = "llm-tool"
ReferencedContainer = "container:mlx-swift-examples.xcodeproj">
</BuildableReference>
</BuildableProductRunnable>
</ProfileAction>
<AnalyzeAction
buildConfiguration = "Debug">
</AnalyzeAction>
<ArchiveAction
buildConfiguration = "Release"
revealArchiveInOrganizer = "YES">
</ArchiveAction>
</Scheme>