Add MNIST Digit Prediction/Inference (#22)
* Add Prediction to MNISTTrainer
This commit is contained in:
@@ -7,10 +7,9 @@ import MLXRandom
|
||||
import MNIST
|
||||
import SwiftUI
|
||||
|
||||
struct ContentView: View {
|
||||
struct TrainingView: View {
|
||||
|
||||
// the training loop
|
||||
@State var trainer = Trainer()
|
||||
@Binding var trainer: Trainer
|
||||
|
||||
var body: some View {
|
||||
VStack {
|
||||
@@ -24,10 +23,16 @@ struct ContentView: View {
|
||||
|
||||
HStack {
|
||||
Spacer()
|
||||
|
||||
Button("Train") {
|
||||
Task {
|
||||
try! await trainer.run()
|
||||
switch trainer.state {
|
||||
case .untrained:
|
||||
Button("Train") {
|
||||
Task {
|
||||
try! await trainer.run()
|
||||
}
|
||||
}
|
||||
case .trained(let model), .predict(let model):
|
||||
Button("Draw a digit") {
|
||||
trainer.state = .predict(model)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -39,9 +44,30 @@ struct ContentView: View {
|
||||
}
|
||||
}
|
||||
|
||||
struct ContentView: View {
|
||||
// the training loop
|
||||
@State var trainer = Trainer()
|
||||
|
||||
var body: some View {
|
||||
switch trainer.state {
|
||||
case .untrained, .trained:
|
||||
TrainingView(trainer: $trainer)
|
||||
case .predict(let model):
|
||||
PredictionView(model: model)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@Observable
|
||||
class Trainer {
|
||||
|
||||
enum State {
|
||||
case untrained
|
||||
case trained(LeNet)
|
||||
case predict(LeNet)
|
||||
}
|
||||
|
||||
var state: State = .untrained
|
||||
var messages = [String]()
|
||||
|
||||
func run() async throws {
|
||||
@@ -101,6 +127,8 @@ class Trainer {
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
await MainActor.run {
|
||||
state = .trained(model)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user