Add MNIST Digit Prediction/Inference (#22)
* Add Prediction to MNISTTrainer
This commit is contained in:
@@ -7,10 +7,9 @@ import MLXRandom
|
|||||||
import MNIST
|
import MNIST
|
||||||
import SwiftUI
|
import SwiftUI
|
||||||
|
|
||||||
struct ContentView: View {
|
struct TrainingView: View {
|
||||||
|
|
||||||
// the training loop
|
@Binding var trainer: Trainer
|
||||||
@State var trainer = Trainer()
|
|
||||||
|
|
||||||
var body: some View {
|
var body: some View {
|
||||||
VStack {
|
VStack {
|
||||||
@@ -24,12 +23,18 @@ struct ContentView: View {
|
|||||||
|
|
||||||
HStack {
|
HStack {
|
||||||
Spacer()
|
Spacer()
|
||||||
|
switch trainer.state {
|
||||||
|
case .untrained:
|
||||||
Button("Train") {
|
Button("Train") {
|
||||||
Task {
|
Task {
|
||||||
try! await trainer.run()
|
try! await trainer.run()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
case .trained(let model), .predict(let model):
|
||||||
|
Button("Draw a digit") {
|
||||||
|
trainer.state = .predict(model)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
Spacer()
|
Spacer()
|
||||||
}
|
}
|
||||||
@@ -39,9 +44,30 @@ struct ContentView: View {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
struct ContentView: View {
|
||||||
|
// the training loop
|
||||||
|
@State var trainer = Trainer()
|
||||||
|
|
||||||
|
var body: some View {
|
||||||
|
switch trainer.state {
|
||||||
|
case .untrained, .trained:
|
||||||
|
TrainingView(trainer: $trainer)
|
||||||
|
case .predict(let model):
|
||||||
|
PredictionView(model: model)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
@Observable
|
@Observable
|
||||||
class Trainer {
|
class Trainer {
|
||||||
|
|
||||||
|
enum State {
|
||||||
|
case untrained
|
||||||
|
case trained(LeNet)
|
||||||
|
case predict(LeNet)
|
||||||
|
}
|
||||||
|
|
||||||
|
var state: State = .untrained
|
||||||
var messages = [String]()
|
var messages = [String]()
|
||||||
|
|
||||||
func run() async throws {
|
func run() async throws {
|
||||||
@@ -101,6 +127,8 @@ class Trainer {
|
|||||||
)
|
)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
await MainActor.run {
|
||||||
|
state = .trained(model)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
127
Applications/MNISTTrainer/PredictionView.swift
Normal file
127
Applications/MNISTTrainer/PredictionView.swift
Normal file
@@ -0,0 +1,127 @@
|
|||||||
|
//
|
||||||
|
// PredictionView.swift
|
||||||
|
// MNISTTrainer
|
||||||
|
//
|
||||||
|
// Created by Rounak Jain on 3/9/24.
|
||||||
|
//
|
||||||
|
|
||||||
|
import MLX
|
||||||
|
import MLXNN
|
||||||
|
import MNIST
|
||||||
|
import SwiftUI
|
||||||
|
|
||||||
|
struct Canvas: View {
|
||||||
|
|
||||||
|
@Binding var path: Path
|
||||||
|
@State var lastPoint: CGPoint?
|
||||||
|
|
||||||
|
var body: some View {
|
||||||
|
path
|
||||||
|
.stroke(.white, lineWidth: 10)
|
||||||
|
.background(.black)
|
||||||
|
.gesture(
|
||||||
|
DragGesture(minimumDistance: 0.05)
|
||||||
|
.onChanged { touch in
|
||||||
|
add(point: touch.location)
|
||||||
|
}
|
||||||
|
.onEnded { touch in
|
||||||
|
lastPoint = nil
|
||||||
|
}
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
func add(point: CGPoint) {
|
||||||
|
var newPath = path
|
||||||
|
if let lastPoint {
|
||||||
|
newPath.move(to: lastPoint)
|
||||||
|
newPath.addLine(to: point)
|
||||||
|
} else {
|
||||||
|
newPath.move(to: point)
|
||||||
|
}
|
||||||
|
self.path = newPath
|
||||||
|
lastPoint = point
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
extension Path {
|
||||||
|
mutating func center(to newMidPoint: CGPoint) {
|
||||||
|
let middleX = boundingRect.midX
|
||||||
|
let middleY = boundingRect.midY
|
||||||
|
self = offsetBy(dx: newMidPoint.x - middleX, dy: newMidPoint.y - middleY)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
struct PredictionView: View {
|
||||||
|
@State var path: Path = Path()
|
||||||
|
@State var prediction: Int?
|
||||||
|
let model: LeNet
|
||||||
|
let canvasSize = 150.0
|
||||||
|
let mnistImageSize: CGSize = CGSize(width: 28, height: 28)
|
||||||
|
|
||||||
|
var body: some View {
|
||||||
|
VStack {
|
||||||
|
if let prediction {
|
||||||
|
Text("You've drawn a \(prediction)")
|
||||||
|
} else {
|
||||||
|
Text("Draw a digit")
|
||||||
|
}
|
||||||
|
Canvas(path: $path)
|
||||||
|
.frame(width: canvasSize, height: canvasSize)
|
||||||
|
HStack {
|
||||||
|
Button("Predict") {
|
||||||
|
path.center(to: CGPoint(x: canvasSize / 2, y: canvasSize / 2))
|
||||||
|
predict()
|
||||||
|
}
|
||||||
|
Button("Clear") {
|
||||||
|
path = Path()
|
||||||
|
prediction = nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
@MainActor
|
||||||
|
func predict() {
|
||||||
|
let imageRenderer = ImageRenderer(
|
||||||
|
content: Canvas(path: $path).frame(width: 150, height: 150))
|
||||||
|
guard
|
||||||
|
let pixelData = imageRenderer.cgImage?.grayscaleImage(with: mnistImageSize)?.pixelData()
|
||||||
|
else {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
// modify input vector to match training in MNIST/Files.swift
|
||||||
|
let x = pixelData.reshaped([1, 28, 28, 1]).asType(.float32) / 255.0
|
||||||
|
prediction = argMax(model(x)).item()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
extension CGImage {
|
||||||
|
func grayscaleImage(with newSize: CGSize) -> CGImage? {
|
||||||
|
let colorSpace = CGColorSpaceCreateDeviceGray()
|
||||||
|
let bitmapInfo = CGBitmapInfo(rawValue: CGImageAlphaInfo.none.rawValue)
|
||||||
|
|
||||||
|
guard
|
||||||
|
let context = CGContext(
|
||||||
|
data: nil,
|
||||||
|
width: Int(newSize.width),
|
||||||
|
height: Int(newSize.height),
|
||||||
|
bitsPerComponent: 8,
|
||||||
|
bytesPerRow: Int(newSize.width),
|
||||||
|
space: colorSpace,
|
||||||
|
bitmapInfo: bitmapInfo.rawValue)
|
||||||
|
else {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
context.draw(self, in: CGRect(x: 0, y: 0, width: newSize.width, height: newSize.width))
|
||||||
|
return context.makeImage()
|
||||||
|
}
|
||||||
|
|
||||||
|
func pixelData() -> MLXArray {
|
||||||
|
guard let data = self.dataProvider?.data else {
|
||||||
|
return []
|
||||||
|
}
|
||||||
|
let bytePtr = CFDataGetBytePtr(data)
|
||||||
|
let count = CFDataGetLength(data)
|
||||||
|
return MLXArray(UnsafeBufferPointer(start: bytePtr, count: count))
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -14,7 +14,7 @@ possible.
|
|||||||
You can also run the formatters manually as follows:
|
You can also run the formatters manually as follows:
|
||||||
|
|
||||||
```
|
```
|
||||||
swift-format format --in-place --recursive Libraries Tools
|
swift-format format --in-place --recursive Libraries Tools Applications
|
||||||
```
|
```
|
||||||
|
|
||||||
or run `pre-commit run --all-files` to check all files in the repo.
|
or run `pre-commit run --all-files` to check all files in the repo.
|
||||||
|
|||||||
@@ -7,6 +7,7 @@
|
|||||||
objects = {
|
objects = {
|
||||||
|
|
||||||
/* Begin PBXBuildFile section */
|
/* Begin PBXBuildFile section */
|
||||||
|
12305EAF2B9D864400C92FEE /* PredictionView.swift in Sources */ = {isa = PBXBuildFile; fileRef = 12305EAE2B9D864400C92FEE /* PredictionView.swift */; };
|
||||||
525C1E9D2B9A011000B5C356 /* Starcoder2.swift in Sources */ = {isa = PBXBuildFile; fileRef = 525C1E9C2B9A010F00B5C356 /* Starcoder2.swift */; };
|
525C1E9D2B9A011000B5C356 /* Starcoder2.swift in Sources */ = {isa = PBXBuildFile; fileRef = 525C1E9C2B9A010F00B5C356 /* Starcoder2.swift */; };
|
||||||
52A776182B94B5EE00AA6E80 /* Qwen2.swift in Sources */ = {isa = PBXBuildFile; fileRef = 52A776172B94B5EE00AA6E80 /* Qwen2.swift */; };
|
52A776182B94B5EE00AA6E80 /* Qwen2.swift in Sources */ = {isa = PBXBuildFile; fileRef = 52A776172B94B5EE00AA6E80 /* Qwen2.swift */; };
|
||||||
81695B412BA373D300F260D8 /* MarkdownUI in Frameworks */ = {isa = PBXBuildFile; productRef = 81695B402BA373D300F260D8 /* MarkdownUI */; };
|
81695B412BA373D300F260D8 /* MarkdownUI in Frameworks */ = {isa = PBXBuildFile; productRef = 81695B402BA373D300F260D8 /* MarkdownUI */; };
|
||||||
@@ -183,6 +184,7 @@
|
|||||||
/* End PBXCopyFilesBuildPhase section */
|
/* End PBXCopyFilesBuildPhase section */
|
||||||
|
|
||||||
/* Begin PBXFileReference section */
|
/* Begin PBXFileReference section */
|
||||||
|
12305EAE2B9D864400C92FEE /* PredictionView.swift */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.swift; path = PredictionView.swift; sourceTree = "<group>"; };
|
||||||
525C1E9C2B9A010F00B5C356 /* Starcoder2.swift */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.swift; path = Starcoder2.swift; sourceTree = "<group>"; };
|
525C1E9C2B9A010F00B5C356 /* Starcoder2.swift */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.swift; path = Starcoder2.swift; sourceTree = "<group>"; };
|
||||||
52A776172B94B5EE00AA6E80 /* Qwen2.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = Qwen2.swift; sourceTree = "<group>"; };
|
52A776172B94B5EE00AA6E80 /* Qwen2.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = Qwen2.swift; sourceTree = "<group>"; };
|
||||||
C325DE3F2B648CDB00628871 /* README.md */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = net.daringfireball.markdown; path = README.md; sourceTree = "<group>"; };
|
C325DE3F2B648CDB00628871 /* README.md */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = net.daringfireball.markdown; path = README.md; sourceTree = "<group>"; };
|
||||||
@@ -451,6 +453,7 @@
|
|||||||
children = (
|
children = (
|
||||||
C3A8B3C32B92951E0002EFB8 /* Assets.xcassets */,
|
C3A8B3C32B92951E0002EFB8 /* Assets.xcassets */,
|
||||||
C3A8B3C92B92951E0002EFB8 /* ContentView.swift */,
|
C3A8B3C92B92951E0002EFB8 /* ContentView.swift */,
|
||||||
|
12305EAE2B9D864400C92FEE /* PredictionView.swift */,
|
||||||
C3A8B3C22B92951E0002EFB8 /* MNISTTrainer-Info.plist */,
|
C3A8B3C22B92951E0002EFB8 /* MNISTTrainer-Info.plist */,
|
||||||
C3A8B3C72B92951E0002EFB8 /* MNISTTrainer.entitlements */,
|
C3A8B3C72B92951E0002EFB8 /* MNISTTrainer.entitlements */,
|
||||||
C3A8B3C42B92951E0002EFB8 /* MNISTTrainerApp.swift */,
|
C3A8B3C42B92951E0002EFB8 /* MNISTTrainerApp.swift */,
|
||||||
@@ -866,6 +869,7 @@
|
|||||||
isa = PBXSourcesBuildPhase;
|
isa = PBXSourcesBuildPhase;
|
||||||
buildActionMask = 2147483647;
|
buildActionMask = 2147483647;
|
||||||
files = (
|
files = (
|
||||||
|
12305EAF2B9D864400C92FEE /* PredictionView.swift in Sources */,
|
||||||
C3A8B3CC2B92951E0002EFB8 /* MNISTTrainerApp.swift in Sources */,
|
C3A8B3CC2B92951E0002EFB8 /* MNISTTrainerApp.swift in Sources */,
|
||||||
C3A8B3CF2B92951E0002EFB8 /* ContentView.swift in Sources */,
|
C3A8B3CF2B92951E0002EFB8 /* ContentView.swift in Sources */,
|
||||||
);
|
);
|
||||||
|
|||||||
Reference in New Issue
Block a user