Add MNIST Digit Prediction/Inference (#22)

* Add Prediction to MNISTTrainer
This commit is contained in:
Rounak
2024-03-18 19:18:41 -07:00
committed by GitHub
parent 0588abec77
commit 9e18eaa479
4 changed files with 168 additions and 9 deletions

View File

@@ -7,10 +7,9 @@ import MLXRandom
import MNIST
import SwiftUI
struct ContentView: View {
struct TrainingView: View {
// the training loop
@State var trainer = Trainer()
@Binding var trainer: Trainer
var body: some View {
VStack {
@@ -24,10 +23,16 @@ struct ContentView: View {
HStack {
Spacer()
Button("Train") {
Task {
try! await trainer.run()
switch trainer.state {
case .untrained:
Button("Train") {
Task {
try! await trainer.run()
}
}
case .trained(let model), .predict(let model):
Button("Draw a digit") {
trainer.state = .predict(model)
}
}
@@ -39,9 +44,30 @@ struct ContentView: View {
}
}
struct ContentView: View {
// the training loop
@State var trainer = Trainer()
var body: some View {
switch trainer.state {
case .untrained, .trained:
TrainingView(trainer: $trainer)
case .predict(let model):
PredictionView(model: model)
}
}
}
@Observable
class Trainer {
enum State {
case untrained
case trained(LeNet)
case predict(LeNet)
}
var state: State = .untrained
var messages = [String]()
func run() async throws {
@@ -101,6 +127,8 @@ class Trainer {
)
}
}
await MainActor.run {
state = .trained(model)
}
}
}