LeNet on MNIST + readme update (#12)
* LeNet on MNIST + readme update * tanh + remove device toggle * remove device entirely
This commit is contained in:
@@ -12,9 +12,6 @@ struct ContentView: View {
|
||||
// the training loop
|
||||
@State var trainer = Trainer()
|
||||
|
||||
// toggle for cpu/gpu training
|
||||
@State var cpu = true
|
||||
|
||||
var body: some View {
|
||||
VStack {
|
||||
Spacer()
|
||||
@@ -30,13 +27,10 @@ struct ContentView: View {
|
||||
|
||||
Button("Train") {
|
||||
Task {
|
||||
try! await trainer.run(device: cpu ? .cpu : .gpu)
|
||||
try! await trainer.run()
|
||||
}
|
||||
}
|
||||
|
||||
Toggle("CPU", isOn: $cpu)
|
||||
.frame(maxWidth: 150)
|
||||
|
||||
Spacer()
|
||||
}
|
||||
Spacer()
|
||||
@@ -50,12 +44,10 @@ class Trainer {
|
||||
|
||||
var messages = [String]()
|
||||
|
||||
func run(device: Device = .cpu) async throws {
|
||||
func run() async throws {
|
||||
// Note: this is pretty close to the code in `mnist-tool`, just
|
||||
// wrapped in an Observable to make it easy to display in SwiftUI
|
||||
|
||||
Device.setDefault(device: device)
|
||||
|
||||
// download & load the training data
|
||||
let url = URL(fileURLWithPath: NSTemporaryDirectory(), isDirectory: true)
|
||||
try await download(into: url)
|
||||
@@ -67,9 +59,7 @@ class Trainer {
|
||||
let testLabels = data[.init(.test, .labels)]!
|
||||
|
||||
// create the model with random weights
|
||||
let model = MLP(
|
||||
layers: 2, inputDimensions: trainImages.dim(-1), hiddenDimensions: 32,
|
||||
outputDimensions: 10)
|
||||
let model = LeNet()
|
||||
eval(model.parameters())
|
||||
|
||||
// the training loop
|
||||
|
||||
@@ -1,13 +1,13 @@
|
||||
# MNISTTrainer
|
||||
|
||||
This is an example showing how to do model training on both macOS and iOS.
|
||||
This will download the MNIST training data, create a new models and train
|
||||
it. It will show the timing per epoch and the test accuracy as it trains.
|
||||
This is an example of model training that works on both macOS and iOS.
|
||||
The example will download the MNIST training data, create a LeNet, and train
|
||||
it. It will show the epoch time and test accuracy as it trains.
|
||||
|
||||
You will need to set the Team on the MNISTTrainer target in order to build and
|
||||
run on iOS.
|
||||
|
||||
Some notes about the setup:
|
||||
|
||||
- this will download test data over the network so MNISTTrainer -> Signing & Capabilities has the "Outgoing Connections (Client)" set in the App Sandbox
|
||||
- the website it connects to uses http rather than https so it has a "App Transport Security Settings" in the Info.plist
|
||||
- This will download test data over the network so MNISTTrainer -> Signing & Capabilities has the "Outgoing Connections (Client)" set in the App Sandbox
|
||||
- The website it connects to uses http rather than https so it has a "App Transport Security Settings" in the Info.plist
|
||||
|
||||
@@ -43,13 +43,13 @@ let files = [
|
||||
name: "train-images-idx3-ubyte.gz",
|
||||
offset: 16,
|
||||
convert: {
|
||||
$0.reshaped([-1, 28 * 28]).asType(.float32) / 255.0
|
||||
$0.reshaped([-1, 28, 28, 1]).asType(.float32) / 255.0
|
||||
}),
|
||||
FileKind(.test, .images): LoadInfo(
|
||||
name: "t10k-images-idx3-ubyte.gz",
|
||||
offset: 16,
|
||||
convert: {
|
||||
$0.reshaped([-1, 28 * 28]).asType(.float32) / 255.0
|
||||
$0.reshaped([-1, 28, 28, 1]).asType(.float32) / 255.0
|
||||
}),
|
||||
FileKind(.training, .labels): LoadInfo(
|
||||
name: "train-labels-idx1-ubyte.gz",
|
||||
|
||||
@@ -6,36 +6,43 @@ import MLXNN
|
||||
|
||||
// based on https://github.com/ml-explore/mlx-examples/blob/main/mnist/main.py
|
||||
|
||||
public class MLP: Module, UnaryLayer {
|
||||
public class LeNet: Module, UnaryLayer {
|
||||
|
||||
@ModuleInfo var layers: [Linear]
|
||||
@ModuleInfo var conv1: Conv2d
|
||||
@ModuleInfo var conv2: Conv2d
|
||||
@ModuleInfo var pool1: MaxPool2d
|
||||
@ModuleInfo var pool2: MaxPool2d
|
||||
@ModuleInfo var fc1: Linear
|
||||
@ModuleInfo var fc2: Linear
|
||||
@ModuleInfo var fc3: Linear
|
||||
|
||||
public init(layers: Int, inputDimensions: Int, hiddenDimensions: Int, outputDimensions: Int) {
|
||||
let layerSizes =
|
||||
[inputDimensions] + Array(repeating: hiddenDimensions, count: layers) + [
|
||||
outputDimensions
|
||||
]
|
||||
|
||||
self.layers = zip(layerSizes.dropLast(), layerSizes.dropFirst())
|
||||
.map {
|
||||
Linear($0, $1)
|
||||
}
|
||||
override public init() {
|
||||
conv1 = Conv2d(inputChannels: 1, outputChannels: 6, kernelSize: 5, padding: 2)
|
||||
conv2 = Conv2d(inputChannels: 6, outputChannels: 16, kernelSize: 5, padding: 0)
|
||||
pool1 = MaxPool2d(kernelSize: 2, stride: 2)
|
||||
pool2 = MaxPool2d(kernelSize: 2, stride: 2)
|
||||
fc1 = Linear(16 * 5 * 5, 120)
|
||||
fc2 = Linear(120, 84)
|
||||
fc3 = Linear(84, 10)
|
||||
}
|
||||
|
||||
public func callAsFunction(_ x: MLXArray) -> MLXArray {
|
||||
var x = x
|
||||
for l in layers.dropLast() {
|
||||
x = relu(l(x))
|
||||
}
|
||||
return layers.last!(x)
|
||||
x = pool1(tanh(conv1(x)))
|
||||
x = pool2(tanh(conv2(x)))
|
||||
x = flattened(x, start: 1)
|
||||
x = tanh(fc1(x))
|
||||
x = tanh(fc2(x))
|
||||
x = fc3(x)
|
||||
return x
|
||||
}
|
||||
}
|
||||
|
||||
public func loss(model: MLP, x: MLXArray, y: MLXArray) -> MLXArray {
|
||||
public func loss(model: LeNet, x: MLXArray, y: MLXArray) -> MLXArray {
|
||||
crossEntropy(logits: model(x), targets: y, reduction: .mean)
|
||||
}
|
||||
|
||||
public func eval(model: MLP, x: MLXArray, y: MLXArray) -> MLXArray {
|
||||
public func eval(model: LeNet, x: MLXArray, y: MLXArray) -> MLXArray {
|
||||
mean(argMax(model(x), axis: 1) .== y)
|
||||
}
|
||||
|
||||
|
||||
@@ -1,13 +1,11 @@
|
||||
# MNIST
|
||||
|
||||
This is a port of the MNIST model and training code from:
|
||||
|
||||
- https://github.com/ml-explore/mlx-examples/blob/main/mnist
|
||||
This is a port of the MNIST training code from the [Python MLX example](https://github.com/ml-explore/mlx-examples/blob/main/mnist). This example uses a [LeNet](https://en.wikipedia.org/wiki/LeNet) instead of an MLP.
|
||||
|
||||
It provides code to:
|
||||
|
||||
- download the test/train data
|
||||
- provides the MNIST model (MLP)
|
||||
- some functions to shuffle and batch the data
|
||||
- Download the MNIST test/train data
|
||||
- Build the LeNet
|
||||
- Some functions to shuffle and batch the data
|
||||
|
||||
See [mnist-tool](../../Tools/mnist-tool) for an example of how to run this. The training loop also lives there.
|
||||
See [mnist-tool](../../Tools/mnist-tool) for an example of how to run this. The training loop also lives there.
|
||||
|
||||
43
README.md
43
README.md
@@ -1,37 +1,20 @@
|
||||
# MLX Swift Examples
|
||||
|
||||
Example [mlx-swift](https://github.com/ml-explore/mlx-swift) programs.
|
||||
Example [MLX Swift](https://github.com/ml-explore/mlx-swift) programs.
|
||||
|
||||
## MNISTTrainer
|
||||
- [MNISTTrainer](Applications/MNISTTrainer/README.md): An example that runs on
|
||||
both iOS and macOS that downloads MNIST training data and trains a
|
||||
[LeNet](https://en.wikipedia.org/wiki/LeNet).
|
||||
|
||||
An example that runs on both iOS and macOS that downloads MNIST training
|
||||
data and trains an MNIST model.
|
||||
- [LLMEval](Applications/LLMEval/README.md): An example that runs on both iOS
|
||||
and macOS that downloads an LLM and tokenizer from Hugging Face and and
|
||||
generates text from a given prompt.
|
||||
|
||||
- [README](Applications/MNISTTrainer/README.md)
|
||||
- [LinearModelTraining](Tools/LinearModelTraining/README.md): An example that
|
||||
trains a simple linear model.
|
||||
|
||||
## LLMEval
|
||||
|
||||
An example that runs on both iOS and macOS that downloads a LLM model
|
||||
weights and tokenizer configuration from Hugging Face and evaluates
|
||||
a prompt in-process.
|
||||
|
||||
- [README](Applications/LLMEval/README.md)
|
||||
|
||||
## LinearModelTraining
|
||||
|
||||
A simple linear model and a training loop.
|
||||
|
||||
- [README](Tools/LinearModelTraining/README.md)
|
||||
|
||||
## llm-tool
|
||||
|
||||
A command line tool for generating text using a variety of Hugging Face models:
|
||||
|
||||
- [README](Tools/llm-tool/README.md)
|
||||
|
||||
## mnist-tool
|
||||
|
||||
A command line tool for training an MNIST (MLP) model:
|
||||
|
||||
- [README](Tools/mnist-tool/README.md)
|
||||
- [llm-tool](Tools/llm-tool/README.md): A command line tool for generating text
|
||||
using a variety of LLMs available on the Hugging Face hub.
|
||||
|
||||
- [mnist-tool](Tools/mnist-tool/README.md): A command line tool for training a
|
||||
a LeNet on MNIST.
|
||||
|
||||
@@ -30,15 +30,11 @@ struct Train: AsyncParsableCommand {
|
||||
@Option(name: .long, help: "The PRNG seed")
|
||||
var seed: UInt64 = 0
|
||||
|
||||
@Option var layers = 2
|
||||
@Option var hidden = 32
|
||||
@Option var batchSize = 256
|
||||
@Option var epochs = 20
|
||||
@Option var learningRate: Float = 1e-1
|
||||
|
||||
@Option var classes = 10
|
||||
|
||||
@Option var device = DeviceType.cpu
|
||||
@Option var device = DeviceType.gpu
|
||||
|
||||
@Flag var compile = false
|
||||
|
||||
@@ -62,9 +58,7 @@ struct Train: AsyncParsableCommand {
|
||||
let testLabels = data[.init(.test, .labels)]!
|
||||
|
||||
// create the model
|
||||
let model = MLP(
|
||||
layers: layers, inputDimensions: trainImages.dim(-1), hiddenDimensions: hidden,
|
||||
outputDimensions: classes)
|
||||
let model = LeNet()
|
||||
eval(model.parameters())
|
||||
|
||||
let lg = valueAndGrad(model: model, loss)
|
||||
|
||||
@@ -1,8 +1,6 @@
|
||||
# mnist-tool
|
||||
|
||||
See other README:
|
||||
|
||||
- [MNIST](../../Libraries/MNIST/README.md)
|
||||
See the [MNIST README.md](../../Libraries/MNIST/README.md).
|
||||
|
||||
### Building
|
||||
|
||||
|
||||
Reference in New Issue
Block a user