From 3f02fcc1cb427aa6621a5eda5e5480d31843f68d Mon Sep 17 00:00:00 2001 From: David Koski Date: Mon, 26 Feb 2024 14:58:51 -0800 Subject: [PATCH] expose eosToken --- Libraries/LLM/Tokenizer.swift | 50 ++++++++++++++++++++++++++++++++++- Tools/llm-tool/LLMTool.swift | 4 +-- 2 files changed, 51 insertions(+), 3 deletions(-) diff --git a/Libraries/LLM/Tokenizer.swift b/Libraries/LLM/Tokenizer.swift index fc062c7..b85a990 100644 --- a/Libraries/LLM/Tokenizer.swift +++ b/Libraries/LLM/Tokenizer.swift @@ -4,6 +4,51 @@ import Foundation import Hub import Tokenizers +/// Wrapper for `Tokenizers.Tokenizer` that provides access to config +/// like ``eosToken``. +public struct Tokenizer: Tokenizers.Tokenizer { + + let tokenizer: Tokenizers.Tokenizer + + public let eosToken: String? + public let eosTokenId: Int? + + internal init(tokenizer: Tokenizers.Tokenizer, tokenizerConfig: Config) { + self.tokenizer = tokenizer + self.eosToken = tokenizerConfig.eosToken?.stringValue + if let eosToken { + self.eosTokenId = tokenizer.convertTokenToId(eosToken) + } else { + self.eosTokenId = nil + } + } + + public func tokenize(text: String) -> [String] { + tokenizer.tokenize(text: text) + } + + public func encode(text: String) -> [Int] { + tokenizer.encode(text: text) + } + + public func decode(tokens: [Int]) -> String { + tokenizer.decode(tokens: tokens) + } + + public func convertTokenToId(_ token: String) -> Int? { + tokenizer.convertTokenToId(token) + } + + public func convertIdToToken(_ id: Int) -> String? { + tokenizer.convertIdToToken(id) + } + + public var unknownToken: String? { tokenizer.unknownToken } + + public var unknownTokenId: Int? { tokenizer.unknownTokenId } + +} + public func loadTokenizer(name: String) async throws -> Tokenizer { // from AutoTokenizer.from() -- this lets us override parts of the configuration let config = LanguageModelConfigurationFromHub(modelName: name) @@ -31,7 +76,10 @@ public func loadTokenizer(name: String) async throws -> Tokenizer { } } - return try PreTrainedTokenizer(tokenizerConfig: tokenizerConfig, tokenizerData: tokenizerData) + let impl = try PreTrainedTokenizer( + tokenizerConfig: tokenizerConfig, tokenizerData: tokenizerData) + + return Tokenizer(tokenizer: impl, tokenizerConfig: tokenizerConfig) } public func discardUnhandledMerges(tokenizerData: Config) -> Config { diff --git a/Tools/llm-tool/LLMTool.swift b/Tools/llm-tool/LLMTool.swift index 74f23f5..542bbda 100644 --- a/Tools/llm-tool/LLMTool.swift +++ b/Tools/llm-tool/LLMTool.swift @@ -64,7 +64,7 @@ struct SyncGenerator: AsyncParsableCommand { } let t = token.item(Int.self) - if t == tokenizer.unknownTokenId { + if t == tokenizer.unknownTokenId || t == tokenizer.eosTokenId { break } tokens.append(t) @@ -150,7 +150,7 @@ struct AsyncGenerator: AsyncParsableCommand { start = now } - if token == tokenizer.unknownTokenId { + if token == tokenizer.unknownTokenId || token == tokenizer.eosTokenId { break } tokens.append(token)