add MNIST training example

This commit is contained in:
David Koski
2024-03-01 15:55:36 -08:00
parent 2157333905
commit 79e0620891
11 changed files with 576 additions and 1 deletions

View File

@@ -0,0 +1,11 @@
{
"colors" : [
{
"idiom" : "universal"
}
],
"info" : {
"author" : "xcode",
"version" : 1
}
}

View File

@@ -0,0 +1,63 @@
{
"images" : [
{
"idiom" : "universal",
"platform" : "ios",
"size" : "1024x1024"
},
{
"idiom" : "mac",
"scale" : "1x",
"size" : "16x16"
},
{
"idiom" : "mac",
"scale" : "2x",
"size" : "16x16"
},
{
"idiom" : "mac",
"scale" : "1x",
"size" : "32x32"
},
{
"idiom" : "mac",
"scale" : "2x",
"size" : "32x32"
},
{
"idiom" : "mac",
"scale" : "1x",
"size" : "128x128"
},
{
"idiom" : "mac",
"scale" : "2x",
"size" : "128x128"
},
{
"idiom" : "mac",
"scale" : "1x",
"size" : "256x256"
},
{
"idiom" : "mac",
"scale" : "2x",
"size" : "256x256"
},
{
"idiom" : "mac",
"scale" : "1x",
"size" : "512x512"
},
{
"idiom" : "mac",
"scale" : "2x",
"size" : "512x512"
}
],
"info" : {
"author" : "xcode",
"version" : 1
}
}

View File

@@ -0,0 +1,6 @@
{
"info" : {
"author" : "xcode",
"version" : 1
}
}

View File

@@ -0,0 +1,116 @@
// Copyright © 2024 Apple Inc.
import MLX
import MLXNN
import MLXOptimizers
import MLXRandom
import MNIST
import SwiftUI
struct ContentView: View {
// the training loop
@State var trainer = Trainer()
// toggle for cpu/gpu training
@State var cpu = true
var body: some View {
VStack {
Spacer()
ScrollView(.vertical) {
ForEach(trainer.messages, id: \.self) {
Text($0)
}
}
HStack {
Spacer()
Button("Train") {
Task {
try! await trainer.run(device: cpu ? .cpu : .gpu)
}
}
Toggle("CPU", isOn: $cpu)
.frame(maxWidth: 150)
Spacer()
}
Spacer()
}
.padding()
}
}
@Observable
class Trainer {
var messages = [String]()
func run(device: Device = .cpu) async throws {
// Note: this is pretty close to the code in `mnist-tool`, just
// wrapped in an Observable to make it easy to display in SwiftUI
Device.setDefault(device: device)
// download & load the training data
let url = URL(fileURLWithPath: NSTemporaryDirectory(), isDirectory: true)
try await download(into: url)
let data = try load(from: url)
let trainImages = data[.init(.training, .images)]!
let trainLabels = data[.init(.training, .labels)]!
let testImages = data[.init(.test, .images)]!
let testLabels = data[.init(.test, .labels)]!
// create the model with random weights
let model = MLP(
layers: 2, inputDimensions: trainImages.dim(-1), hiddenDimensions: 32,
outputDimensions: 10)
eval(model.parameters())
// the training loop
let lg = valueAndGrad(model: model, loss)
let optimizer = SGD(learningRate: 0.1)
// using a consistent random seed so it behaves the same way each time
MLXRandom.seed(0)
var generator: RandomNumberGenerator = SplitMix64(seed: 0)
for e in 0 ..< 10 {
let start = Date.timeIntervalSinceReferenceDate
for (x, y) in iterateBatches(
batchSize: 256, x: trainImages, y: trainLabels, using: &generator)
{
// loss and gradients
let (_, grads) = lg(model, x, y)
// use SGD to update the weights
optimizer.update(model: model, gradients: grads)
// eval the parameters so the next iteration is independent
eval(model, optimizer)
}
let accuracy = eval(model: model, x: testImages, y: testLabels)
let end = Date.timeIntervalSinceReferenceDate
// add to messages -- triggers display
await MainActor.run {
messages.append(
"""
Epoch \(e): test accuracy \(accuracy.item(Float.self).formatted())
Time: \((end - start).formatted())
"""
)
}
}
}
}

View File

@@ -0,0 +1,17 @@
<?xml version="1.0" encoding="UTF-8"?>
<!DOCTYPE plist PUBLIC "-//Apple//DTD PLIST 1.0//EN" "http://www.apple.com/DTDs/PropertyList-1.0.dtd">
<plist version="1.0">
<dict>
<key>NSAppTransportSecurity</key>
<dict>
<key>NSExceptionDomains</key>
<dict>
<key>yann.lecun.com</key>
<dict>
<key>NSExceptionAllowsInsecureHTTPLoads</key>
<true/>
</dict>
</dict>
</dict>
</dict>
</plist>

View File

@@ -0,0 +1,12 @@
<?xml version="1.0" encoding="UTF-8"?>
<!DOCTYPE plist PUBLIC "-//Apple//DTD PLIST 1.0//EN" "http://www.apple.com/DTDs/PropertyList-1.0.dtd">
<plist version="1.0">
<dict>
<key>com.apple.security.app-sandbox</key>
<true/>
<key>com.apple.security.files.user-selected.read-only</key>
<true/>
<key>com.apple.security.network.client</key>
<true/>
</dict>
</plist>

View File

@@ -0,0 +1,12 @@
// Copyright © 2024 Apple Inc.
import SwiftUI
@main
struct MNISTTrainerApp: App {
var body: some Scene {
WindowGroup {
ContentView()
}
}
}

View File

@@ -0,0 +1,6 @@
{
"info" : {
"author" : "xcode",
"version" : 1
}
}

View File

@@ -0,0 +1,13 @@
# MNISTTrainer
This is an example showing how to do model training on both macOS and iOS.
This will download the MNIST training data, create a new models and train
it. It will show the timing per epoch and the test accuracy as it trains.
You will need to set the Team on the MNISTTrainer target in order to build and
run on iOS.
Some notes about the setup:
- this will download test data over the network so MNISTTrainer -> Signing & Capabilities has the "Outgoing Connections (Client)" set in the App Sandbox
- the website it connects to uses http rather than https so it has a "App Transport Security Settings" in the Info.plist