LeNet on MNIST + readme update (#12)

* LeNet on MNIST + readme update

* tanh + remove device toggle

* remove device entirely
This commit is contained in:
Awni Hannun
2024-03-04 14:16:20 -08:00
committed by GitHub
parent dfc9f2fc01
commit 4ed4ec69e7
8 changed files with 56 additions and 86 deletions

View File

@@ -30,15 +30,11 @@ struct Train: AsyncParsableCommand {
@Option(name: .long, help: "The PRNG seed")
var seed: UInt64 = 0
@Option var layers = 2
@Option var hidden = 32
@Option var batchSize = 256
@Option var epochs = 20
@Option var learningRate: Float = 1e-1
@Option var classes = 10
@Option var device = DeviceType.cpu
@Option var device = DeviceType.gpu
@Flag var compile = false
@@ -62,9 +58,7 @@ struct Train: AsyncParsableCommand {
let testLabels = data[.init(.test, .labels)]!
// create the model
let model = MLP(
layers: layers, inputDimensions: trainImages.dim(-1), hiddenDimensions: hidden,
outputDimensions: classes)
let model = LeNet()
eval(model.parameters())
let lg = valueAndGrad(model: model, loss)

View File

@@ -1,8 +1,6 @@
# mnist-tool
See other README:
- [MNIST](../../Libraries/MNIST/README.md)
See the [MNIST README.md](../../Libraries/MNIST/README.md).
### Building