diff --git a/Applications/LLMEval/ContentView.swift b/Applications/LLMEval/ContentView.swift index f4193e5..0674131 100644 --- a/Applications/LLMEval/ContentView.swift +++ b/Applications/LLMEval/ContentView.swift @@ -160,7 +160,7 @@ class LLMEvaluator { let modelConfiguration = ModelConfiguration.phi4bit /// parameters controlling the output - let temperature: Float = 0.6 + let generateParameters = GenerateParameters(temperature: 0.6) let maxTokens = 240 /// update the display every N tokens -- 4 looks like it updates continuously @@ -201,7 +201,6 @@ class LLMEvaluator { } func generate(prompt: String) async { - let startTime = Date() do { let (model, tokenizer) = try await load() @@ -212,59 +211,37 @@ class LLMEvaluator { // augment the prompt as needed let prompt = modelConfiguration.prepare(prompt: prompt) - let promptTokens = MLXArray(tokenizer.encode(text: prompt)) - - var initTime = Date() - let initDuration = initTime.timeIntervalSince(startTime) - await MainActor.run { - self.stat = "Init: \(String(format: "%.3f", initDuration))s" - } + let promptTokens = tokenizer.encode(text: prompt) // each time you generate you will get something new MLXRandom.seed(UInt64(Date.timeIntervalSinceReferenceDate * 1000)) - var outputTokens = [Int]() - - for token in TokenIterator(prompt: promptTokens, model: model, temp: temperature) { - let tokenId = token.item(Int.self) - - // to match the measurement from the command line we reset the start time - // after the first token is generated (called the prompt time) - if outputTokens.isEmpty { - initTime = Date() - } - - if tokenId == tokenizer.unknownTokenId || tokenId == tokenizer.eosTokenId { - break - } - - outputTokens.append(tokenId) - let text = tokenizer.decode(tokens: outputTokens) - + let result = await LLM.generate( + promptTokens: promptTokens, parameters: generateParameters, model: model, + tokenizer: tokenizer + ) { tokens in // update the output -- this will make the view show the text as it generates - if outputTokens.count % displayEveryNTokens == 0 { + if tokens.count % displayEveryNTokens == 0 { + let text = tokenizer.decode(tokens: tokens) await MainActor.run { self.output = text } } - if outputTokens.count == maxTokens { - break + if tokens.count >= maxTokens { + return .stop + } else { + return .more } } - let tokenDuration = Date().timeIntervalSince(initTime) - let tokensPerSecond = Double(outputTokens.count) / tokenDuration - // update the text if needed, e.g. we haven't displayed because of displayEveryNTokens - let finalText = tokenizer.decode(tokens: outputTokens) - await MainActor.run { - if finalText != self.output { - self.output = finalText + if result.output != self.output { + self.output = result.output } running = false - self.stat += " Tokens/second: \(String(format: "%.3f", tokensPerSecond))" + self.stat = " Tokens/second: \(String(format: "%.3f", result.tokensPerSecond))" } } catch { diff --git a/Libraries/LLM/Evaluate.swift b/Libraries/LLM/Evaluate.swift index 9267121..94fdda7 100644 --- a/Libraries/LLM/Evaluate.swift +++ b/Libraries/LLM/Evaluate.swift @@ -4,6 +4,7 @@ import AsyncAlgorithms import Foundation import MLX import MLXRandom +import Tokenizers private func topPSampling(logits: MLXArray, topP: Float, temp: Float) -> MLXArray { var logits = logits @@ -28,8 +29,6 @@ private func topPSampling(logits: MLXArray, topP: Float, temp: Float) -> MLXArra private func applyRepetitionPenalty( logits: MLXArray, repetitionContext: MLXArray, penalty: Float ) -> MLXArray { - var logits = logits - if repetitionContext.shape[0] > 0 { let indices = repetitionContext var selectedLogits = take(logits, indices, axis: -1).squeezed(axis: 0) @@ -55,37 +54,53 @@ private func sample(logits: MLXArray, temp: Float, topP: Float = 1.0) -> MLXArra } } +/// Parameters for text generation, see ``TokenIterator`` +public struct GenerateParameters { + /// sampling temperature + public var temperature: Float = 0.6 + + /// top p sampling + public var topP: Float = 0.9 + + /// penalty factor for repeating tokens + public var repetitionPenalty: Float = 1.0 + + /// number of tokens to consider for repetition penalty + public var repetitionContextSize: Int = 20 + + public init( + temperature: Float = 0.6, topP: Float = 0.9, repetitionPenalty: Float = 1.0, + repetitionContextSize: Int = 20 + ) { + self.temperature = temperature + self.topP = topP + self.repetitionPenalty = repetitionPenalty + self.repetitionContextSize = repetitionContextSize + } +} + /// Synchronous generator of tokens. /// /// Port of `generate_step()` from https://github.com/ml-explore/mlx-examples/blob/main/llms/mlx_lm/utils.py public struct TokenIterator: Sequence, IteratorProtocol { let model: LLMModel - let temp: Float - let topP: Float - let repetitionPenalty: Float - let repetitionContextSize: Int + let parameters: GenerateParameters var repetitionContext: MLXArray var y: MLXArray var cache: [(MLXArray, MLXArray)] var first = true - public init( - prompt: MLXArray, model: LLMModel, temp: Float = 0.0, topP: Float = 1.0, - repetitionPenalty: Float = 1.0, repetitionContextSize: Int = 20 - ) { + public init(prompt: MLXArray, model: LLMModel, parameters: GenerateParameters) { self.model = model - self.temp = temp - self.topP = topP + self.parameters = parameters self.y = prompt self.cache = [] - self.repetitionPenalty = repetitionPenalty - self.repetitionContextSize = repetitionContextSize - if repetitionContextSize > 1 { - if prompt.shape[0] <= repetitionContextSize { + if parameters.repetitionContextSize > 1 { + if prompt.shape[0] <= parameters.repetitionContextSize { self.repetitionContext = prompt } else { - self.repetitionContext = prompt[-repetitionContextSize ... -1] + self.repetitionContext = prompt[-parameters.repetitionContextSize ... -1] } } else { self.repetitionContext = [] @@ -96,16 +111,17 @@ public struct TokenIterator: Sequence, IteratorProtocol { var logits: MLXArray (logits, cache) = model(expandedDimensions(y, axis: 0), cache: cache.isEmpty ? nil : cache) logits = logits[0..., -1, 0...] - if repetitionPenalty > 1.0 { + if parameters.repetitionPenalty > 1.0 { // apply repetition penalty logits = applyRepetitionPenalty( - logits: logits, repetitionContext: repetitionContext, penalty: repetitionPenalty) + logits: logits, repetitionContext: repetitionContext, + penalty: parameters.repetitionPenalty) } - y = sample(logits: logits, temp: temp, topP: topP) + y = sample(logits: logits, temp: parameters.temperature, topP: parameters.topP) // append the current token to the context and check repetitionPenalty context see if need to remove the first token - if repetitionContextSize > 1 { + if parameters.repetitionContextSize > 1 { repetitionContext = concatenated([repetitionContext, y], axis: 0) - if repetitionContext.shape[0] > repetitionContextSize { + if repetitionContext.shape[0] > parameters.repetitionContextSize { repetitionContext = repetitionContext[1...] } } @@ -114,61 +130,88 @@ public struct TokenIterator: Sequence, IteratorProtocol { } } -/// Async generator of tokens. +public struct GenerateResult { + /// input tokens + public let promptTokens: [Int] + + /// output tokens + public let tokens: [Int] + + /// output text + public let output: String + + /// time to process the prompt / generate the first token + public let promptTime: TimeInterval + + /// time to generate the remaining tokens + public let generateTime: TimeInterval + + public var promptTokensPerSecond: Double { + Double(promptTokens.count) / promptTime + } + + public var tokensPerSecond: Double { + Double(tokens.count - 1) / generateTime + } + + public func summary() -> String { + """ + Prompt Tokens per second: \(promptTokensPerSecond.formatted()) + Generation tokens per second: \(tokensPerSecond.formatted()) + """ + } +} + +public enum GenerateDisposition { + case more + case stop +} + +/// Given prompt tokens generate text using the given model and parameters. /// -/// Port of `generate_step()` from https://github.com/ml-explore/mlx-examples/blob/main/llms/mlx_lm/utils.py. -/// -/// Note that because MLXArray is not thread safe this eval's the result and sends the TokenId back -/// to the caller. +/// - Parameters: +/// - promptTokens: tokenized prompt +/// - parameters: generation parameters +/// - model: model to evaluate +/// - tokenizer: tokenizer to convert tokens back into strings and recognizer special tokens +/// - didGenerate: visitor for the tokens as they are generated public func generate( - prompt: MLXArray, model: LLMModel, temp: Float = 0.0, topP: Float = 1.0, - repetitionPenalty: Float = 1.0, repetitionContextSize: Int = 20 -) -> ( - Task, AsyncBufferSequence> -) { - let channel = AsyncChannel() - let buffer = channel.buffer(policy: .bounded(10)) + promptTokens: [Int], parameters: GenerateParameters, model: LLMModel, tokenizer: Tokenizer, + didGenerate: ([Int]) async -> GenerateDisposition +) async -> GenerateResult { + var start = Date.timeIntervalSinceReferenceDate + var promptTime: TimeInterval = 0 - let task = Task { - var y = prompt - var cache = [(MLXArray, MLXArray)]() - var repetitionContext: MLXArray + var tokens = [Int]() - if repetitionContextSize > 1 { - if prompt.shape[0] <= repetitionContextSize { - repetitionContext = prompt - } else { - repetitionContext = prompt[-repetitionContextSize ... -1] - } - } else { - repetitionContext = [] + for token in TokenIterator( + prompt: MLXArray(promptTokens), model: model, parameters: parameters) + { + // compute the timing for the prompt + if tokens.isEmpty { + eval(token) + let now = Date.timeIntervalSinceReferenceDate + promptTime = now - start + start = now } - while !Task.isCancelled { - var logits: MLXArray - (logits, cache) = model( - expandedDimensions(y, axis: 0), cache: cache.isEmpty ? nil : cache) - logits = logits[0..., -1, 0...] - if repetitionPenalty > 1.0 { - // apply repetition penalty - logits = applyRepetitionPenalty( - logits: logits, repetitionContext: repetitionContext, penalty: repetitionPenalty - ) - } - y = sample(logits: logits, temp: temp, topP: topP) - // append the current token to the context and check repetitionPenalty context see if need to remove the first token - if repetitionContextSize > 1 { - repetitionContext = concatenated([repetitionContext, y], axis: 0) - if repetitionContext.shape[0] > repetitionContextSize { - repetitionContext = repetitionContext[1...] - } - } + let t = token.item(Int.self) + if t == tokenizer.unknownTokenId || t == tokenizer.eosTokenId { + break + } - eval(y) + tokens.append(t) - await channel.send(y.item(Int.self)) + if await didGenerate(tokens) == .stop { + break } } - return (task, buffer) + let now = Date.timeIntervalSinceReferenceDate + let generateTime = now - start + + return GenerateResult( + promptTokens: promptTokens, tokens: tokens, + output: tokenizer.decode(tokens: tokens), + promptTime: promptTime, generateTime: generateTime) } diff --git a/Tools/llm-tool/LLMTool.swift b/Tools/llm-tool/LLMTool.swift index cdbe708..4430489 100644 --- a/Tools/llm-tool/LLMTool.swift +++ b/Tools/llm-tool/LLMTool.swift @@ -10,16 +10,27 @@ import Tokenizers @main struct LLMTool: AsyncParsableCommand { static var configuration = CommandConfiguration( - abstract: "Command line tool for generating text using Llama models", - subcommands: [SyncGenerator.self, AsyncGenerator.self], - defaultSubcommand: SyncGenerator.self) + abstract: "Command line tool for generating text and manipulating LLMs", + subcommands: [EvaluateCommand.self], + defaultSubcommand: EvaluateCommand.self) } -struct LLMArguments: ParsableArguments { +/// Command line arguments for loading a model. +struct ModelArguments: ParsableArguments { @Option(name: .long, help: "Name of the huggingface model") var model: String = "mlx-community/Mistral-7B-v0.1-hf-4bit-mlx" + func load() async throws -> (LLMModel, Tokenizer, ModelConfiguration) { + let modelConfiguration = ModelConfiguration.configuration(id: model) + let (model, tokenizer) = try await LLM.load(configuration: modelConfiguration) + return (model, tokenizer, modelConfiguration) + } +} + +/// Command line arguments for controlling generation of text. +struct GenerateArguments: ParsableArguments { + @Option(name: .shortAndLong, help: "The message to be processed by the model") var prompt = "compare python and swift" @@ -29,19 +40,67 @@ struct LLMArguments: ParsableArguments { @Option(name: .shortAndLong, help: "The sampling temperature") var temperature: Float = 0.6 - @Option(name: .shortAndLong, help: "The top p sampling") + @Option(name: .long, help: "The top p sampling") var topP: Float = 0.9 - @Option(name: .shortAndLong, help: "The penalty factor for repeating tokens") + @Option(name: .long, help: "The penalty factor for repeating tokens") var repetitionPenalty: Float = 1.0 - @Option(name: .shortAndLong, help: "The number of tokens to consider for repetition penalty") + @Option(name: .long, help: "The number of tokens to consider for repetition penalty") var repetitionContextSize: Int = 20 @Option(name: .long, help: "The PRNG seed") var seed: UInt64 = 0 - @Flag(help: "Show memory stats") + var generateParameters: GenerateParameters { + GenerateParameters( + temperature: temperature, topP: topP, repetitionPenalty: repetitionPenalty, + repetitionContextSize: repetitionContextSize) + } + + func tokenizePrompt(configuration: ModelConfiguration, tokenizer: Tokenizer) -> (String, [Int]) + { + MLXRandom.seed(seed) + + let prompt = configuration.prepare(prompt: self.prompt) + let promptTokens = tokenizer.encode(text: prompt) + + return (prompt, promptTokens) + } + + func generate(promptTokens: [Int], model: LLMModel, tokenizer: Tokenizer) async + -> GenerateResult + { + // track how much we have printed + var printed = 0 + + return await LLM.generate( + promptTokens: promptTokens, parameters: generateParameters, + model: model, tokenizer: tokenizer + ) { tokens in + + // print any new parts of the string + let fullOutput = tokenizer.decode(tokens: tokens) + let emitLength = fullOutput.count - printed + let suffix = fullOutput.suffix(emitLength) + print(suffix, terminator: "") + fflush(stdout) + + printed = fullOutput.count + + if tokens.count >= maxTokens { + return .stop + } else { + return .more + } + } + } +} + +/// Argument package for adjusting and reporting memory use. +struct MemoryArguments: ParsableArguments { + + @Flag(name: .long, help: "Show memory stats") var memoryStats = false @Option(name: .long, help: "Maximum cache size in M") @@ -52,9 +111,7 @@ struct LLMArguments: ParsableArguments { var startMemory: GPU.Snapshot? - mutating func load() async throws -> (LLMModel, Tokenizer, ModelConfiguration) { - MLXRandom.seed(seed) - + mutating func start(_ load: () async throws -> L) async throws -> L { if let cacheSize { GPU.set(cacheLimit: cacheSize * 1024 * 1024) } @@ -63,20 +120,29 @@ struct LLMArguments: ParsableArguments { GPU.set(memoryLimit: memorySize * 1024 * 1024) } - let modelConfiguration = ModelConfiguration.configuration(id: model) - let (model, tokenizer) = try await LLM.load(configuration: modelConfiguration) - + let result = try await load() startMemory = GPU.snapshot() - return (model, tokenizer, modelConfiguration) + return result } - func tokenizePropmpt(configuration: ModelConfiguration, tokenizer: Tokenizer) -> (String, [Int]) - { - let prompt = configuration.prepare(prompt: self.prompt) - let promptTokens = tokenizer.encode(text: prompt) + mutating func start() { + if let cacheSize { + GPU.set(cacheLimit: cacheSize * 1024 * 1024) + } - return (prompt, promptTokens) + if let memorySize { + GPU.set(memoryLimit: memorySize * 1024 * 1024) + } + + startMemory = GPU.snapshot() + } + + func reportCurrent() { + if memoryStats { + let memory = GPU.snapshot() + print(memory.description) + } } func reportMemoryStatistics() { @@ -106,164 +172,36 @@ struct LLMArguments: ParsableArguments { } } -struct SyncGenerator: AsyncParsableCommand { +struct EvaluateCommand: AsyncParsableCommand { static var configuration = CommandConfiguration( - commandName: "sync", - abstract: "Synchronous generator" + commandName: "eval", + abstract: "evaluate prompt and generate text" ) - @OptionGroup var args: LLMArguments + @OptionGroup var args: ModelArguments + @OptionGroup var memory: MemoryArguments + @OptionGroup var generate: GenerateArguments @MainActor mutating func run() async throws { - let (model, tokenizer, modelConfiguration) = try await args.load() + let (model, tokenizer, modelConfiguration) = try await memory.start(args.load) print("Model loaded -> \(modelConfiguration.id)") - let (prompt, promptTokens) = args.tokenizePropmpt( + let (prompt, promptTokens) = generate.tokenizePrompt( configuration: modelConfiguration, tokenizer: tokenizer) print("Starting generation ...") print(prompt, terminator: "") - var start = Date.timeIntervalSinceReferenceDate - var promptTime: TimeInterval = 0 - - // collect the tokens and keep track of how much of the string - // we have printed already - var tokens = [Int]() - var printed = 0 - - for token in TokenIterator( - prompt: MLXArray(promptTokens), model: model, temp: args.temperature, topP: args.topP, - repetitionPenalty: args.repetitionPenalty, - repetitionContextSize: args.repetitionContextSize) - { - if tokens.isEmpty { - eval(token) - let now = Date.timeIntervalSinceReferenceDate - promptTime = now - start - start = now - } - - let t = token.item(Int.self) - if t == tokenizer.unknownTokenId || t == tokenizer.eosTokenId { - break - } - tokens.append(t) - - // print any new parts of the string - let fullOutput = tokenizer.decode(tokens: tokens) - let emitLength = fullOutput.count - printed - let suffix = fullOutput.suffix(emitLength) - print(suffix, terminator: "") - fflush(stdout) - - printed = fullOutput.count - - if tokens.count == args.maxTokens { - break - } - } + let result = await generate.generate( + promptTokens: promptTokens, model: model, tokenizer: tokenizer) print() print("------") - let now = Date.timeIntervalSinceReferenceDate - let generateTime = now - start + print(result.summary()) - print( - """ - Prompt Tokens per second: \((Double(promptTokens.count) / promptTime).formatted()) - Generation tokens per second: \((Double(tokens.count - 1) / generateTime).formatted()) - """) - - args.reportMemoryStatistics() - } -} - -/// Example of an async generator. -/// -/// Note that all of the computation is done on another thread and TokenId (Int32) are sent -/// rather than MLXArray. -struct AsyncGenerator: AsyncParsableCommand { - - static var configuration = CommandConfiguration( - commandName: "async", - abstract: "async generator" - ) - - @OptionGroup var args: LLMArguments - - @MainActor - mutating func run() async throws { - let (model, tokenizer, modelConfiguration) = try await args.load() - - print("Model loaded -> \(modelConfiguration.id)") - - let (prompt, promptTokens) = args.tokenizePropmpt( - configuration: modelConfiguration, tokenizer: tokenizer) - - print("Starting generation ...") - print(prompt, terminator: "") - - var start = Date.timeIntervalSinceReferenceDate - var promptTime: TimeInterval = 0 - - // collect the tokens and keep track of how much of the string - // we have printed already - var tokens = [Int]() - var printed = 0 - - let (task, channel) = generate( - prompt: MLXArray(promptTokens), model: model, temp: args.temperature, topP: args.topP, - repetitionPenalty: args.repetitionPenalty, - repetitionContextSize: args.repetitionContextSize) - - for await token in channel { - if tokens.isEmpty { - let now = Date.timeIntervalSinceReferenceDate - promptTime = now - start - start = now - } - - if token == tokenizer.unknownTokenId || token == tokenizer.eosTokenId { - break - } - tokens.append(token) - - // print any new parts of the string - let fullOutput = tokenizer.decode(tokens: tokens) - let emitLength = fullOutput.count - printed - let suffix = fullOutput.suffix(emitLength) - print(suffix, terminator: "") - fflush(stdout) - - printed = fullOutput.count - - if tokens.count == args.maxTokens { - break - } - } - - // tell the task to stop - task.cancel() - - print() - print("------") - let now = Date.timeIntervalSinceReferenceDate - let generateTime = now - start - - print( - """ - Prompt Tokens per second: \((Double(promptTokens.count) / promptTime).formatted()) - Generation tokens per second: \((Double(tokens.count - 1) / generateTime).formatted()) - """) - - args.reportMemoryStatistics() - - // wait for the task to complete -- since it is running async, it might - // be in the middle of running the model - try? await Task.sleep(for: .milliseconds(500)) + memory.reportMemoryStatistics() } } diff --git a/mlx-swift-examples.xcodeproj/project.pbxproj b/mlx-swift-examples.xcodeproj/project.pbxproj index 27e0a60..02bac53 100644 --- a/mlx-swift-examples.xcodeproj/project.pbxproj +++ b/mlx-swift-examples.xcodeproj/project.pbxproj @@ -42,11 +42,9 @@ C3932D592B6A0BE400A81055 /* Random.swift in Sources */ = {isa = PBXBuildFile; fileRef = C3932D582B6A0BE400A81055 /* Random.swift */; }; C397C59C2B62C6D0004B084D /* ArgumentParser in Frameworks */ = {isa = PBXBuildFile; productRef = C397C59B2B62C6D0004B084D /* ArgumentParser */; }; C3A8B3AC2B9283150002EFB8 /* Models.swift in Sources */ = {isa = PBXBuildFile; fileRef = C3A8B3AB2B9283150002EFB8 /* Models.swift */; }; - C3A8B3CA2B92951E0002EFB8 /* MNISTTrainer-Info.plist in Resources */ = {isa = PBXBuildFile; fileRef = C3A8B3C22B92951E0002EFB8 /* MNISTTrainer-Info.plist */; }; C3A8B3CB2B92951E0002EFB8 /* Assets.xcassets in Resources */ = {isa = PBXBuildFile; fileRef = C3A8B3C32B92951E0002EFB8 /* Assets.xcassets */; }; C3A8B3CC2B92951E0002EFB8 /* MNISTTrainerApp.swift in Sources */ = {isa = PBXBuildFile; fileRef = C3A8B3C42B92951E0002EFB8 /* MNISTTrainerApp.swift */; }; C3A8B3CD2B92951E0002EFB8 /* Preview Assets.xcassets in Resources */ = {isa = PBXBuildFile; fileRef = C3A8B3C62B92951E0002EFB8 /* Preview Assets.xcassets */; }; - C3A8B3CE2B92951E0002EFB8 /* README.md in Resources */ = {isa = PBXBuildFile; fileRef = C3A8B3C82B92951E0002EFB8 /* README.md */; }; C3A8B3CF2B92951E0002EFB8 /* ContentView.swift in Sources */ = {isa = PBXBuildFile; fileRef = C3A8B3C92B92951E0002EFB8 /* ContentView.swift */; }; C3A8B3D22B92A0880002EFB8 /* MLXOptimizers in Frameworks */ = {isa = PBXBuildFile; productRef = C3A8B3D12B92A0880002EFB8 /* MLXOptimizers */; }; C3A8B3D32B92A0880002EFB8 /* MNIST.framework in Frameworks */ = {isa = PBXBuildFile; fileRef = C34E490D2B69A92900FCB841 /* MNIST.framework */; }; @@ -54,7 +52,6 @@ C3A8B3F32B92A2A90002EFB8 /* Assets.xcassets in Resources */ = {isa = PBXBuildFile; fileRef = C3A8B3EC2B92A2A90002EFB8 /* Assets.xcassets */; }; C3A8B3F42B92A2A90002EFB8 /* LLMEvalApp.swift in Sources */ = {isa = PBXBuildFile; fileRef = C3A8B3ED2B92A2A90002EFB8 /* LLMEvalApp.swift */; }; C3A8B3F52B92A2A90002EFB8 /* Preview Assets.xcassets in Resources */ = {isa = PBXBuildFile; fileRef = C3A8B3EF2B92A2A90002EFB8 /* Preview Assets.xcassets */; }; - C3A8B3F62B92A2A90002EFB8 /* README.md in Resources */ = {isa = PBXBuildFile; fileRef = C3A8B3F02B92A2A90002EFB8 /* README.md */; }; C3A8B3F72B92A2A90002EFB8 /* ContentView.swift in Sources */ = {isa = PBXBuildFile; fileRef = C3A8B3F22B92A2A90002EFB8 /* ContentView.swift */; }; C3A8B3F82B92A3360002EFB8 /* LLM.framework in Frameworks */ = {isa = PBXBuildFile; fileRef = C38935C52B869C7A0037B833 /* LLM.framework */; }; C3A8B3F92B92A3360002EFB8 /* LLM.framework in Embed Frameworks */ = {isa = PBXBuildFile; fileRef = C38935C52B869C7A0037B833 /* LLM.framework */; settings = {ATTRIBUTES = (CodeSignOnCopy, RemoveHeadersOnCopy, ); }; }; @@ -801,8 +798,6 @@ isa = PBXResourcesBuildPhase; buildActionMask = 2147483647; files = ( - C3A8B3CE2B92951E0002EFB8 /* README.md in Resources */, - C3A8B3CA2B92951E0002EFB8 /* MNISTTrainer-Info.plist in Resources */, C3A8B3CD2B92951E0002EFB8 /* Preview Assets.xcassets in Resources */, C3A8B3CB2B92951E0002EFB8 /* Assets.xcassets in Resources */, ); @@ -813,7 +808,6 @@ buildActionMask = 2147483647; files = ( C3A8B3F52B92A2A90002EFB8 /* Preview Assets.xcassets in Resources */, - C3A8B3F62B92A2A90002EFB8 /* README.md in Resources */, C3A8B3F32B92A2A90002EFB8 /* Assets.xcassets in Resources */, ); runOnlyForDeploymentPostprocessing = 0; diff --git a/mlx-swift-examples.xcodeproj/project.xcworkspace/xcshareddata/swiftpm/Package.resolved b/mlx-swift-examples.xcodeproj/project.xcworkspace/xcshareddata/swiftpm/Package.resolved index 5b241b4..c129691 100644 --- a/mlx-swift-examples.xcodeproj/project.xcworkspace/xcshareddata/swiftpm/Package.resolved +++ b/mlx-swift-examples.xcodeproj/project.xcworkspace/xcshareddata/swiftpm/Package.resolved @@ -1,4 +1,5 @@ { + "originHash" : "da53546673b6d05016b6e5640c18814c7dba5b5af8db34715afe6d633037c758", "pins" : [ { "identity" : "gzipswift", @@ -15,7 +16,7 @@ "location" : "https://github.com/ml-explore/mlx-swift", "state" : { "branch" : "main", - "revision" : "a1c544c817d44cfdfa1a650f521066b565c2ae4f" + "revision" : "b4d3e4bbbe41e6dc7c46d5ba075049ae7177961b" } }, { @@ -82,5 +83,5 @@ } } ], - "version" : 2 + "version" : 3 }