LeNet on MNIST + readme update (#12)

* LeNet on MNIST + readme update

* tanh + remove device toggle

* remove device entirely
This commit is contained in:
Awni Hannun
2024-03-04 14:16:20 -08:00
committed by GitHub
parent dfc9f2fc01
commit 4ed4ec69e7
8 changed files with 56 additions and 86 deletions

View File

@@ -12,9 +12,6 @@ struct ContentView: View {
// the training loop // the training loop
@State var trainer = Trainer() @State var trainer = Trainer()
// toggle for cpu/gpu training
@State var cpu = true
var body: some View { var body: some View {
VStack { VStack {
Spacer() Spacer()
@@ -30,13 +27,10 @@ struct ContentView: View {
Button("Train") { Button("Train") {
Task { Task {
try! await trainer.run(device: cpu ? .cpu : .gpu) try! await trainer.run()
} }
} }
Toggle("CPU", isOn: $cpu)
.frame(maxWidth: 150)
Spacer() Spacer()
} }
Spacer() Spacer()
@@ -50,12 +44,10 @@ class Trainer {
var messages = [String]() var messages = [String]()
func run(device: Device = .cpu) async throws { func run() async throws {
// Note: this is pretty close to the code in `mnist-tool`, just // Note: this is pretty close to the code in `mnist-tool`, just
// wrapped in an Observable to make it easy to display in SwiftUI // wrapped in an Observable to make it easy to display in SwiftUI
Device.setDefault(device: device)
// download & load the training data // download & load the training data
let url = URL(fileURLWithPath: NSTemporaryDirectory(), isDirectory: true) let url = URL(fileURLWithPath: NSTemporaryDirectory(), isDirectory: true)
try await download(into: url) try await download(into: url)
@@ -67,9 +59,7 @@ class Trainer {
let testLabels = data[.init(.test, .labels)]! let testLabels = data[.init(.test, .labels)]!
// create the model with random weights // create the model with random weights
let model = MLP( let model = LeNet()
layers: 2, inputDimensions: trainImages.dim(-1), hiddenDimensions: 32,
outputDimensions: 10)
eval(model.parameters()) eval(model.parameters())
// the training loop // the training loop

View File

@@ -1,13 +1,13 @@
# MNISTTrainer # MNISTTrainer
This is an example showing how to do model training on both macOS and iOS. This is an example of model training that works on both macOS and iOS.
This will download the MNIST training data, create a new models and train The example will download the MNIST training data, create a LeNet, and train
it. It will show the timing per epoch and the test accuracy as it trains. it. It will show the epoch time and test accuracy as it trains.
You will need to set the Team on the MNISTTrainer target in order to build and You will need to set the Team on the MNISTTrainer target in order to build and
run on iOS. run on iOS.
Some notes about the setup: Some notes about the setup:
- this will download test data over the network so MNISTTrainer -> Signing & Capabilities has the "Outgoing Connections (Client)" set in the App Sandbox - This will download test data over the network so MNISTTrainer -> Signing & Capabilities has the "Outgoing Connections (Client)" set in the App Sandbox
- the website it connects to uses http rather than https so it has a "App Transport Security Settings" in the Info.plist - The website it connects to uses http rather than https so it has a "App Transport Security Settings" in the Info.plist

View File

@@ -43,13 +43,13 @@ let files = [
name: "train-images-idx3-ubyte.gz", name: "train-images-idx3-ubyte.gz",
offset: 16, offset: 16,
convert: { convert: {
$0.reshaped([-1, 28 * 28]).asType(.float32) / 255.0 $0.reshaped([-1, 28, 28, 1]).asType(.float32) / 255.0
}), }),
FileKind(.test, .images): LoadInfo( FileKind(.test, .images): LoadInfo(
name: "t10k-images-idx3-ubyte.gz", name: "t10k-images-idx3-ubyte.gz",
offset: 16, offset: 16,
convert: { convert: {
$0.reshaped([-1, 28 * 28]).asType(.float32) / 255.0 $0.reshaped([-1, 28, 28, 1]).asType(.float32) / 255.0
}), }),
FileKind(.training, .labels): LoadInfo( FileKind(.training, .labels): LoadInfo(
name: "train-labels-idx1-ubyte.gz", name: "train-labels-idx1-ubyte.gz",

View File

@@ -6,36 +6,43 @@ import MLXNN
// based on https://github.com/ml-explore/mlx-examples/blob/main/mnist/main.py // based on https://github.com/ml-explore/mlx-examples/blob/main/mnist/main.py
public class MLP: Module, UnaryLayer { public class LeNet: Module, UnaryLayer {
@ModuleInfo var layers: [Linear] @ModuleInfo var conv1: Conv2d
@ModuleInfo var conv2: Conv2d
@ModuleInfo var pool1: MaxPool2d
@ModuleInfo var pool2: MaxPool2d
@ModuleInfo var fc1: Linear
@ModuleInfo var fc2: Linear
@ModuleInfo var fc3: Linear
public init(layers: Int, inputDimensions: Int, hiddenDimensions: Int, outputDimensions: Int) { override public init() {
let layerSizes = conv1 = Conv2d(inputChannels: 1, outputChannels: 6, kernelSize: 5, padding: 2)
[inputDimensions] + Array(repeating: hiddenDimensions, count: layers) + [ conv2 = Conv2d(inputChannels: 6, outputChannels: 16, kernelSize: 5, padding: 0)
outputDimensions pool1 = MaxPool2d(kernelSize: 2, stride: 2)
] pool2 = MaxPool2d(kernelSize: 2, stride: 2)
fc1 = Linear(16 * 5 * 5, 120)
self.layers = zip(layerSizes.dropLast(), layerSizes.dropFirst()) fc2 = Linear(120, 84)
.map { fc3 = Linear(84, 10)
Linear($0, $1)
}
} }
public func callAsFunction(_ x: MLXArray) -> MLXArray { public func callAsFunction(_ x: MLXArray) -> MLXArray {
var x = x var x = x
for l in layers.dropLast() { x = pool1(tanh(conv1(x)))
x = relu(l(x)) x = pool2(tanh(conv2(x)))
} x = flattened(x, start: 1)
return layers.last!(x) x = tanh(fc1(x))
x = tanh(fc2(x))
x = fc3(x)
return x
} }
} }
public func loss(model: MLP, x: MLXArray, y: MLXArray) -> MLXArray { public func loss(model: LeNet, x: MLXArray, y: MLXArray) -> MLXArray {
crossEntropy(logits: model(x), targets: y, reduction: .mean) crossEntropy(logits: model(x), targets: y, reduction: .mean)
} }
public func eval(model: MLP, x: MLXArray, y: MLXArray) -> MLXArray { public func eval(model: LeNet, x: MLXArray, y: MLXArray) -> MLXArray {
mean(argMax(model(x), axis: 1) .== y) mean(argMax(model(x), axis: 1) .== y)
} }

View File

@@ -1,13 +1,11 @@
# MNIST # MNIST
This is a port of the MNIST model and training code from: This is a port of the MNIST training code from the [Python MLX example](https://github.com/ml-explore/mlx-examples/blob/main/mnist). This example uses a [LeNet](https://en.wikipedia.org/wiki/LeNet) instead of an MLP.
- https://github.com/ml-explore/mlx-examples/blob/main/mnist
It provides code to: It provides code to:
- download the test/train data - Download the MNIST test/train data
- provides the MNIST model (MLP) - Build the LeNet
- some functions to shuffle and batch the data - Some functions to shuffle and batch the data
See [mnist-tool](../../Tools/mnist-tool) for an example of how to run this. The training loop also lives there. See [mnist-tool](../../Tools/mnist-tool) for an example of how to run this. The training loop also lives there.

View File

@@ -1,37 +1,20 @@
# MLX Swift Examples # MLX Swift Examples
Example [mlx-swift](https://github.com/ml-explore/mlx-swift) programs. Example [MLX Swift](https://github.com/ml-explore/mlx-swift) programs.
## MNISTTrainer - [MNISTTrainer](Applications/MNISTTrainer/README.md): An example that runs on
both iOS and macOS that downloads MNIST training data and trains a
[LeNet](https://en.wikipedia.org/wiki/LeNet).
An example that runs on both iOS and macOS that downloads MNIST training - [LLMEval](Applications/LLMEval/README.md): An example that runs on both iOS
data and trains an MNIST model. and macOS that downloads an LLM and tokenizer from Hugging Face and and
generates text from a given prompt.
- [README](Applications/MNISTTrainer/README.md) - [LinearModelTraining](Tools/LinearModelTraining/README.md): An example that
trains a simple linear model.
## LLMEval - [llm-tool](Tools/llm-tool/README.md): A command line tool for generating text
using a variety of LLMs available on the Hugging Face hub.
An example that runs on both iOS and macOS that downloads a LLM model
weights and tokenizer configuration from Hugging Face and evaluates
a prompt in-process.
- [README](Applications/LLMEval/README.md)
## LinearModelTraining
A simple linear model and a training loop.
- [README](Tools/LinearModelTraining/README.md)
## llm-tool
A command line tool for generating text using a variety of Hugging Face models:
- [README](Tools/llm-tool/README.md)
## mnist-tool
A command line tool for training an MNIST (MLP) model:
- [README](Tools/mnist-tool/README.md)
- [mnist-tool](Tools/mnist-tool/README.md): A command line tool for training a
a LeNet on MNIST.

View File

@@ -30,15 +30,11 @@ struct Train: AsyncParsableCommand {
@Option(name: .long, help: "The PRNG seed") @Option(name: .long, help: "The PRNG seed")
var seed: UInt64 = 0 var seed: UInt64 = 0
@Option var layers = 2
@Option var hidden = 32
@Option var batchSize = 256 @Option var batchSize = 256
@Option var epochs = 20 @Option var epochs = 20
@Option var learningRate: Float = 1e-1 @Option var learningRate: Float = 1e-1
@Option var classes = 10 @Option var device = DeviceType.gpu
@Option var device = DeviceType.cpu
@Flag var compile = false @Flag var compile = false
@@ -62,9 +58,7 @@ struct Train: AsyncParsableCommand {
let testLabels = data[.init(.test, .labels)]! let testLabels = data[.init(.test, .labels)]!
// create the model // create the model
let model = MLP( let model = LeNet()
layers: layers, inputDimensions: trainImages.dim(-1), hiddenDimensions: hidden,
outputDimensions: classes)
eval(model.parameters()) eval(model.parameters())
let lg = valueAndGrad(model: model, loss) let lg = valueAndGrad(model: model, loss)

View File

@@ -1,8 +1,6 @@
# mnist-tool # mnist-tool
See other README: See the [MNIST README.md](../../Libraries/MNIST/README.md).
- [MNIST](../../Libraries/MNIST/README.md)
### Building ### Building