LeNet on MNIST + readme update (#12)

* LeNet on MNIST + readme update

* tanh + remove device toggle

* remove device entirely
This commit is contained in:
Awni Hannun
2024-03-04 14:16:20 -08:00
committed by GitHub
parent dfc9f2fc01
commit 4ed4ec69e7
8 changed files with 56 additions and 86 deletions

View File

@@ -6,36 +6,43 @@ import MLXNN
// based on https://github.com/ml-explore/mlx-examples/blob/main/mnist/main.py
public class MLP: Module, UnaryLayer {
public class LeNet: Module, UnaryLayer {
@ModuleInfo var layers: [Linear]
@ModuleInfo var conv1: Conv2d
@ModuleInfo var conv2: Conv2d
@ModuleInfo var pool1: MaxPool2d
@ModuleInfo var pool2: MaxPool2d
@ModuleInfo var fc1: Linear
@ModuleInfo var fc2: Linear
@ModuleInfo var fc3: Linear
public init(layers: Int, inputDimensions: Int, hiddenDimensions: Int, outputDimensions: Int) {
let layerSizes =
[inputDimensions] + Array(repeating: hiddenDimensions, count: layers) + [
outputDimensions
]
self.layers = zip(layerSizes.dropLast(), layerSizes.dropFirst())
.map {
Linear($0, $1)
}
override public init() {
conv1 = Conv2d(inputChannels: 1, outputChannels: 6, kernelSize: 5, padding: 2)
conv2 = Conv2d(inputChannels: 6, outputChannels: 16, kernelSize: 5, padding: 0)
pool1 = MaxPool2d(kernelSize: 2, stride: 2)
pool2 = MaxPool2d(kernelSize: 2, stride: 2)
fc1 = Linear(16 * 5 * 5, 120)
fc2 = Linear(120, 84)
fc3 = Linear(84, 10)
}
public func callAsFunction(_ x: MLXArray) -> MLXArray {
var x = x
for l in layers.dropLast() {
x = relu(l(x))
}
return layers.last!(x)
x = pool1(tanh(conv1(x)))
x = pool2(tanh(conv2(x)))
x = flattened(x, start: 1)
x = tanh(fc1(x))
x = tanh(fc2(x))
x = fc3(x)
return x
}
}
public func loss(model: MLP, x: MLXArray, y: MLXArray) -> MLXArray {
public func loss(model: LeNet, x: MLXArray, y: MLXArray) -> MLXArray {
crossEntropy(logits: model(x), targets: y, reduction: .mean)
}
public func eval(model: MLP, x: MLXArray, y: MLXArray) -> MLXArray {
public func eval(model: LeNet, x: MLXArray, y: MLXArray) -> MLXArray {
mean(argMax(model(x), axis: 1) .== y)
}