// Copyright © 2024 Apple Inc. import ArgumentParser import Foundation import MLX import MLXNN import MLXOptimizers import MLXRandom import MNIST @main struct MNISTTool: AsyncParsableCommand { static var configuration = CommandConfiguration( abstract: "Command line tool for training mnist models", subcommands: [Train.self], defaultSubcommand: Train.self) } extension MLX.DeviceType: ExpressibleByArgument { public init?(argument: String) { self.init(rawValue: argument) } } struct Train: AsyncParsableCommand { @Option(name: .long, help: "Directory with the training data") var data: String @Option(name: .long, help: "The PRNG seed") var seed: UInt64 = 0 @Option var layers = 2 @Option var hidden = 32 @Option var batchSize = 256 @Option var epochs = 20 @Option var learningRate: Float = 1e-1 @Option var classes = 10 @Option var device = DeviceType.cpu @Flag var compile = false func run() async throws { Device.setDefault(device: Device(device)) MLXRandom.seed(seed) var generator: RandomNumberGenerator = SplitMix64(seed: seed) // load the data let url = URL(filePath: data) try FileManager.default.createDirectory(at: url, withIntermediateDirectories: true) try await download(into: url) let data = try load(from: url) let trainImages = data[.init(.training, .images)]! let trainLabels = data[.init(.training, .labels)]! let testImages = data[.init(.test, .images)]! let testLabels = data[.init(.test, .labels)]! // create the model let model = MLP( layers: layers, inputDimensions: trainImages.dim(-1), hiddenDimensions: hidden, outputDimensions: classes) eval(model.parameters()) let lg = valueAndGrad(model: model, loss) let optimizer = SGD(learningRate: learningRate) func step(_ x: MLXArray, _ y: MLXArray) -> MLXArray { let (loss, grads) = lg(model, x, y) optimizer.update(model: model, gradients: grads) return loss } let resolvedStep = compile ? MLX.compile(inputs: [model, optimizer], outputs: [model, optimizer], step) : step for e in 0 ..< epochs { let start = Date.timeIntervalSinceReferenceDate for (x, y) in iterateBatches( batchSize: batchSize, x: trainImages, y: trainLabels, using: &generator) { _ = resolvedStep(x, y) // eval the parameters so the next iteration is independent eval(model, optimizer) } let accuracy = eval(model: model, x: testImages, y: testLabels) let end = Date.timeIntervalSinceReferenceDate print( """ Epoch \(e): test accuracy \(accuracy.item(Float.self).formatted()) Time: \((end - start).formatted()) """ ) } } }