diff --git a/Libraries/LLM/Load.swift b/Libraries/LLM/Load.swift index e733f2d..add4552 100644 --- a/Libraries/LLM/Load.swift +++ b/Libraries/LLM/Load.swift @@ -59,73 +59,6 @@ public func load( return (model, tokenizer) } -// MARK: - Tokenizers - -public func loadTokenizer(name: String) async throws -> Tokenizer { - // from AutoTokenizer.from() -- this lets us override parts of the configuration - let config = LanguageModelConfigurationFromHub(modelName: name) - guard var tokenizerConfig = try await config.tokenizerConfig else { - throw LLMError(message: "missing config") - } - var tokenizerData = try await config.tokenizerData - - // workaround: replacement tokenizers for unhandled values in swift-transform - if let tokenizerClass = tokenizerConfig.tokenizerClass?.stringValue, - let replacement = replacementTokenizers[tokenizerClass] - { - var dictionary = tokenizerConfig.dictionary - dictionary["tokenizer_class"] = replacement - tokenizerConfig = Config(dictionary) - } - - // workaround: some merges can't be split on space in BPETokenizer - if let tokenizerClass = tokenizerConfig.tokenizerClass?.stringValue { - switch tokenizerClass { - case "T5Tokenizer": - break - default: - tokenizerData = discardUnhandledMerges(tokenizerData: tokenizerData) - } - } - - return try PreTrainedTokenizer(tokenizerConfig: tokenizerConfig, tokenizerData: tokenizerData) -} - -public func discardUnhandledMerges(tokenizerData: Config) -> Config { - // see https://github.com/ml-explore/mlx-swift-examples/issues/1 - // and https://github.com/huggingface/swift-transformers/issues/51 - - if let model = tokenizerData.model { - if let merges = model.dictionary["merges"] as? [String] { - // discard any merges that can't be split on a space - // (required by BPETokenizer) - let newMerges = - merges - .filter { - $0.split(separator: " ").count == 2 - } - - if newMerges.count != merges.count { - var newModel = model.dictionary - newModel["merges"] = newMerges - - var newTokenizerData = tokenizerData.dictionary - newTokenizerData["model"] = newModel - - return Config(newTokenizerData) - } - } - } - - return tokenizerData -} - -/// overrides for TokenizerModel/knownTokenizers -let replacementTokenizers = [ - "CodeLlamaTokenizer": "LlamaTokenizer", - "GemmaTokenizer": "PreTrainedTokenizer", -] - // MARK: - Quantization private func quantizeIfNeeded( diff --git a/Libraries/LLM/Tokenizer.swift b/Libraries/LLM/Tokenizer.swift new file mode 100644 index 0000000..fc062c7 --- /dev/null +++ b/Libraries/LLM/Tokenizer.swift @@ -0,0 +1,70 @@ +// Copyright © 2024 Apple Inc. + +import Foundation +import Hub +import Tokenizers + +public func loadTokenizer(name: String) async throws -> Tokenizer { + // from AutoTokenizer.from() -- this lets us override parts of the configuration + let config = LanguageModelConfigurationFromHub(modelName: name) + guard var tokenizerConfig = try await config.tokenizerConfig else { + throw LLMError(message: "missing config") + } + var tokenizerData = try await config.tokenizerData + + // workaround: replacement tokenizers for unhandled values in swift-transform + if let tokenizerClass = tokenizerConfig.tokenizerClass?.stringValue, + let replacement = replacementTokenizers[tokenizerClass] + { + var dictionary = tokenizerConfig.dictionary + dictionary["tokenizer_class"] = replacement + tokenizerConfig = Config(dictionary) + } + + // workaround: some merges can't be split on space in BPETokenizer + if let tokenizerClass = tokenizerConfig.tokenizerClass?.stringValue { + switch tokenizerClass { + case "T5Tokenizer": + break + default: + tokenizerData = discardUnhandledMerges(tokenizerData: tokenizerData) + } + } + + return try PreTrainedTokenizer(tokenizerConfig: tokenizerConfig, tokenizerData: tokenizerData) +} + +public func discardUnhandledMerges(tokenizerData: Config) -> Config { + // see https://github.com/ml-explore/mlx-swift-examples/issues/1 + // and https://github.com/huggingface/swift-transformers/issues/51 + + if let model = tokenizerData.model { + if let merges = model.dictionary["merges"] as? [String] { + // discard any merges that can't be split on a space + // (required by BPETokenizer) + let newMerges = + merges + .filter { + $0.split(separator: " ").count == 2 + } + + if newMerges.count != merges.count { + var newModel = model.dictionary + newModel["merges"] = newMerges + + var newTokenizerData = tokenizerData.dictionary + newTokenizerData["model"] = newModel + + return Config(newTokenizerData) + } + } + } + + return tokenizerData +} + +/// overrides for TokenizerModel/knownTokenizers +let replacementTokenizers = [ + "CodeLlamaTokenizer": "LlamaTokenizer", + "GemmaTokenizer": "PreTrainedTokenizer", +] diff --git a/mlx-swift-examples.xcodeproj/project.pbxproj b/mlx-swift-examples.xcodeproj/project.pbxproj index f033265..4f68254 100644 --- a/mlx-swift-examples.xcodeproj/project.pbxproj +++ b/mlx-swift-examples.xcodeproj/project.pbxproj @@ -37,6 +37,7 @@ C3932D592B6A0BE400A81055 /* Random.swift in Sources */ = {isa = PBXBuildFile; fileRef = C3932D582B6A0BE400A81055 /* Random.swift */; }; C397C59C2B62C6D0004B084D /* ArgumentParser in Frameworks */ = {isa = PBXBuildFile; productRef = C397C59B2B62C6D0004B084D /* ArgumentParser */; }; C3E786AB2B8D1AEC0004D037 /* Evaluate.swift in Sources */ = {isa = PBXBuildFile; fileRef = C3E786AA2B8D1AEC0004D037 /* Evaluate.swift */; }; + C3E786AD2B8D4AF50004D037 /* Tokenizer.swift in Sources */ = {isa = PBXBuildFile; fileRef = C3E786AC2B8D4AF50004D037 /* Tokenizer.swift */; }; C3FBCB212B8520B80007E490 /* MLX in Frameworks */ = {isa = PBXBuildFile; productRef = C3FBCB202B8520B80007E490 /* MLX */; }; C3FBCB292B8520DA0007E490 /* MLX in Frameworks */ = {isa = PBXBuildFile; productRef = C3FBCB282B8520DA0007E490 /* MLX */; }; C3FBCB2B2B8520DA0007E490 /* MLXNN in Frameworks */ = {isa = PBXBuildFile; productRef = C3FBCB2A2B8520DA0007E490 /* MLXNN */; }; @@ -154,6 +155,7 @@ C3C3240B2B6CA689007D2D9A /* README.md */ = {isa = PBXFileReference; lastKnownFileType = net.daringfireball.markdown; path = README.md; sourceTree = ""; }; C3C3240C2B6CA792007D2D9A /* README.md */ = {isa = PBXFileReference; lastKnownFileType = net.daringfireball.markdown; path = README.md; sourceTree = ""; }; C3E786AA2B8D1AEC0004D037 /* Evaluate.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = Evaluate.swift; sourceTree = ""; }; + C3E786AC2B8D4AF50004D037 /* Tokenizer.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = Tokenizer.swift; sourceTree = ""; }; /* End PBXFileReference section */ /* Begin PBXFrameworksBuildPhase section */ @@ -273,6 +275,7 @@ C34E48F62B69832600FCB841 /* README.md */, C34E48ED2B696E6500FCB841 /* Load.swift */, C3E786AA2B8D1AEC0004D037 /* Evaluate.swift */, + C3E786AC2B8D4AF50004D037 /* Tokenizer.swift */, ); path = LLM; sourceTree = ""; @@ -610,6 +613,7 @@ C38935CD2B869C870037B833 /* Configuration.swift in Sources */, C38935DF2B869DD00037B833 /* Phi.swift in Sources */, C38935CE2B869C870037B833 /* Load.swift in Sources */, + C3E786AD2B8D4AF50004D037 /* Tokenizer.swift in Sources */, C3E786AB2B8D1AEC0004D037 /* Evaluate.swift in Sources */, C38935CC2B869C870037B833 /* Llama.swift in Sources */, );