initial commit
This commit is contained in:
108
Tools/mnist-tool/MNISTTool.swift
Normal file
108
Tools/mnist-tool/MNISTTool.swift
Normal file
@@ -0,0 +1,108 @@
|
||||
// Copyright © 2024 Apple Inc.
|
||||
|
||||
import ArgumentParser
|
||||
import Foundation
|
||||
import MLX
|
||||
import MLXNN
|
||||
import MLXOptimizers
|
||||
import MLXRandom
|
||||
import MNIST
|
||||
|
||||
@main
|
||||
struct MNISTTool: AsyncParsableCommand {
|
||||
static var configuration = CommandConfiguration(
|
||||
abstract: "Command line tool for training mnist models",
|
||||
subcommands: [Train.self],
|
||||
defaultSubcommand: Train.self)
|
||||
}
|
||||
|
||||
extension MLX.DeviceType: ExpressibleByArgument {
|
||||
public init?(argument: String) {
|
||||
self.init(rawValue: argument)
|
||||
}
|
||||
}
|
||||
|
||||
struct Train: AsyncParsableCommand {
|
||||
|
||||
@Option(name: .long, help: "Directory with the training data")
|
||||
var data: String
|
||||
|
||||
@Option(name: .long, help: "The PRNG seed")
|
||||
var seed: UInt64 = 0
|
||||
|
||||
@Option var layers = 2
|
||||
@Option var hidden = 32
|
||||
@Option var batchSize = 256
|
||||
@Option var epochs = 20
|
||||
@Option var learningRate: Float = 1e-1
|
||||
|
||||
@Option var classes = 10
|
||||
|
||||
@Option var device = DeviceType.cpu
|
||||
|
||||
@Flag var compile = false
|
||||
|
||||
func run() async throws {
|
||||
Device.setDefault(device: Device(device))
|
||||
|
||||
MLXRandom.seed(seed)
|
||||
var generator: RandomNumberGenerator = SplitMix64(seed: seed)
|
||||
|
||||
// load the data
|
||||
let url = URL(filePath: data)
|
||||
|
||||
try FileManager.default.createDirectory(at: url, withIntermediateDirectories: true)
|
||||
try await download(into: url)
|
||||
|
||||
let data = try load(from: url)
|
||||
|
||||
let trainImages = data[.init(.training, .images)]!
|
||||
let trainLabels = data[.init(.training, .labels)]!
|
||||
let testImages = data[.init(.test, .images)]!
|
||||
let testLabels = data[.init(.test, .labels)]!
|
||||
|
||||
// create the model
|
||||
let model = MLP(
|
||||
layers: layers, inputDimensions: trainImages.dim(-1), hiddenDimensions: hidden,
|
||||
outputDimensions: classes)
|
||||
eval(model.parameters())
|
||||
|
||||
let lg = valueAndGrad(model: model, loss)
|
||||
let optimizer = SGD(learningRate: learningRate)
|
||||
|
||||
func step(_ x: MLXArray, _ y: MLXArray) -> MLXArray {
|
||||
let (loss, grads) = lg(model, x, y)
|
||||
optimizer.update(model: model, gradients: grads)
|
||||
return loss
|
||||
}
|
||||
|
||||
let resolvedStep =
|
||||
compile
|
||||
? MLX.compile(inputs: [model, optimizer], outputs: [model, optimizer], step) : step
|
||||
|
||||
for e in 0 ..< epochs {
|
||||
let start = Date.timeIntervalSinceReferenceDate
|
||||
|
||||
for (x, y) in iterateBatches(
|
||||
batchSize: batchSize, x: trainImages, y: trainLabels, using: &generator)
|
||||
{
|
||||
_ = resolvedStep(x, y)
|
||||
|
||||
// eval the parameters so the next iteration is independent
|
||||
eval(model, optimizer)
|
||||
}
|
||||
|
||||
let accuracy = eval(model: model, x: testImages, y: testLabels)
|
||||
|
||||
let end = Date.timeIntervalSinceReferenceDate
|
||||
|
||||
print(
|
||||
"""
|
||||
Epoch \(e): test accuracy \(accuracy.item(Float.self).formatted())
|
||||
Time: \((end - start).formatted())
|
||||
|
||||
"""
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
36
Tools/mnist-tool/README.md
Normal file
36
Tools/mnist-tool/README.md
Normal file
@@ -0,0 +1,36 @@
|
||||
# mnist-tool
|
||||
|
||||
See other README:
|
||||
|
||||
- [MNIST](../../Libraries/MNIST/README.md)
|
||||
|
||||
### Building
|
||||
|
||||
`mnist-tool` has no dependencies outside of the package dependencies
|
||||
represented in xcode.
|
||||
|
||||
When you run the tool it will download the test/train datasets and
|
||||
store them in a specified directory (see run arguments -- default is /tmp).
|
||||
|
||||
Simply build the project in xcode.
|
||||
|
||||
### Running (Xcode)
|
||||
|
||||
To run this in Xcode simply press cmd-opt-r to set the scheme arguments. For example:
|
||||
|
||||
```
|
||||
--data /tmp
|
||||
```
|
||||
|
||||
Then cmd-r to run.
|
||||
|
||||
### Running (CommandLine)
|
||||
|
||||
`mnist-tool` can also be run from the command line if built from Xcode, but
|
||||
the `DYLD_FRAMEWORK_PATH` must be set so that the frameworks and bundles can be found:
|
||||
|
||||
- [MLX troubleshooting](https://ml-explore.github.io/mlx-swift/MLX/documentation/mlx/troubleshooting)
|
||||
|
||||
```
|
||||
DYLD_FRAMEWORK_PATH=~/Library/Developer/Xcode/DerivedData/mlx-examples-swift-ceuohnhzsownvsbbleukxoksddja/Build/Products/Debug ~/Library/Developer/Xcode/DerivedData/mlx-examples-swift-ceuohnhzsownvsbbleukxoksddja/Build/Products/Debug/mnist-tool --data /tmp
|
||||
```
|
||||
Reference in New Issue
Block a user