LeNet on MNIST + readme update (#12)
* LeNet on MNIST + readme update * tanh + remove device toggle * remove device entirely
This commit is contained in:
@@ -43,13 +43,13 @@ let files = [
|
||||
name: "train-images-idx3-ubyte.gz",
|
||||
offset: 16,
|
||||
convert: {
|
||||
$0.reshaped([-1, 28 * 28]).asType(.float32) / 255.0
|
||||
$0.reshaped([-1, 28, 28, 1]).asType(.float32) / 255.0
|
||||
}),
|
||||
FileKind(.test, .images): LoadInfo(
|
||||
name: "t10k-images-idx3-ubyte.gz",
|
||||
offset: 16,
|
||||
convert: {
|
||||
$0.reshaped([-1, 28 * 28]).asType(.float32) / 255.0
|
||||
$0.reshaped([-1, 28, 28, 1]).asType(.float32) / 255.0
|
||||
}),
|
||||
FileKind(.training, .labels): LoadInfo(
|
||||
name: "train-labels-idx1-ubyte.gz",
|
||||
|
||||
@@ -6,36 +6,43 @@ import MLXNN
|
||||
|
||||
// based on https://github.com/ml-explore/mlx-examples/blob/main/mnist/main.py
|
||||
|
||||
public class MLP: Module, UnaryLayer {
|
||||
public class LeNet: Module, UnaryLayer {
|
||||
|
||||
@ModuleInfo var layers: [Linear]
|
||||
@ModuleInfo var conv1: Conv2d
|
||||
@ModuleInfo var conv2: Conv2d
|
||||
@ModuleInfo var pool1: MaxPool2d
|
||||
@ModuleInfo var pool2: MaxPool2d
|
||||
@ModuleInfo var fc1: Linear
|
||||
@ModuleInfo var fc2: Linear
|
||||
@ModuleInfo var fc3: Linear
|
||||
|
||||
public init(layers: Int, inputDimensions: Int, hiddenDimensions: Int, outputDimensions: Int) {
|
||||
let layerSizes =
|
||||
[inputDimensions] + Array(repeating: hiddenDimensions, count: layers) + [
|
||||
outputDimensions
|
||||
]
|
||||
|
||||
self.layers = zip(layerSizes.dropLast(), layerSizes.dropFirst())
|
||||
.map {
|
||||
Linear($0, $1)
|
||||
}
|
||||
override public init() {
|
||||
conv1 = Conv2d(inputChannels: 1, outputChannels: 6, kernelSize: 5, padding: 2)
|
||||
conv2 = Conv2d(inputChannels: 6, outputChannels: 16, kernelSize: 5, padding: 0)
|
||||
pool1 = MaxPool2d(kernelSize: 2, stride: 2)
|
||||
pool2 = MaxPool2d(kernelSize: 2, stride: 2)
|
||||
fc1 = Linear(16 * 5 * 5, 120)
|
||||
fc2 = Linear(120, 84)
|
||||
fc3 = Linear(84, 10)
|
||||
}
|
||||
|
||||
public func callAsFunction(_ x: MLXArray) -> MLXArray {
|
||||
var x = x
|
||||
for l in layers.dropLast() {
|
||||
x = relu(l(x))
|
||||
}
|
||||
return layers.last!(x)
|
||||
x = pool1(tanh(conv1(x)))
|
||||
x = pool2(tanh(conv2(x)))
|
||||
x = flattened(x, start: 1)
|
||||
x = tanh(fc1(x))
|
||||
x = tanh(fc2(x))
|
||||
x = fc3(x)
|
||||
return x
|
||||
}
|
||||
}
|
||||
|
||||
public func loss(model: MLP, x: MLXArray, y: MLXArray) -> MLXArray {
|
||||
public func loss(model: LeNet, x: MLXArray, y: MLXArray) -> MLXArray {
|
||||
crossEntropy(logits: model(x), targets: y, reduction: .mean)
|
||||
}
|
||||
|
||||
public func eval(model: MLP, x: MLXArray, y: MLXArray) -> MLXArray {
|
||||
public func eval(model: LeNet, x: MLXArray, y: MLXArray) -> MLXArray {
|
||||
mean(argMax(model(x), axis: 1) .== y)
|
||||
}
|
||||
|
||||
|
||||
@@ -1,13 +1,11 @@
|
||||
# MNIST
|
||||
|
||||
This is a port of the MNIST model and training code from:
|
||||
|
||||
- https://github.com/ml-explore/mlx-examples/blob/main/mnist
|
||||
This is a port of the MNIST training code from the [Python MLX example](https://github.com/ml-explore/mlx-examples/blob/main/mnist). This example uses a [LeNet](https://en.wikipedia.org/wiki/LeNet) instead of an MLP.
|
||||
|
||||
It provides code to:
|
||||
|
||||
- download the test/train data
|
||||
- provides the MNIST model (MLP)
|
||||
- some functions to shuffle and batch the data
|
||||
- Download the MNIST test/train data
|
||||
- Build the LeNet
|
||||
- Some functions to shuffle and batch the data
|
||||
|
||||
See [mnist-tool](../../Tools/mnist-tool) for an example of how to run this. The training loop also lives there.
|
||||
See [mnist-tool](../../Tools/mnist-tool) for an example of how to run this. The training loop also lives there.
|
||||
|
||||
Reference in New Issue
Block a user