LeNet on MNIST + readme update (#12)
* LeNet on MNIST + readme update * tanh + remove device toggle * remove device entirely
This commit is contained in:
@@ -12,9 +12,6 @@ struct ContentView: View {
|
||||
// the training loop
|
||||
@State var trainer = Trainer()
|
||||
|
||||
// toggle for cpu/gpu training
|
||||
@State var cpu = true
|
||||
|
||||
var body: some View {
|
||||
VStack {
|
||||
Spacer()
|
||||
@@ -30,13 +27,10 @@ struct ContentView: View {
|
||||
|
||||
Button("Train") {
|
||||
Task {
|
||||
try! await trainer.run(device: cpu ? .cpu : .gpu)
|
||||
try! await trainer.run()
|
||||
}
|
||||
}
|
||||
|
||||
Toggle("CPU", isOn: $cpu)
|
||||
.frame(maxWidth: 150)
|
||||
|
||||
Spacer()
|
||||
}
|
||||
Spacer()
|
||||
@@ -50,12 +44,10 @@ class Trainer {
|
||||
|
||||
var messages = [String]()
|
||||
|
||||
func run(device: Device = .cpu) async throws {
|
||||
func run() async throws {
|
||||
// Note: this is pretty close to the code in `mnist-tool`, just
|
||||
// wrapped in an Observable to make it easy to display in SwiftUI
|
||||
|
||||
Device.setDefault(device: device)
|
||||
|
||||
// download & load the training data
|
||||
let url = URL(fileURLWithPath: NSTemporaryDirectory(), isDirectory: true)
|
||||
try await download(into: url)
|
||||
@@ -67,9 +59,7 @@ class Trainer {
|
||||
let testLabels = data[.init(.test, .labels)]!
|
||||
|
||||
// create the model with random weights
|
||||
let model = MLP(
|
||||
layers: 2, inputDimensions: trainImages.dim(-1), hiddenDimensions: 32,
|
||||
outputDimensions: 10)
|
||||
let model = LeNet()
|
||||
eval(model.parameters())
|
||||
|
||||
// the training loop
|
||||
|
||||
Reference in New Issue
Block a user