initial commit

This commit is contained in:
David Koski
2024-02-22 10:41:02 -08:00
commit b6d1e14465
29 changed files with 3856 additions and 0 deletions

102
Libraries/MNIST/Files.swift Normal file
View File

@@ -0,0 +1,102 @@
// Copyright © 2024 Apple Inc.
import Foundation
import Gzip
import MLX
// based on https://github.com/ml-explore/mlx-examples/blob/main/mnist/mnist.py
public enum Use: String, Hashable {
case test
case training
}
public enum DataKind: String, Hashable {
case images
case labels
}
public struct FileKind: Hashable, CustomStringConvertible {
let use: Use
let data: DataKind
public init(_ use: Use, _ data: DataKind) {
self.use = use
self.data = data
}
public var description: String {
"\(use.rawValue)-\(data.rawValue)"
}
}
struct LoadInfo {
let name: String
let offset: Int
let convert: (MLXArray) -> MLXArray
}
let baseURL = URL(string: "http://yann.lecun.com/exdb/mnist/")!
let files = [
FileKind(.training, .images): LoadInfo(
name: "train-images-idx3-ubyte.gz",
offset: 16,
convert: {
$0.reshaped([-1, 28 * 28]).asType(.float32) / 255.0
}),
FileKind(.test, .images): LoadInfo(
name: "t10k-images-idx3-ubyte.gz",
offset: 16,
convert: {
$0.reshaped([-1, 28 * 28]).asType(.float32) / 255.0
}),
FileKind(.training, .labels): LoadInfo(
name: "train-labels-idx1-ubyte.gz",
offset: 8,
convert: {
$0.asType(.uint32)
}),
FileKind(.test, .labels): LoadInfo(
name: "t10k-labels-idx1-ubyte.gz",
offset: 8,
convert: {
$0.asType(.uint32)
}),
]
public func download(into: URL) async throws {
for (_, info) in files {
let fileURL = into.appending(component: info.name)
if !FileManager.default.fileExists(atPath: fileURL.path()) {
print("Download: \(info.name)")
let url = baseURL.appending(component: info.name)
let (data, response) = try await URLSession.shared.data(from: url)
guard let httpResponse = response as? HTTPURLResponse else {
fatalError("Unable to download \(url), not an http response: \(response)")
}
guard httpResponse.statusCode == 200 else {
fatalError("Unable to download \(url): \(httpResponse)")
}
try data.write(to: fileURL)
}
}
}
public func load(from: URL) throws -> [FileKind: MLXArray] {
var result = [FileKind: MLXArray]()
for (key, info) in files {
let fileURL = from.appending(component: info.name)
let data = try Data(contentsOf: fileURL).gunzipped()
let array = MLXArray(
data.dropFirst(info.offset), [data.count - info.offset], type: UInt8.self)
result[key] = info.convert(array)
}
return result
}

1
Libraries/MNIST/MNIST.h Normal file
View File

@@ -0,0 +1 @@

View File

@@ -0,0 +1,73 @@
// Copyright © 2024 Apple Inc.
import Foundation
import MLX
import MLXNN
// based on https://github.com/ml-explore/mlx-examples/blob/main/mnist/main.py
public class MLP: Module, UnaryLayer {
@ModuleInfo var layers: [Linear]
public init(layers: Int, inputDimensions: Int, hiddenDimensions: Int, outputDimensions: Int) {
let layerSizes =
[inputDimensions] + Array(repeating: hiddenDimensions, count: layers) + [
outputDimensions
]
self.layers = zip(layerSizes.dropLast(), layerSizes.dropFirst())
.map {
Linear($0, $1)
}
}
public func callAsFunction(_ x: MLXArray) -> MLXArray {
var x = x
for l in layers.dropLast() {
x = relu(l(x))
}
return layers.last!(x)
}
}
public func loss(model: MLP, x: MLXArray, y: MLXArray) -> MLXArray {
crossEntropy(logits: model(x), targets: y, reduction: .mean)
}
public func eval(model: MLP, x: MLXArray, y: MLXArray) -> MLXArray {
mean(argMax(model(x), axis: 1) .== y)
}
private struct BatchSequence: Sequence, IteratorProtocol {
let batchSize: Int
let x: MLXArray
let y: MLXArray
let indexes: MLXArray
var index = 0
init(batchSize: Int, x: MLXArray, y: MLXArray, using generator: inout any RandomNumberGenerator)
{
self.batchSize = batchSize
self.x = x
self.y = y
self.indexes = MLXArray(Array(0 ..< y.size).shuffled(using: &generator))
}
mutating func next() -> (MLXArray, MLXArray)? {
guard index < y.size else { return nil }
let range = index ..< Swift.min(index + batchSize, y.size)
index += batchSize
let ids = indexes[range]
return (x[ids], y[ids])
}
}
public func iterateBatches(
batchSize: Int, x: MLXArray, y: MLXArray, using generator: inout any RandomNumberGenerator
) -> some Sequence<(MLXArray, MLXArray)> {
BatchSequence(batchSize: batchSize, x: x, y: y, using: &generator)
}

13
Libraries/MNIST/README.md Normal file
View File

@@ -0,0 +1,13 @@
# MNIST
This is a port of the MNIST model and training code from:
- https://github.com/ml-explore/mlx-examples/blob/main/mnist
It provides code to:
- download the test/train data
- provides the MNIST model (MLP)
- some functions to shuffle and batch the data
See [mnist-tool](../../Tools/mnist-tool) for an example of how to run this. The training loop also lives there.

View File

@@ -0,0 +1,30 @@
// Copyright © 2024 Apple Inc.
import Foundation
// From https://github.com/apple/swift/blob/cb0fb1ea051631219c0b944b84c78571448d58c2/benchmark/utils/TestsUtils.swift#L254
//
// This is just a seedable RandomNumberGenerator for shuffle()
// This is a fixed-increment version of Java 8's SplittableRandom generator.
// It is a very fast generator passing BigCrush, with 64 bits of state.
// See http://dx.doi.org/10.1145/2714064.2660195 and
// http://docs.oracle.com/javase/8/docs/api/java/util/SplittableRandom.html
//
// Derived from public domain C implementation by Sebastiano Vigna
// See http://xoshiro.di.unimi.it/splitmix64.c
public struct SplitMix64: RandomNumberGenerator {
private var state: UInt64
public init(seed: UInt64) {
self.state = seed
}
public mutating func next() -> UInt64 {
self.state &+= 0x9e37_79b9_7f4a_7c15
var z: UInt64 = self.state
z = (z ^ (z &>> 30)) &* 0xbf58_476d_1ce4_e5b9
z = (z ^ (z &>> 27)) &* 0x94d0_49bb_1331_11eb
return z ^ (z &>> 31)
}
}