Add MNIST Digit Prediction/Inference (#22)
* Add Prediction to MNISTTrainer
This commit is contained in:
127
Applications/MNISTTrainer/PredictionView.swift
Normal file
127
Applications/MNISTTrainer/PredictionView.swift
Normal file
@@ -0,0 +1,127 @@
|
||||
//
|
||||
// PredictionView.swift
|
||||
// MNISTTrainer
|
||||
//
|
||||
// Created by Rounak Jain on 3/9/24.
|
||||
//
|
||||
|
||||
import MLX
|
||||
import MLXNN
|
||||
import MNIST
|
||||
import SwiftUI
|
||||
|
||||
struct Canvas: View {
|
||||
|
||||
@Binding var path: Path
|
||||
@State var lastPoint: CGPoint?
|
||||
|
||||
var body: some View {
|
||||
path
|
||||
.stroke(.white, lineWidth: 10)
|
||||
.background(.black)
|
||||
.gesture(
|
||||
DragGesture(minimumDistance: 0.05)
|
||||
.onChanged { touch in
|
||||
add(point: touch.location)
|
||||
}
|
||||
.onEnded { touch in
|
||||
lastPoint = nil
|
||||
}
|
||||
)
|
||||
}
|
||||
|
||||
func add(point: CGPoint) {
|
||||
var newPath = path
|
||||
if let lastPoint {
|
||||
newPath.move(to: lastPoint)
|
||||
newPath.addLine(to: point)
|
||||
} else {
|
||||
newPath.move(to: point)
|
||||
}
|
||||
self.path = newPath
|
||||
lastPoint = point
|
||||
}
|
||||
}
|
||||
|
||||
extension Path {
|
||||
mutating func center(to newMidPoint: CGPoint) {
|
||||
let middleX = boundingRect.midX
|
||||
let middleY = boundingRect.midY
|
||||
self = offsetBy(dx: newMidPoint.x - middleX, dy: newMidPoint.y - middleY)
|
||||
}
|
||||
}
|
||||
|
||||
struct PredictionView: View {
|
||||
@State var path: Path = Path()
|
||||
@State var prediction: Int?
|
||||
let model: LeNet
|
||||
let canvasSize = 150.0
|
||||
let mnistImageSize: CGSize = CGSize(width: 28, height: 28)
|
||||
|
||||
var body: some View {
|
||||
VStack {
|
||||
if let prediction {
|
||||
Text("You've drawn a \(prediction)")
|
||||
} else {
|
||||
Text("Draw a digit")
|
||||
}
|
||||
Canvas(path: $path)
|
||||
.frame(width: canvasSize, height: canvasSize)
|
||||
HStack {
|
||||
Button("Predict") {
|
||||
path.center(to: CGPoint(x: canvasSize / 2, y: canvasSize / 2))
|
||||
predict()
|
||||
}
|
||||
Button("Clear") {
|
||||
path = Path()
|
||||
prediction = nil
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@MainActor
|
||||
func predict() {
|
||||
let imageRenderer = ImageRenderer(
|
||||
content: Canvas(path: $path).frame(width: 150, height: 150))
|
||||
guard
|
||||
let pixelData = imageRenderer.cgImage?.grayscaleImage(with: mnistImageSize)?.pixelData()
|
||||
else {
|
||||
return
|
||||
}
|
||||
// modify input vector to match training in MNIST/Files.swift
|
||||
let x = pixelData.reshaped([1, 28, 28, 1]).asType(.float32) / 255.0
|
||||
prediction = argMax(model(x)).item()
|
||||
}
|
||||
}
|
||||
|
||||
extension CGImage {
|
||||
func grayscaleImage(with newSize: CGSize) -> CGImage? {
|
||||
let colorSpace = CGColorSpaceCreateDeviceGray()
|
||||
let bitmapInfo = CGBitmapInfo(rawValue: CGImageAlphaInfo.none.rawValue)
|
||||
|
||||
guard
|
||||
let context = CGContext(
|
||||
data: nil,
|
||||
width: Int(newSize.width),
|
||||
height: Int(newSize.height),
|
||||
bitsPerComponent: 8,
|
||||
bytesPerRow: Int(newSize.width),
|
||||
space: colorSpace,
|
||||
bitmapInfo: bitmapInfo.rawValue)
|
||||
else {
|
||||
return nil
|
||||
}
|
||||
context.draw(self, in: CGRect(x: 0, y: 0, width: newSize.width, height: newSize.width))
|
||||
return context.makeImage()
|
||||
}
|
||||
|
||||
func pixelData() -> MLXArray {
|
||||
guard let data = self.dataProvider?.data else {
|
||||
return []
|
||||
}
|
||||
let bytePtr = CFDataGetBytePtr(data)
|
||||
let count = CFDataGetLength(data)
|
||||
return MLXArray(UnsafeBufferPointer(start: bytePtr, count: count))
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user