LeNet on MNIST + readme update (#12)

* LeNet on MNIST + readme update

* tanh + remove device toggle

* remove device entirely
This commit is contained in:
Awni Hannun
2024-03-04 14:16:20 -08:00
committed by GitHub
parent dfc9f2fc01
commit 4ed4ec69e7
8 changed files with 56 additions and 86 deletions

View File

@@ -12,9 +12,6 @@ struct ContentView: View {
// the training loop
@State var trainer = Trainer()
// toggle for cpu/gpu training
@State var cpu = true
var body: some View {
VStack {
Spacer()
@@ -30,13 +27,10 @@ struct ContentView: View {
Button("Train") {
Task {
try! await trainer.run(device: cpu ? .cpu : .gpu)
try! await trainer.run()
}
}
Toggle("CPU", isOn: $cpu)
.frame(maxWidth: 150)
Spacer()
}
Spacer()
@@ -50,12 +44,10 @@ class Trainer {
var messages = [String]()
func run(device: Device = .cpu) async throws {
func run() async throws {
// Note: this is pretty close to the code in `mnist-tool`, just
// wrapped in an Observable to make it easy to display in SwiftUI
Device.setDefault(device: device)
// download & load the training data
let url = URL(fileURLWithPath: NSTemporaryDirectory(), isDirectory: true)
try await download(into: url)
@@ -67,9 +59,7 @@ class Trainer {
let testLabels = data[.init(.test, .labels)]!
// create the model with random weights
let model = MLP(
layers: 2, inputDimensions: trainImages.dim(-1), hiddenDimensions: 32,
outputDimensions: 10)
let model = LeNet()
eval(model.parameters())
// the training loop