128 lines
3.5 KiB
Swift
128 lines
3.5 KiB
Swift
//
|
|
// PredictionView.swift
|
|
// MNISTTrainer
|
|
//
|
|
// Created by Rounak Jain on 3/9/24.
|
|
//
|
|
|
|
import MLX
|
|
import MLXNN
|
|
import MNIST
|
|
import SwiftUI
|
|
|
|
struct Canvas: View {
|
|
|
|
@Binding var path: Path
|
|
@State var lastPoint: CGPoint?
|
|
|
|
var body: some View {
|
|
path
|
|
.stroke(.white, lineWidth: 10)
|
|
.background(.black)
|
|
.gesture(
|
|
DragGesture(minimumDistance: 0.05)
|
|
.onChanged { touch in
|
|
add(point: touch.location)
|
|
}
|
|
.onEnded { touch in
|
|
lastPoint = nil
|
|
}
|
|
)
|
|
}
|
|
|
|
func add(point: CGPoint) {
|
|
var newPath = path
|
|
if let lastPoint {
|
|
newPath.move(to: lastPoint)
|
|
newPath.addLine(to: point)
|
|
} else {
|
|
newPath.move(to: point)
|
|
}
|
|
self.path = newPath
|
|
lastPoint = point
|
|
}
|
|
}
|
|
|
|
extension Path {
|
|
mutating func center(to newMidPoint: CGPoint) {
|
|
let middleX = boundingRect.midX
|
|
let middleY = boundingRect.midY
|
|
self = offsetBy(dx: newMidPoint.x - middleX, dy: newMidPoint.y - middleY)
|
|
}
|
|
}
|
|
|
|
struct PredictionView: View {
|
|
@State var path: Path = Path()
|
|
@State var prediction: Int?
|
|
let model: LeNet
|
|
let canvasSize = 150.0
|
|
let mnistImageSize: CGSize = CGSize(width: 28, height: 28)
|
|
|
|
var body: some View {
|
|
VStack {
|
|
if let prediction {
|
|
Text("You've drawn a \(prediction)")
|
|
} else {
|
|
Text("Draw a digit")
|
|
}
|
|
Canvas(path: $path)
|
|
.frame(width: canvasSize, height: canvasSize)
|
|
HStack {
|
|
Button("Predict") {
|
|
path.center(to: CGPoint(x: canvasSize / 2, y: canvasSize / 2))
|
|
predict()
|
|
}
|
|
Button("Clear") {
|
|
path = Path()
|
|
prediction = nil
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
@MainActor
|
|
func predict() {
|
|
let imageRenderer = ImageRenderer(
|
|
content: Canvas(path: $path).frame(width: 150, height: 150))
|
|
guard
|
|
let pixelData = imageRenderer.cgImage?.grayscaleImage(with: mnistImageSize)?.pixelData()
|
|
else {
|
|
return
|
|
}
|
|
// modify input vector to match training in MNIST/Files.swift
|
|
let x = pixelData.reshaped([1, 28, 28, 1]).asType(.float32) / 255.0
|
|
prediction = argMax(model(x)).item()
|
|
}
|
|
}
|
|
|
|
extension CGImage {
|
|
func grayscaleImage(with newSize: CGSize) -> CGImage? {
|
|
let colorSpace = CGColorSpaceCreateDeviceGray()
|
|
let bitmapInfo = CGBitmapInfo(rawValue: CGImageAlphaInfo.none.rawValue)
|
|
|
|
guard
|
|
let context = CGContext(
|
|
data: nil,
|
|
width: Int(newSize.width),
|
|
height: Int(newSize.height),
|
|
bitsPerComponent: 8,
|
|
bytesPerRow: Int(newSize.width),
|
|
space: colorSpace,
|
|
bitmapInfo: bitmapInfo.rawValue)
|
|
else {
|
|
return nil
|
|
}
|
|
context.draw(self, in: CGRect(x: 0, y: 0, width: newSize.width, height: newSize.width))
|
|
return context.makeImage()
|
|
}
|
|
|
|
func pixelData() -> MLXArray {
|
|
guard let data = self.dataProvider?.data else {
|
|
return []
|
|
}
|
|
let bytePtr = CFDataGetBytePtr(data)
|
|
let count = CFDataGetLength(data)
|
|
return MLXArray(UnsafeBufferPointer(start: bytePtr, count: count))
|
|
}
|
|
}
|