add MNIST training example
This commit is contained in:
@@ -0,0 +1,11 @@
|
||||
{
|
||||
"colors" : [
|
||||
{
|
||||
"idiom" : "universal"
|
||||
}
|
||||
],
|
||||
"info" : {
|
||||
"author" : "xcode",
|
||||
"version" : 1
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,63 @@
|
||||
{
|
||||
"images" : [
|
||||
{
|
||||
"idiom" : "universal",
|
||||
"platform" : "ios",
|
||||
"size" : "1024x1024"
|
||||
},
|
||||
{
|
||||
"idiom" : "mac",
|
||||
"scale" : "1x",
|
||||
"size" : "16x16"
|
||||
},
|
||||
{
|
||||
"idiom" : "mac",
|
||||
"scale" : "2x",
|
||||
"size" : "16x16"
|
||||
},
|
||||
{
|
||||
"idiom" : "mac",
|
||||
"scale" : "1x",
|
||||
"size" : "32x32"
|
||||
},
|
||||
{
|
||||
"idiom" : "mac",
|
||||
"scale" : "2x",
|
||||
"size" : "32x32"
|
||||
},
|
||||
{
|
||||
"idiom" : "mac",
|
||||
"scale" : "1x",
|
||||
"size" : "128x128"
|
||||
},
|
||||
{
|
||||
"idiom" : "mac",
|
||||
"scale" : "2x",
|
||||
"size" : "128x128"
|
||||
},
|
||||
{
|
||||
"idiom" : "mac",
|
||||
"scale" : "1x",
|
||||
"size" : "256x256"
|
||||
},
|
||||
{
|
||||
"idiom" : "mac",
|
||||
"scale" : "2x",
|
||||
"size" : "256x256"
|
||||
},
|
||||
{
|
||||
"idiom" : "mac",
|
||||
"scale" : "1x",
|
||||
"size" : "512x512"
|
||||
},
|
||||
{
|
||||
"idiom" : "mac",
|
||||
"scale" : "2x",
|
||||
"size" : "512x512"
|
||||
}
|
||||
],
|
||||
"info" : {
|
||||
"author" : "xcode",
|
||||
"version" : 1
|
||||
}
|
||||
}
|
||||
6
Applications/MNISTTrainer/Assets.xcassets/Contents.json
Normal file
6
Applications/MNISTTrainer/Assets.xcassets/Contents.json
Normal file
@@ -0,0 +1,6 @@
|
||||
{
|
||||
"info" : {
|
||||
"author" : "xcode",
|
||||
"version" : 1
|
||||
}
|
||||
}
|
||||
116
Applications/MNISTTrainer/ContentView.swift
Normal file
116
Applications/MNISTTrainer/ContentView.swift
Normal file
@@ -0,0 +1,116 @@
|
||||
// Copyright © 2024 Apple Inc.
|
||||
|
||||
import MLX
|
||||
import MLXNN
|
||||
import MLXOptimizers
|
||||
import MLXRandom
|
||||
import MNIST
|
||||
import SwiftUI
|
||||
|
||||
struct ContentView: View {
|
||||
|
||||
// the training loop
|
||||
@State var trainer = Trainer()
|
||||
|
||||
// toggle for cpu/gpu training
|
||||
@State var cpu = true
|
||||
|
||||
var body: some View {
|
||||
VStack {
|
||||
Spacer()
|
||||
|
||||
ScrollView(.vertical) {
|
||||
ForEach(trainer.messages, id: \.self) {
|
||||
Text($0)
|
||||
}
|
||||
}
|
||||
|
||||
HStack {
|
||||
Spacer()
|
||||
|
||||
Button("Train") {
|
||||
Task {
|
||||
try! await trainer.run(device: cpu ? .cpu : .gpu)
|
||||
}
|
||||
}
|
||||
|
||||
Toggle("CPU", isOn: $cpu)
|
||||
.frame(maxWidth: 150)
|
||||
|
||||
Spacer()
|
||||
}
|
||||
Spacer()
|
||||
}
|
||||
.padding()
|
||||
}
|
||||
}
|
||||
|
||||
@Observable
|
||||
class Trainer {
|
||||
|
||||
var messages = [String]()
|
||||
|
||||
func run(device: Device = .cpu) async throws {
|
||||
// Note: this is pretty close to the code in `mnist-tool`, just
|
||||
// wrapped in an Observable to make it easy to display in SwiftUI
|
||||
|
||||
Device.setDefault(device: device)
|
||||
|
||||
// download & load the training data
|
||||
let url = URL(fileURLWithPath: NSTemporaryDirectory(), isDirectory: true)
|
||||
try await download(into: url)
|
||||
let data = try load(from: url)
|
||||
|
||||
let trainImages = data[.init(.training, .images)]!
|
||||
let trainLabels = data[.init(.training, .labels)]!
|
||||
let testImages = data[.init(.test, .images)]!
|
||||
let testLabels = data[.init(.test, .labels)]!
|
||||
|
||||
// create the model with random weights
|
||||
let model = MLP(
|
||||
layers: 2, inputDimensions: trainImages.dim(-1), hiddenDimensions: 32,
|
||||
outputDimensions: 10)
|
||||
eval(model.parameters())
|
||||
|
||||
// the training loop
|
||||
let lg = valueAndGrad(model: model, loss)
|
||||
let optimizer = SGD(learningRate: 0.1)
|
||||
|
||||
// using a consistent random seed so it behaves the same way each time
|
||||
MLXRandom.seed(0)
|
||||
var generator: RandomNumberGenerator = SplitMix64(seed: 0)
|
||||
|
||||
for e in 0 ..< 10 {
|
||||
let start = Date.timeIntervalSinceReferenceDate
|
||||
|
||||
for (x, y) in iterateBatches(
|
||||
batchSize: 256, x: trainImages, y: trainLabels, using: &generator)
|
||||
{
|
||||
// loss and gradients
|
||||
let (_, grads) = lg(model, x, y)
|
||||
|
||||
// use SGD to update the weights
|
||||
optimizer.update(model: model, gradients: grads)
|
||||
|
||||
// eval the parameters so the next iteration is independent
|
||||
eval(model, optimizer)
|
||||
}
|
||||
|
||||
let accuracy = eval(model: model, x: testImages, y: testLabels)
|
||||
|
||||
let end = Date.timeIntervalSinceReferenceDate
|
||||
|
||||
// add to messages -- triggers display
|
||||
await MainActor.run {
|
||||
messages.append(
|
||||
"""
|
||||
Epoch \(e): test accuracy \(accuracy.item(Float.self).formatted())
|
||||
Time: \((end - start).formatted())
|
||||
|
||||
"""
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
17
Applications/MNISTTrainer/MNISTTrainer-Info.plist
Normal file
17
Applications/MNISTTrainer/MNISTTrainer-Info.plist
Normal file
@@ -0,0 +1,17 @@
|
||||
<?xml version="1.0" encoding="UTF-8"?>
|
||||
<!DOCTYPE plist PUBLIC "-//Apple//DTD PLIST 1.0//EN" "http://www.apple.com/DTDs/PropertyList-1.0.dtd">
|
||||
<plist version="1.0">
|
||||
<dict>
|
||||
<key>NSAppTransportSecurity</key>
|
||||
<dict>
|
||||
<key>NSExceptionDomains</key>
|
||||
<dict>
|
||||
<key>yann.lecun.com</key>
|
||||
<dict>
|
||||
<key>NSExceptionAllowsInsecureHTTPLoads</key>
|
||||
<true/>
|
||||
</dict>
|
||||
</dict>
|
||||
</dict>
|
||||
</dict>
|
||||
</plist>
|
||||
12
Applications/MNISTTrainer/MNISTTrainer.entitlements
Normal file
12
Applications/MNISTTrainer/MNISTTrainer.entitlements
Normal file
@@ -0,0 +1,12 @@
|
||||
<?xml version="1.0" encoding="UTF-8"?>
|
||||
<!DOCTYPE plist PUBLIC "-//Apple//DTD PLIST 1.0//EN" "http://www.apple.com/DTDs/PropertyList-1.0.dtd">
|
||||
<plist version="1.0">
|
||||
<dict>
|
||||
<key>com.apple.security.app-sandbox</key>
|
||||
<true/>
|
||||
<key>com.apple.security.files.user-selected.read-only</key>
|
||||
<true/>
|
||||
<key>com.apple.security.network.client</key>
|
||||
<true/>
|
||||
</dict>
|
||||
</plist>
|
||||
12
Applications/MNISTTrainer/MNISTTrainerApp.swift
Normal file
12
Applications/MNISTTrainer/MNISTTrainerApp.swift
Normal file
@@ -0,0 +1,12 @@
|
||||
// Copyright © 2024 Apple Inc.
|
||||
|
||||
import SwiftUI
|
||||
|
||||
@main
|
||||
struct MNISTTrainerApp: App {
|
||||
var body: some Scene {
|
||||
WindowGroup {
|
||||
ContentView()
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,6 @@
|
||||
{
|
||||
"info" : {
|
||||
"author" : "xcode",
|
||||
"version" : 1
|
||||
}
|
||||
}
|
||||
13
Applications/MNISTTrainer/README.md
Normal file
13
Applications/MNISTTrainer/README.md
Normal file
@@ -0,0 +1,13 @@
|
||||
# MNISTTrainer
|
||||
|
||||
This is an example showing how to do model training on both macOS and iOS.
|
||||
This will download the MNIST training data, create a new models and train
|
||||
it. It will show the timing per epoch and the test accuracy as it trains.
|
||||
|
||||
You will need to set the Team on the MNISTTrainer target in order to build and
|
||||
run on iOS.
|
||||
|
||||
Some notes about the setup:
|
||||
|
||||
- this will download test data over the network so MNISTTrainer -> Signing & Capabilities has the "Outgoing Connections (Client)" set in the App Sandbox
|
||||
- the website it connects to uses http rather than https so it has a "App Transport Security Settings" in the Info.plist
|
||||
Reference in New Issue
Block a user