Add MNIST Digit Prediction/Inference (#22)

* Add Prediction to MNISTTrainer
This commit is contained in:
Rounak
2024-03-18 19:18:41 -07:00
committed by GitHub
parent 0588abec77
commit 9e18eaa479
4 changed files with 168 additions and 9 deletions

View File

@@ -7,10 +7,9 @@ import MLXRandom
import MNIST
import SwiftUI
struct ContentView: View {
struct TrainingView: View {
// the training loop
@State var trainer = Trainer()
@Binding var trainer: Trainer
var body: some View {
VStack {
@@ -24,10 +23,16 @@ struct ContentView: View {
HStack {
Spacer()
Button("Train") {
Task {
try! await trainer.run()
switch trainer.state {
case .untrained:
Button("Train") {
Task {
try! await trainer.run()
}
}
case .trained(let model), .predict(let model):
Button("Draw a digit") {
trainer.state = .predict(model)
}
}
@@ -39,9 +44,30 @@ struct ContentView: View {
}
}
struct ContentView: View {
// the training loop
@State var trainer = Trainer()
var body: some View {
switch trainer.state {
case .untrained, .trained:
TrainingView(trainer: $trainer)
case .predict(let model):
PredictionView(model: model)
}
}
}
@Observable
class Trainer {
enum State {
case untrained
case trained(LeNet)
case predict(LeNet)
}
var state: State = .untrained
var messages = [String]()
func run() async throws {
@@ -101,6 +127,8 @@ class Trainer {
)
}
}
await MainActor.run {
state = .trained(model)
}
}
}

View File

@@ -0,0 +1,127 @@
//
// PredictionView.swift
// MNISTTrainer
//
// Created by Rounak Jain on 3/9/24.
//
import MLX
import MLXNN
import MNIST
import SwiftUI
struct Canvas: View {
@Binding var path: Path
@State var lastPoint: CGPoint?
var body: some View {
path
.stroke(.white, lineWidth: 10)
.background(.black)
.gesture(
DragGesture(minimumDistance: 0.05)
.onChanged { touch in
add(point: touch.location)
}
.onEnded { touch in
lastPoint = nil
}
)
}
func add(point: CGPoint) {
var newPath = path
if let lastPoint {
newPath.move(to: lastPoint)
newPath.addLine(to: point)
} else {
newPath.move(to: point)
}
self.path = newPath
lastPoint = point
}
}
extension Path {
mutating func center(to newMidPoint: CGPoint) {
let middleX = boundingRect.midX
let middleY = boundingRect.midY
self = offsetBy(dx: newMidPoint.x - middleX, dy: newMidPoint.y - middleY)
}
}
struct PredictionView: View {
@State var path: Path = Path()
@State var prediction: Int?
let model: LeNet
let canvasSize = 150.0
let mnistImageSize: CGSize = CGSize(width: 28, height: 28)
var body: some View {
VStack {
if let prediction {
Text("You've drawn a \(prediction)")
} else {
Text("Draw a digit")
}
Canvas(path: $path)
.frame(width: canvasSize, height: canvasSize)
HStack {
Button("Predict") {
path.center(to: CGPoint(x: canvasSize / 2, y: canvasSize / 2))
predict()
}
Button("Clear") {
path = Path()
prediction = nil
}
}
}
}
@MainActor
func predict() {
let imageRenderer = ImageRenderer(
content: Canvas(path: $path).frame(width: 150, height: 150))
guard
let pixelData = imageRenderer.cgImage?.grayscaleImage(with: mnistImageSize)?.pixelData()
else {
return
}
// modify input vector to match training in MNIST/Files.swift
let x = pixelData.reshaped([1, 28, 28, 1]).asType(.float32) / 255.0
prediction = argMax(model(x)).item()
}
}
extension CGImage {
func grayscaleImage(with newSize: CGSize) -> CGImage? {
let colorSpace = CGColorSpaceCreateDeviceGray()
let bitmapInfo = CGBitmapInfo(rawValue: CGImageAlphaInfo.none.rawValue)
guard
let context = CGContext(
data: nil,
width: Int(newSize.width),
height: Int(newSize.height),
bitsPerComponent: 8,
bytesPerRow: Int(newSize.width),
space: colorSpace,
bitmapInfo: bitmapInfo.rawValue)
else {
return nil
}
context.draw(self, in: CGRect(x: 0, y: 0, width: newSize.width, height: newSize.width))
return context.makeImage()
}
func pixelData() -> MLXArray {
guard let data = self.dataProvider?.data else {
return []
}
let bytePtr = CFDataGetBytePtr(data)
let count = CFDataGetLength(data)
return MLXArray(UnsafeBufferPointer(start: bytePtr, count: count))
}
}