diff --git a/Applications/MNISTTrainer/ContentView.swift b/Applications/MNISTTrainer/ContentView.swift index b9fd4ee..4747386 100644 --- a/Applications/MNISTTrainer/ContentView.swift +++ b/Applications/MNISTTrainer/ContentView.swift @@ -12,9 +12,6 @@ struct ContentView: View { // the training loop @State var trainer = Trainer() - // toggle for cpu/gpu training - @State var cpu = true - var body: some View { VStack { Spacer() @@ -30,13 +27,10 @@ struct ContentView: View { Button("Train") { Task { - try! await trainer.run(device: cpu ? .cpu : .gpu) + try! await trainer.run() } } - Toggle("CPU", isOn: $cpu) - .frame(maxWidth: 150) - Spacer() } Spacer() @@ -50,12 +44,10 @@ class Trainer { var messages = [String]() - func run(device: Device = .cpu) async throws { + func run() async throws { // Note: this is pretty close to the code in `mnist-tool`, just // wrapped in an Observable to make it easy to display in SwiftUI - Device.setDefault(device: device) - // download & load the training data let url = URL(fileURLWithPath: NSTemporaryDirectory(), isDirectory: true) try await download(into: url) @@ -67,9 +59,7 @@ class Trainer { let testLabels = data[.init(.test, .labels)]! // create the model with random weights - let model = MLP( - layers: 2, inputDimensions: trainImages.dim(-1), hiddenDimensions: 32, - outputDimensions: 10) + let model = LeNet() eval(model.parameters()) // the training loop diff --git a/Applications/MNISTTrainer/README.md b/Applications/MNISTTrainer/README.md index 04e2b83..bd2adb8 100644 --- a/Applications/MNISTTrainer/README.md +++ b/Applications/MNISTTrainer/README.md @@ -1,13 +1,13 @@ # MNISTTrainer -This is an example showing how to do model training on both macOS and iOS. -This will download the MNIST training data, create a new models and train -it. It will show the timing per epoch and the test accuracy as it trains. +This is an example of model training that works on both macOS and iOS. +The example will download the MNIST training data, create a LeNet, and train +it. It will show the epoch time and test accuracy as it trains. You will need to set the Team on the MNISTTrainer target in order to build and run on iOS. Some notes about the setup: -- this will download test data over the network so MNISTTrainer -> Signing & Capabilities has the "Outgoing Connections (Client)" set in the App Sandbox -- the website it connects to uses http rather than https so it has a "App Transport Security Settings" in the Info.plist +- This will download test data over the network so MNISTTrainer -> Signing & Capabilities has the "Outgoing Connections (Client)" set in the App Sandbox +- The website it connects to uses http rather than https so it has a "App Transport Security Settings" in the Info.plist diff --git a/Libraries/MNIST/Files.swift b/Libraries/MNIST/Files.swift index 84957c2..9164f69 100644 --- a/Libraries/MNIST/Files.swift +++ b/Libraries/MNIST/Files.swift @@ -43,13 +43,13 @@ let files = [ name: "train-images-idx3-ubyte.gz", offset: 16, convert: { - $0.reshaped([-1, 28 * 28]).asType(.float32) / 255.0 + $0.reshaped([-1, 28, 28, 1]).asType(.float32) / 255.0 }), FileKind(.test, .images): LoadInfo( name: "t10k-images-idx3-ubyte.gz", offset: 16, convert: { - $0.reshaped([-1, 28 * 28]).asType(.float32) / 255.0 + $0.reshaped([-1, 28, 28, 1]).asType(.float32) / 255.0 }), FileKind(.training, .labels): LoadInfo( name: "train-labels-idx1-ubyte.gz", diff --git a/Libraries/MNIST/MNIST.swift b/Libraries/MNIST/MNIST.swift index 78d912b..a4db8e7 100644 --- a/Libraries/MNIST/MNIST.swift +++ b/Libraries/MNIST/MNIST.swift @@ -6,36 +6,43 @@ import MLXNN // based on https://github.com/ml-explore/mlx-examples/blob/main/mnist/main.py -public class MLP: Module, UnaryLayer { +public class LeNet: Module, UnaryLayer { - @ModuleInfo var layers: [Linear] + @ModuleInfo var conv1: Conv2d + @ModuleInfo var conv2: Conv2d + @ModuleInfo var pool1: MaxPool2d + @ModuleInfo var pool2: MaxPool2d + @ModuleInfo var fc1: Linear + @ModuleInfo var fc2: Linear + @ModuleInfo var fc3: Linear - public init(layers: Int, inputDimensions: Int, hiddenDimensions: Int, outputDimensions: Int) { - let layerSizes = - [inputDimensions] + Array(repeating: hiddenDimensions, count: layers) + [ - outputDimensions - ] - - self.layers = zip(layerSizes.dropLast(), layerSizes.dropFirst()) - .map { - Linear($0, $1) - } + override public init() { + conv1 = Conv2d(inputChannels: 1, outputChannels: 6, kernelSize: 5, padding: 2) + conv2 = Conv2d(inputChannels: 6, outputChannels: 16, kernelSize: 5, padding: 0) + pool1 = MaxPool2d(kernelSize: 2, stride: 2) + pool2 = MaxPool2d(kernelSize: 2, stride: 2) + fc1 = Linear(16 * 5 * 5, 120) + fc2 = Linear(120, 84) + fc3 = Linear(84, 10) } public func callAsFunction(_ x: MLXArray) -> MLXArray { var x = x - for l in layers.dropLast() { - x = relu(l(x)) - } - return layers.last!(x) + x = pool1(tanh(conv1(x))) + x = pool2(tanh(conv2(x))) + x = flattened(x, start: 1) + x = tanh(fc1(x)) + x = tanh(fc2(x)) + x = fc3(x) + return x } } -public func loss(model: MLP, x: MLXArray, y: MLXArray) -> MLXArray { +public func loss(model: LeNet, x: MLXArray, y: MLXArray) -> MLXArray { crossEntropy(logits: model(x), targets: y, reduction: .mean) } -public func eval(model: MLP, x: MLXArray, y: MLXArray) -> MLXArray { +public func eval(model: LeNet, x: MLXArray, y: MLXArray) -> MLXArray { mean(argMax(model(x), axis: 1) .== y) } diff --git a/Libraries/MNIST/README.md b/Libraries/MNIST/README.md index 5d7918d..a94b1ce 100644 --- a/Libraries/MNIST/README.md +++ b/Libraries/MNIST/README.md @@ -1,13 +1,11 @@ # MNIST -This is a port of the MNIST model and training code from: - -- https://github.com/ml-explore/mlx-examples/blob/main/mnist +This is a port of the MNIST training code from the [Python MLX example](https://github.com/ml-explore/mlx-examples/blob/main/mnist). This example uses a [LeNet](https://en.wikipedia.org/wiki/LeNet) instead of an MLP. It provides code to: -- download the test/train data -- provides the MNIST model (MLP) -- some functions to shuffle and batch the data +- Download the MNIST test/train data +- Build the LeNet +- Some functions to shuffle and batch the data -See [mnist-tool](../../Tools/mnist-tool) for an example of how to run this. The training loop also lives there. +See [mnist-tool](../../Tools/mnist-tool) for an example of how to run this. The training loop also lives there. diff --git a/README.md b/README.md index a6c0435..95914f6 100644 --- a/README.md +++ b/README.md @@ -1,37 +1,20 @@ # MLX Swift Examples -Example [mlx-swift](https://github.com/ml-explore/mlx-swift) programs. +Example [MLX Swift](https://github.com/ml-explore/mlx-swift) programs. -## MNISTTrainer +- [MNISTTrainer](Applications/MNISTTrainer/README.md): An example that runs on + both iOS and macOS that downloads MNIST training data and trains a + [LeNet](https://en.wikipedia.org/wiki/LeNet). -An example that runs on both iOS and macOS that downloads MNIST training -data and trains an MNIST model. +- [LLMEval](Applications/LLMEval/README.md): An example that runs on both iOS + and macOS that downloads an LLM and tokenizer from Hugging Face and and + generates text from a given prompt. -- [README](Applications/MNISTTrainer/README.md) +- [LinearModelTraining](Tools/LinearModelTraining/README.md): An example that + trains a simple linear model. -## LLMEval - -An example that runs on both iOS and macOS that downloads a LLM model -weights and tokenizer configuration from Hugging Face and evaluates -a prompt in-process. - -- [README](Applications/LLMEval/README.md) - -## LinearModelTraining - -A simple linear model and a training loop. - -- [README](Tools/LinearModelTraining/README.md) - -## llm-tool - -A command line tool for generating text using a variety of Hugging Face models: - -- [README](Tools/llm-tool/README.md) - -## mnist-tool - -A command line tool for training an MNIST (MLP) model: - -- [README](Tools/mnist-tool/README.md) +- [llm-tool](Tools/llm-tool/README.md): A command line tool for generating text + using a variety of LLMs available on the Hugging Face hub. +- [mnist-tool](Tools/mnist-tool/README.md): A command line tool for training a + a LeNet on MNIST. diff --git a/Tools/mnist-tool/MNISTTool.swift b/Tools/mnist-tool/MNISTTool.swift index 68fd4b3..1199784 100644 --- a/Tools/mnist-tool/MNISTTool.swift +++ b/Tools/mnist-tool/MNISTTool.swift @@ -30,15 +30,11 @@ struct Train: AsyncParsableCommand { @Option(name: .long, help: "The PRNG seed") var seed: UInt64 = 0 - @Option var layers = 2 - @Option var hidden = 32 @Option var batchSize = 256 @Option var epochs = 20 @Option var learningRate: Float = 1e-1 - @Option var classes = 10 - - @Option var device = DeviceType.cpu + @Option var device = DeviceType.gpu @Flag var compile = false @@ -62,9 +58,7 @@ struct Train: AsyncParsableCommand { let testLabels = data[.init(.test, .labels)]! // create the model - let model = MLP( - layers: layers, inputDimensions: trainImages.dim(-1), hiddenDimensions: hidden, - outputDimensions: classes) + let model = LeNet() eval(model.parameters()) let lg = valueAndGrad(model: model, loss) diff --git a/Tools/mnist-tool/README.md b/Tools/mnist-tool/README.md index 3bd745a..a815fe6 100644 --- a/Tools/mnist-tool/README.md +++ b/Tools/mnist-tool/README.md @@ -1,8 +1,6 @@ # mnist-tool -See other README: - -- [MNIST](../../Libraries/MNIST/README.md) +See the [MNIST README.md](../../Libraries/MNIST/README.md). ### Building