LeNet on MNIST + readme update (#12)
* LeNet on MNIST + readme update * tanh + remove device toggle * remove device entirely
This commit is contained in:
@@ -12,9 +12,6 @@ struct ContentView: View {
|
||||
// the training loop
|
||||
@State var trainer = Trainer()
|
||||
|
||||
// toggle for cpu/gpu training
|
||||
@State var cpu = true
|
||||
|
||||
var body: some View {
|
||||
VStack {
|
||||
Spacer()
|
||||
@@ -30,13 +27,10 @@ struct ContentView: View {
|
||||
|
||||
Button("Train") {
|
||||
Task {
|
||||
try! await trainer.run(device: cpu ? .cpu : .gpu)
|
||||
try! await trainer.run()
|
||||
}
|
||||
}
|
||||
|
||||
Toggle("CPU", isOn: $cpu)
|
||||
.frame(maxWidth: 150)
|
||||
|
||||
Spacer()
|
||||
}
|
||||
Spacer()
|
||||
@@ -50,12 +44,10 @@ class Trainer {
|
||||
|
||||
var messages = [String]()
|
||||
|
||||
func run(device: Device = .cpu) async throws {
|
||||
func run() async throws {
|
||||
// Note: this is pretty close to the code in `mnist-tool`, just
|
||||
// wrapped in an Observable to make it easy to display in SwiftUI
|
||||
|
||||
Device.setDefault(device: device)
|
||||
|
||||
// download & load the training data
|
||||
let url = URL(fileURLWithPath: NSTemporaryDirectory(), isDirectory: true)
|
||||
try await download(into: url)
|
||||
@@ -67,9 +59,7 @@ class Trainer {
|
||||
let testLabels = data[.init(.test, .labels)]!
|
||||
|
||||
// create the model with random weights
|
||||
let model = MLP(
|
||||
layers: 2, inputDimensions: trainImages.dim(-1), hiddenDimensions: 32,
|
||||
outputDimensions: 10)
|
||||
let model = LeNet()
|
||||
eval(model.parameters())
|
||||
|
||||
// the training loop
|
||||
|
||||
@@ -1,13 +1,13 @@
|
||||
# MNISTTrainer
|
||||
|
||||
This is an example showing how to do model training on both macOS and iOS.
|
||||
This will download the MNIST training data, create a new models and train
|
||||
it. It will show the timing per epoch and the test accuracy as it trains.
|
||||
This is an example of model training that works on both macOS and iOS.
|
||||
The example will download the MNIST training data, create a LeNet, and train
|
||||
it. It will show the epoch time and test accuracy as it trains.
|
||||
|
||||
You will need to set the Team on the MNISTTrainer target in order to build and
|
||||
run on iOS.
|
||||
|
||||
Some notes about the setup:
|
||||
|
||||
- this will download test data over the network so MNISTTrainer -> Signing & Capabilities has the "Outgoing Connections (Client)" set in the App Sandbox
|
||||
- the website it connects to uses http rather than https so it has a "App Transport Security Settings" in the Info.plist
|
||||
- This will download test data over the network so MNISTTrainer -> Signing & Capabilities has the "Outgoing Connections (Client)" set in the App Sandbox
|
||||
- The website it connects to uses http rather than https so it has a "App Transport Security Settings" in the Info.plist
|
||||
|
||||
Reference in New Issue
Block a user