LeNet on MNIST + readme update (#12)

* LeNet on MNIST + readme update

* tanh + remove device toggle

* remove device entirely
This commit is contained in:
Awni Hannun
2024-03-04 14:16:20 -08:00
committed by GitHub
parent dfc9f2fc01
commit 4ed4ec69e7
8 changed files with 56 additions and 86 deletions

View File

@@ -12,9 +12,6 @@ struct ContentView: View {
// the training loop
@State var trainer = Trainer()
// toggle for cpu/gpu training
@State var cpu = true
var body: some View {
VStack {
Spacer()
@@ -30,13 +27,10 @@ struct ContentView: View {
Button("Train") {
Task {
try! await trainer.run(device: cpu ? .cpu : .gpu)
try! await trainer.run()
}
}
Toggle("CPU", isOn: $cpu)
.frame(maxWidth: 150)
Spacer()
}
Spacer()
@@ -50,12 +44,10 @@ class Trainer {
var messages = [String]()
func run(device: Device = .cpu) async throws {
func run() async throws {
// Note: this is pretty close to the code in `mnist-tool`, just
// wrapped in an Observable to make it easy to display in SwiftUI
Device.setDefault(device: device)
// download & load the training data
let url = URL(fileURLWithPath: NSTemporaryDirectory(), isDirectory: true)
try await download(into: url)
@@ -67,9 +59,7 @@ class Trainer {
let testLabels = data[.init(.test, .labels)]!
// create the model with random weights
let model = MLP(
layers: 2, inputDimensions: trainImages.dim(-1), hiddenDimensions: 32,
outputDimensions: 10)
let model = LeNet()
eval(model.parameters())
// the training loop

View File

@@ -1,13 +1,13 @@
# MNISTTrainer
This is an example showing how to do model training on both macOS and iOS.
This will download the MNIST training data, create a new models and train
it. It will show the timing per epoch and the test accuracy as it trains.
This is an example of model training that works on both macOS and iOS.
The example will download the MNIST training data, create a LeNet, and train
it. It will show the epoch time and test accuracy as it trains.
You will need to set the Team on the MNISTTrainer target in order to build and
run on iOS.
Some notes about the setup:
- this will download test data over the network so MNISTTrainer -> Signing & Capabilities has the "Outgoing Connections (Client)" set in the App Sandbox
- the website it connects to uses http rather than https so it has a "App Transport Security Settings" in the Info.plist
- This will download test data over the network so MNISTTrainer -> Signing & Capabilities has the "Outgoing Connections (Client)" set in the App Sandbox
- The website it connects to uses http rather than https so it has a "App Transport Security Settings" in the Info.plist