Files
mlx-swift-examples/Applications/MNISTTrainer/ContentView.swift
Rounak 9e18eaa479 Add MNIST Digit Prediction/Inference (#22)
* Add Prediction to MNISTTrainer
2024-03-18 19:18:41 -07:00

135 lines
3.6 KiB
Swift

// Copyright © 2024 Apple Inc.
import MLX
import MLXNN
import MLXOptimizers
import MLXRandom
import MNIST
import SwiftUI
struct TrainingView: View {
@Binding var trainer: Trainer
var body: some View {
VStack {
Spacer()
ScrollView(.vertical) {
ForEach(trainer.messages, id: \.self) {
Text($0)
}
}
HStack {
Spacer()
switch trainer.state {
case .untrained:
Button("Train") {
Task {
try! await trainer.run()
}
}
case .trained(let model), .predict(let model):
Button("Draw a digit") {
trainer.state = .predict(model)
}
}
Spacer()
}
Spacer()
}
.padding()
}
}
struct ContentView: View {
// the training loop
@State var trainer = Trainer()
var body: some View {
switch trainer.state {
case .untrained, .trained:
TrainingView(trainer: $trainer)
case .predict(let model):
PredictionView(model: model)
}
}
}
@Observable
class Trainer {
enum State {
case untrained
case trained(LeNet)
case predict(LeNet)
}
var state: State = .untrained
var messages = [String]()
func run() async throws {
// Note: this is pretty close to the code in `mnist-tool`, just
// wrapped in an Observable to make it easy to display in SwiftUI
// download & load the training data
let url = URL(fileURLWithPath: NSTemporaryDirectory(), isDirectory: true)
try await download(into: url)
let data = try load(from: url)
let trainImages = data[.init(.training, .images)]!
let trainLabels = data[.init(.training, .labels)]!
let testImages = data[.init(.test, .images)]!
let testLabels = data[.init(.test, .labels)]!
// create the model with random weights
let model = LeNet()
eval(model.parameters())
// the training loop
let lg = valueAndGrad(model: model, loss)
let optimizer = SGD(learningRate: 0.1)
// using a consistent random seed so it behaves the same way each time
MLXRandom.seed(0)
var generator: RandomNumberGenerator = SplitMix64(seed: 0)
for e in 0 ..< 10 {
let start = Date.timeIntervalSinceReferenceDate
for (x, y) in iterateBatches(
batchSize: 256, x: trainImages, y: trainLabels, using: &generator)
{
// loss and gradients
let (_, grads) = lg(model, x, y)
// use SGD to update the weights
optimizer.update(model: model, gradients: grads)
// eval the parameters so the next iteration is independent
eval(model, optimizer)
}
let accuracy = eval(model: model, x: testImages, y: testLabels)
let end = Date.timeIntervalSinceReferenceDate
// add to messages -- triggers display
await MainActor.run {
messages.append(
"""
Epoch \(e): test accuracy \(accuracy.item(Float.self).formatted())
Time: \((end - start).formatted())
"""
)
}
}
await MainActor.run {
state = .trained(model)
}
}
}