LeNet on MNIST + readme update (#12)
* LeNet on MNIST + readme update * tanh + remove device toggle * remove device entirely
This commit is contained in:
@@ -30,15 +30,11 @@ struct Train: AsyncParsableCommand {
|
||||
@Option(name: .long, help: "The PRNG seed")
|
||||
var seed: UInt64 = 0
|
||||
|
||||
@Option var layers = 2
|
||||
@Option var hidden = 32
|
||||
@Option var batchSize = 256
|
||||
@Option var epochs = 20
|
||||
@Option var learningRate: Float = 1e-1
|
||||
|
||||
@Option var classes = 10
|
||||
|
||||
@Option var device = DeviceType.cpu
|
||||
@Option var device = DeviceType.gpu
|
||||
|
||||
@Flag var compile = false
|
||||
|
||||
@@ -62,9 +58,7 @@ struct Train: AsyncParsableCommand {
|
||||
let testLabels = data[.init(.test, .labels)]!
|
||||
|
||||
// create the model
|
||||
let model = MLP(
|
||||
layers: layers, inputDimensions: trainImages.dim(-1), hiddenDimensions: hidden,
|
||||
outputDimensions: classes)
|
||||
let model = LeNet()
|
||||
eval(model.parameters())
|
||||
|
||||
let lg = valueAndGrad(model: model, loss)
|
||||
|
||||
@@ -1,8 +1,6 @@
|
||||
# mnist-tool
|
||||
|
||||
See other README:
|
||||
|
||||
- [MNIST](../../Libraries/MNIST/README.md)
|
||||
See the [MNIST README.md](../../Libraries/MNIST/README.md).
|
||||
|
||||
### Building
|
||||
|
||||
|
||||
Reference in New Issue
Block a user