initial commit

This commit is contained in:
David Koski
2024-02-22 10:41:02 -08:00
commit b6d1e14465
29 changed files with 3856 additions and 0 deletions

View File

@@ -0,0 +1,108 @@
// Copyright © 2024 Apple Inc.
import ArgumentParser
import Foundation
import MLX
import MLXNN
import MLXOptimizers
import MLXRandom
import MNIST
@main
struct MNISTTool: AsyncParsableCommand {
static var configuration = CommandConfiguration(
abstract: "Command line tool for training mnist models",
subcommands: [Train.self],
defaultSubcommand: Train.self)
}
extension MLX.DeviceType: ExpressibleByArgument {
public init?(argument: String) {
self.init(rawValue: argument)
}
}
struct Train: AsyncParsableCommand {
@Option(name: .long, help: "Directory with the training data")
var data: String
@Option(name: .long, help: "The PRNG seed")
var seed: UInt64 = 0
@Option var layers = 2
@Option var hidden = 32
@Option var batchSize = 256
@Option var epochs = 20
@Option var learningRate: Float = 1e-1
@Option var classes = 10
@Option var device = DeviceType.cpu
@Flag var compile = false
func run() async throws {
Device.setDefault(device: Device(device))
MLXRandom.seed(seed)
var generator: RandomNumberGenerator = SplitMix64(seed: seed)
// load the data
let url = URL(filePath: data)
try FileManager.default.createDirectory(at: url, withIntermediateDirectories: true)
try await download(into: url)
let data = try load(from: url)
let trainImages = data[.init(.training, .images)]!
let trainLabels = data[.init(.training, .labels)]!
let testImages = data[.init(.test, .images)]!
let testLabels = data[.init(.test, .labels)]!
// create the model
let model = MLP(
layers: layers, inputDimensions: trainImages.dim(-1), hiddenDimensions: hidden,
outputDimensions: classes)
eval(model.parameters())
let lg = valueAndGrad(model: model, loss)
let optimizer = SGD(learningRate: learningRate)
func step(_ x: MLXArray, _ y: MLXArray) -> MLXArray {
let (loss, grads) = lg(model, x, y)
optimizer.update(model: model, gradients: grads)
return loss
}
let resolvedStep =
compile
? MLX.compile(inputs: [model, optimizer], outputs: [model, optimizer], step) : step
for e in 0 ..< epochs {
let start = Date.timeIntervalSinceReferenceDate
for (x, y) in iterateBatches(
batchSize: batchSize, x: trainImages, y: trainLabels, using: &generator)
{
_ = resolvedStep(x, y)
// eval the parameters so the next iteration is independent
eval(model, optimizer)
}
let accuracy = eval(model: model, x: testImages, y: testLabels)
let end = Date.timeIntervalSinceReferenceDate
print(
"""
Epoch \(e): test accuracy \(accuracy.item(Float.self).formatted())
Time: \((end - start).formatted())
"""
)
}
}
}

View File

@@ -0,0 +1,36 @@
# mnist-tool
See other README:
- [MNIST](../../Libraries/MNIST/README.md)
### Building
`mnist-tool` has no dependencies outside of the package dependencies
represented in xcode.
When you run the tool it will download the test/train datasets and
store them in a specified directory (see run arguments -- default is /tmp).
Simply build the project in xcode.
### Running (Xcode)
To run this in Xcode simply press cmd-opt-r to set the scheme arguments. For example:
```
--data /tmp
```
Then cmd-r to run.
### Running (CommandLine)
`mnist-tool` can also be run from the command line if built from Xcode, but
the `DYLD_FRAMEWORK_PATH` must be set so that the frameworks and bundles can be found:
- [MLX troubleshooting](https://ml-explore.github.io/mlx-swift/MLX/documentation/mlx/troubleshooting)
```
DYLD_FRAMEWORK_PATH=~/Library/Developer/Xcode/DerivedData/mlx-examples-swift-ceuohnhzsownvsbbleukxoksddja/Build/Products/Debug ~/Library/Developer/Xcode/DerivedData/mlx-examples-swift-ceuohnhzsownvsbbleukxoksddja/Build/Products/Debug/mnist-tool --data /tmp
```