LeNet on MNIST + readme update (#12)
* LeNet on MNIST + readme update * tanh + remove device toggle * remove device entirely
This commit is contained in:
@@ -12,9 +12,6 @@ struct ContentView: View {
|
|||||||
// the training loop
|
// the training loop
|
||||||
@State var trainer = Trainer()
|
@State var trainer = Trainer()
|
||||||
|
|
||||||
// toggle for cpu/gpu training
|
|
||||||
@State var cpu = true
|
|
||||||
|
|
||||||
var body: some View {
|
var body: some View {
|
||||||
VStack {
|
VStack {
|
||||||
Spacer()
|
Spacer()
|
||||||
@@ -30,13 +27,10 @@ struct ContentView: View {
|
|||||||
|
|
||||||
Button("Train") {
|
Button("Train") {
|
||||||
Task {
|
Task {
|
||||||
try! await trainer.run(device: cpu ? .cpu : .gpu)
|
try! await trainer.run()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
Toggle("CPU", isOn: $cpu)
|
|
||||||
.frame(maxWidth: 150)
|
|
||||||
|
|
||||||
Spacer()
|
Spacer()
|
||||||
}
|
}
|
||||||
Spacer()
|
Spacer()
|
||||||
@@ -50,12 +44,10 @@ class Trainer {
|
|||||||
|
|
||||||
var messages = [String]()
|
var messages = [String]()
|
||||||
|
|
||||||
func run(device: Device = .cpu) async throws {
|
func run() async throws {
|
||||||
// Note: this is pretty close to the code in `mnist-tool`, just
|
// Note: this is pretty close to the code in `mnist-tool`, just
|
||||||
// wrapped in an Observable to make it easy to display in SwiftUI
|
// wrapped in an Observable to make it easy to display in SwiftUI
|
||||||
|
|
||||||
Device.setDefault(device: device)
|
|
||||||
|
|
||||||
// download & load the training data
|
// download & load the training data
|
||||||
let url = URL(fileURLWithPath: NSTemporaryDirectory(), isDirectory: true)
|
let url = URL(fileURLWithPath: NSTemporaryDirectory(), isDirectory: true)
|
||||||
try await download(into: url)
|
try await download(into: url)
|
||||||
@@ -67,9 +59,7 @@ class Trainer {
|
|||||||
let testLabels = data[.init(.test, .labels)]!
|
let testLabels = data[.init(.test, .labels)]!
|
||||||
|
|
||||||
// create the model with random weights
|
// create the model with random weights
|
||||||
let model = MLP(
|
let model = LeNet()
|
||||||
layers: 2, inputDimensions: trainImages.dim(-1), hiddenDimensions: 32,
|
|
||||||
outputDimensions: 10)
|
|
||||||
eval(model.parameters())
|
eval(model.parameters())
|
||||||
|
|
||||||
// the training loop
|
// the training loop
|
||||||
|
|||||||
@@ -1,13 +1,13 @@
|
|||||||
# MNISTTrainer
|
# MNISTTrainer
|
||||||
|
|
||||||
This is an example showing how to do model training on both macOS and iOS.
|
This is an example of model training that works on both macOS and iOS.
|
||||||
This will download the MNIST training data, create a new models and train
|
The example will download the MNIST training data, create a LeNet, and train
|
||||||
it. It will show the timing per epoch and the test accuracy as it trains.
|
it. It will show the epoch time and test accuracy as it trains.
|
||||||
|
|
||||||
You will need to set the Team on the MNISTTrainer target in order to build and
|
You will need to set the Team on the MNISTTrainer target in order to build and
|
||||||
run on iOS.
|
run on iOS.
|
||||||
|
|
||||||
Some notes about the setup:
|
Some notes about the setup:
|
||||||
|
|
||||||
- this will download test data over the network so MNISTTrainer -> Signing & Capabilities has the "Outgoing Connections (Client)" set in the App Sandbox
|
- This will download test data over the network so MNISTTrainer -> Signing & Capabilities has the "Outgoing Connections (Client)" set in the App Sandbox
|
||||||
- the website it connects to uses http rather than https so it has a "App Transport Security Settings" in the Info.plist
|
- The website it connects to uses http rather than https so it has a "App Transport Security Settings" in the Info.plist
|
||||||
|
|||||||
@@ -43,13 +43,13 @@ let files = [
|
|||||||
name: "train-images-idx3-ubyte.gz",
|
name: "train-images-idx3-ubyte.gz",
|
||||||
offset: 16,
|
offset: 16,
|
||||||
convert: {
|
convert: {
|
||||||
$0.reshaped([-1, 28 * 28]).asType(.float32) / 255.0
|
$0.reshaped([-1, 28, 28, 1]).asType(.float32) / 255.0
|
||||||
}),
|
}),
|
||||||
FileKind(.test, .images): LoadInfo(
|
FileKind(.test, .images): LoadInfo(
|
||||||
name: "t10k-images-idx3-ubyte.gz",
|
name: "t10k-images-idx3-ubyte.gz",
|
||||||
offset: 16,
|
offset: 16,
|
||||||
convert: {
|
convert: {
|
||||||
$0.reshaped([-1, 28 * 28]).asType(.float32) / 255.0
|
$0.reshaped([-1, 28, 28, 1]).asType(.float32) / 255.0
|
||||||
}),
|
}),
|
||||||
FileKind(.training, .labels): LoadInfo(
|
FileKind(.training, .labels): LoadInfo(
|
||||||
name: "train-labels-idx1-ubyte.gz",
|
name: "train-labels-idx1-ubyte.gz",
|
||||||
|
|||||||
@@ -6,36 +6,43 @@ import MLXNN
|
|||||||
|
|
||||||
// based on https://github.com/ml-explore/mlx-examples/blob/main/mnist/main.py
|
// based on https://github.com/ml-explore/mlx-examples/blob/main/mnist/main.py
|
||||||
|
|
||||||
public class MLP: Module, UnaryLayer {
|
public class LeNet: Module, UnaryLayer {
|
||||||
|
|
||||||
@ModuleInfo var layers: [Linear]
|
@ModuleInfo var conv1: Conv2d
|
||||||
|
@ModuleInfo var conv2: Conv2d
|
||||||
|
@ModuleInfo var pool1: MaxPool2d
|
||||||
|
@ModuleInfo var pool2: MaxPool2d
|
||||||
|
@ModuleInfo var fc1: Linear
|
||||||
|
@ModuleInfo var fc2: Linear
|
||||||
|
@ModuleInfo var fc3: Linear
|
||||||
|
|
||||||
public init(layers: Int, inputDimensions: Int, hiddenDimensions: Int, outputDimensions: Int) {
|
override public init() {
|
||||||
let layerSizes =
|
conv1 = Conv2d(inputChannels: 1, outputChannels: 6, kernelSize: 5, padding: 2)
|
||||||
[inputDimensions] + Array(repeating: hiddenDimensions, count: layers) + [
|
conv2 = Conv2d(inputChannels: 6, outputChannels: 16, kernelSize: 5, padding: 0)
|
||||||
outputDimensions
|
pool1 = MaxPool2d(kernelSize: 2, stride: 2)
|
||||||
]
|
pool2 = MaxPool2d(kernelSize: 2, stride: 2)
|
||||||
|
fc1 = Linear(16 * 5 * 5, 120)
|
||||||
self.layers = zip(layerSizes.dropLast(), layerSizes.dropFirst())
|
fc2 = Linear(120, 84)
|
||||||
.map {
|
fc3 = Linear(84, 10)
|
||||||
Linear($0, $1)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
public func callAsFunction(_ x: MLXArray) -> MLXArray {
|
public func callAsFunction(_ x: MLXArray) -> MLXArray {
|
||||||
var x = x
|
var x = x
|
||||||
for l in layers.dropLast() {
|
x = pool1(tanh(conv1(x)))
|
||||||
x = relu(l(x))
|
x = pool2(tanh(conv2(x)))
|
||||||
}
|
x = flattened(x, start: 1)
|
||||||
return layers.last!(x)
|
x = tanh(fc1(x))
|
||||||
|
x = tanh(fc2(x))
|
||||||
|
x = fc3(x)
|
||||||
|
return x
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
public func loss(model: MLP, x: MLXArray, y: MLXArray) -> MLXArray {
|
public func loss(model: LeNet, x: MLXArray, y: MLXArray) -> MLXArray {
|
||||||
crossEntropy(logits: model(x), targets: y, reduction: .mean)
|
crossEntropy(logits: model(x), targets: y, reduction: .mean)
|
||||||
}
|
}
|
||||||
|
|
||||||
public func eval(model: MLP, x: MLXArray, y: MLXArray) -> MLXArray {
|
public func eval(model: LeNet, x: MLXArray, y: MLXArray) -> MLXArray {
|
||||||
mean(argMax(model(x), axis: 1) .== y)
|
mean(argMax(model(x), axis: 1) .== y)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -1,13 +1,11 @@
|
|||||||
# MNIST
|
# MNIST
|
||||||
|
|
||||||
This is a port of the MNIST model and training code from:
|
This is a port of the MNIST training code from the [Python MLX example](https://github.com/ml-explore/mlx-examples/blob/main/mnist). This example uses a [LeNet](https://en.wikipedia.org/wiki/LeNet) instead of an MLP.
|
||||||
|
|
||||||
- https://github.com/ml-explore/mlx-examples/blob/main/mnist
|
|
||||||
|
|
||||||
It provides code to:
|
It provides code to:
|
||||||
|
|
||||||
- download the test/train data
|
- Download the MNIST test/train data
|
||||||
- provides the MNIST model (MLP)
|
- Build the LeNet
|
||||||
- some functions to shuffle and batch the data
|
- Some functions to shuffle and batch the data
|
||||||
|
|
||||||
See [mnist-tool](../../Tools/mnist-tool) for an example of how to run this. The training loop also lives there.
|
See [mnist-tool](../../Tools/mnist-tool) for an example of how to run this. The training loop also lives there.
|
||||||
|
|||||||
43
README.md
43
README.md
@@ -1,37 +1,20 @@
|
|||||||
# MLX Swift Examples
|
# MLX Swift Examples
|
||||||
|
|
||||||
Example [mlx-swift](https://github.com/ml-explore/mlx-swift) programs.
|
Example [MLX Swift](https://github.com/ml-explore/mlx-swift) programs.
|
||||||
|
|
||||||
## MNISTTrainer
|
- [MNISTTrainer](Applications/MNISTTrainer/README.md): An example that runs on
|
||||||
|
both iOS and macOS that downloads MNIST training data and trains a
|
||||||
|
[LeNet](https://en.wikipedia.org/wiki/LeNet).
|
||||||
|
|
||||||
An example that runs on both iOS and macOS that downloads MNIST training
|
- [LLMEval](Applications/LLMEval/README.md): An example that runs on both iOS
|
||||||
data and trains an MNIST model.
|
and macOS that downloads an LLM and tokenizer from Hugging Face and and
|
||||||
|
generates text from a given prompt.
|
||||||
|
|
||||||
- [README](Applications/MNISTTrainer/README.md)
|
- [LinearModelTraining](Tools/LinearModelTraining/README.md): An example that
|
||||||
|
trains a simple linear model.
|
||||||
|
|
||||||
## LLMEval
|
- [llm-tool](Tools/llm-tool/README.md): A command line tool for generating text
|
||||||
|
using a variety of LLMs available on the Hugging Face hub.
|
||||||
An example that runs on both iOS and macOS that downloads a LLM model
|
|
||||||
weights and tokenizer configuration from Hugging Face and evaluates
|
|
||||||
a prompt in-process.
|
|
||||||
|
|
||||||
- [README](Applications/LLMEval/README.md)
|
|
||||||
|
|
||||||
## LinearModelTraining
|
|
||||||
|
|
||||||
A simple linear model and a training loop.
|
|
||||||
|
|
||||||
- [README](Tools/LinearModelTraining/README.md)
|
|
||||||
|
|
||||||
## llm-tool
|
|
||||||
|
|
||||||
A command line tool for generating text using a variety of Hugging Face models:
|
|
||||||
|
|
||||||
- [README](Tools/llm-tool/README.md)
|
|
||||||
|
|
||||||
## mnist-tool
|
|
||||||
|
|
||||||
A command line tool for training an MNIST (MLP) model:
|
|
||||||
|
|
||||||
- [README](Tools/mnist-tool/README.md)
|
|
||||||
|
|
||||||
|
- [mnist-tool](Tools/mnist-tool/README.md): A command line tool for training a
|
||||||
|
a LeNet on MNIST.
|
||||||
|
|||||||
@@ -30,15 +30,11 @@ struct Train: AsyncParsableCommand {
|
|||||||
@Option(name: .long, help: "The PRNG seed")
|
@Option(name: .long, help: "The PRNG seed")
|
||||||
var seed: UInt64 = 0
|
var seed: UInt64 = 0
|
||||||
|
|
||||||
@Option var layers = 2
|
|
||||||
@Option var hidden = 32
|
|
||||||
@Option var batchSize = 256
|
@Option var batchSize = 256
|
||||||
@Option var epochs = 20
|
@Option var epochs = 20
|
||||||
@Option var learningRate: Float = 1e-1
|
@Option var learningRate: Float = 1e-1
|
||||||
|
|
||||||
@Option var classes = 10
|
@Option var device = DeviceType.gpu
|
||||||
|
|
||||||
@Option var device = DeviceType.cpu
|
|
||||||
|
|
||||||
@Flag var compile = false
|
@Flag var compile = false
|
||||||
|
|
||||||
@@ -62,9 +58,7 @@ struct Train: AsyncParsableCommand {
|
|||||||
let testLabels = data[.init(.test, .labels)]!
|
let testLabels = data[.init(.test, .labels)]!
|
||||||
|
|
||||||
// create the model
|
// create the model
|
||||||
let model = MLP(
|
let model = LeNet()
|
||||||
layers: layers, inputDimensions: trainImages.dim(-1), hiddenDimensions: hidden,
|
|
||||||
outputDimensions: classes)
|
|
||||||
eval(model.parameters())
|
eval(model.parameters())
|
||||||
|
|
||||||
let lg = valueAndGrad(model: model, loss)
|
let lg = valueAndGrad(model: model, loss)
|
||||||
|
|||||||
@@ -1,8 +1,6 @@
|
|||||||
# mnist-tool
|
# mnist-tool
|
||||||
|
|
||||||
See other README:
|
See the [MNIST README.md](../../Libraries/MNIST/README.md).
|
||||||
|
|
||||||
- [MNIST](../../Libraries/MNIST/README.md)
|
|
||||||
|
|
||||||
### Building
|
### Building
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user