initial commit
This commit is contained in:
102
Libraries/MNIST/Files.swift
Normal file
102
Libraries/MNIST/Files.swift
Normal file
@@ -0,0 +1,102 @@
|
||||
// Copyright © 2024 Apple Inc.
|
||||
|
||||
import Foundation
|
||||
import Gzip
|
||||
import MLX
|
||||
|
||||
// based on https://github.com/ml-explore/mlx-examples/blob/main/mnist/mnist.py
|
||||
|
||||
public enum Use: String, Hashable {
|
||||
case test
|
||||
case training
|
||||
}
|
||||
|
||||
public enum DataKind: String, Hashable {
|
||||
case images
|
||||
case labels
|
||||
}
|
||||
|
||||
public struct FileKind: Hashable, CustomStringConvertible {
|
||||
let use: Use
|
||||
let data: DataKind
|
||||
|
||||
public init(_ use: Use, _ data: DataKind) {
|
||||
self.use = use
|
||||
self.data = data
|
||||
}
|
||||
|
||||
public var description: String {
|
||||
"\(use.rawValue)-\(data.rawValue)"
|
||||
}
|
||||
}
|
||||
|
||||
struct LoadInfo {
|
||||
let name: String
|
||||
let offset: Int
|
||||
let convert: (MLXArray) -> MLXArray
|
||||
}
|
||||
|
||||
let baseURL = URL(string: "http://yann.lecun.com/exdb/mnist/")!
|
||||
|
||||
let files = [
|
||||
FileKind(.training, .images): LoadInfo(
|
||||
name: "train-images-idx3-ubyte.gz",
|
||||
offset: 16,
|
||||
convert: {
|
||||
$0.reshaped([-1, 28 * 28]).asType(.float32) / 255.0
|
||||
}),
|
||||
FileKind(.test, .images): LoadInfo(
|
||||
name: "t10k-images-idx3-ubyte.gz",
|
||||
offset: 16,
|
||||
convert: {
|
||||
$0.reshaped([-1, 28 * 28]).asType(.float32) / 255.0
|
||||
}),
|
||||
FileKind(.training, .labels): LoadInfo(
|
||||
name: "train-labels-idx1-ubyte.gz",
|
||||
offset: 8,
|
||||
convert: {
|
||||
$0.asType(.uint32)
|
||||
}),
|
||||
FileKind(.test, .labels): LoadInfo(
|
||||
name: "t10k-labels-idx1-ubyte.gz",
|
||||
offset: 8,
|
||||
convert: {
|
||||
$0.asType(.uint32)
|
||||
}),
|
||||
]
|
||||
|
||||
public func download(into: URL) async throws {
|
||||
for (_, info) in files {
|
||||
let fileURL = into.appending(component: info.name)
|
||||
if !FileManager.default.fileExists(atPath: fileURL.path()) {
|
||||
print("Download: \(info.name)")
|
||||
let url = baseURL.appending(component: info.name)
|
||||
let (data, response) = try await URLSession.shared.data(from: url)
|
||||
|
||||
guard let httpResponse = response as? HTTPURLResponse else {
|
||||
fatalError("Unable to download \(url), not an http response: \(response)")
|
||||
}
|
||||
guard httpResponse.statusCode == 200 else {
|
||||
fatalError("Unable to download \(url): \(httpResponse)")
|
||||
}
|
||||
|
||||
try data.write(to: fileURL)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
public func load(from: URL) throws -> [FileKind: MLXArray] {
|
||||
var result = [FileKind: MLXArray]()
|
||||
|
||||
for (key, info) in files {
|
||||
let fileURL = from.appending(component: info.name)
|
||||
let data = try Data(contentsOf: fileURL).gunzipped()
|
||||
|
||||
let array = MLXArray(
|
||||
data.dropFirst(info.offset), [data.count - info.offset], type: UInt8.self)
|
||||
|
||||
result[key] = info.convert(array)
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
1
Libraries/MNIST/MNIST.h
Normal file
1
Libraries/MNIST/MNIST.h
Normal file
@@ -0,0 +1 @@
|
||||
|
||||
73
Libraries/MNIST/MNIST.swift
Normal file
73
Libraries/MNIST/MNIST.swift
Normal file
@@ -0,0 +1,73 @@
|
||||
// Copyright © 2024 Apple Inc.
|
||||
|
||||
import Foundation
|
||||
import MLX
|
||||
import MLXNN
|
||||
|
||||
// based on https://github.com/ml-explore/mlx-examples/blob/main/mnist/main.py
|
||||
|
||||
public class MLP: Module, UnaryLayer {
|
||||
|
||||
@ModuleInfo var layers: [Linear]
|
||||
|
||||
public init(layers: Int, inputDimensions: Int, hiddenDimensions: Int, outputDimensions: Int) {
|
||||
let layerSizes =
|
||||
[inputDimensions] + Array(repeating: hiddenDimensions, count: layers) + [
|
||||
outputDimensions
|
||||
]
|
||||
|
||||
self.layers = zip(layerSizes.dropLast(), layerSizes.dropFirst())
|
||||
.map {
|
||||
Linear($0, $1)
|
||||
}
|
||||
}
|
||||
|
||||
public func callAsFunction(_ x: MLXArray) -> MLXArray {
|
||||
var x = x
|
||||
for l in layers.dropLast() {
|
||||
x = relu(l(x))
|
||||
}
|
||||
return layers.last!(x)
|
||||
}
|
||||
}
|
||||
|
||||
public func loss(model: MLP, x: MLXArray, y: MLXArray) -> MLXArray {
|
||||
crossEntropy(logits: model(x), targets: y, reduction: .mean)
|
||||
}
|
||||
|
||||
public func eval(model: MLP, x: MLXArray, y: MLXArray) -> MLXArray {
|
||||
mean(argMax(model(x), axis: 1) .== y)
|
||||
}
|
||||
|
||||
private struct BatchSequence: Sequence, IteratorProtocol {
|
||||
|
||||
let batchSize: Int
|
||||
let x: MLXArray
|
||||
let y: MLXArray
|
||||
|
||||
let indexes: MLXArray
|
||||
var index = 0
|
||||
|
||||
init(batchSize: Int, x: MLXArray, y: MLXArray, using generator: inout any RandomNumberGenerator)
|
||||
{
|
||||
self.batchSize = batchSize
|
||||
self.x = x
|
||||
self.y = y
|
||||
self.indexes = MLXArray(Array(0 ..< y.size).shuffled(using: &generator))
|
||||
}
|
||||
|
||||
mutating func next() -> (MLXArray, MLXArray)? {
|
||||
guard index < y.size else { return nil }
|
||||
|
||||
let range = index ..< Swift.min(index + batchSize, y.size)
|
||||
index += batchSize
|
||||
let ids = indexes[range]
|
||||
return (x[ids], y[ids])
|
||||
}
|
||||
}
|
||||
|
||||
public func iterateBatches(
|
||||
batchSize: Int, x: MLXArray, y: MLXArray, using generator: inout any RandomNumberGenerator
|
||||
) -> some Sequence<(MLXArray, MLXArray)> {
|
||||
BatchSequence(batchSize: batchSize, x: x, y: y, using: &generator)
|
||||
}
|
||||
13
Libraries/MNIST/README.md
Normal file
13
Libraries/MNIST/README.md
Normal file
@@ -0,0 +1,13 @@
|
||||
# MNIST
|
||||
|
||||
This is a port of the MNIST model and training code from:
|
||||
|
||||
- https://github.com/ml-explore/mlx-examples/blob/main/mnist
|
||||
|
||||
It provides code to:
|
||||
|
||||
- download the test/train data
|
||||
- provides the MNIST model (MLP)
|
||||
- some functions to shuffle and batch the data
|
||||
|
||||
See [mnist-tool](../../Tools/mnist-tool) for an example of how to run this. The training loop also lives there.
|
||||
30
Libraries/MNIST/Random.swift
Normal file
30
Libraries/MNIST/Random.swift
Normal file
@@ -0,0 +1,30 @@
|
||||
// Copyright © 2024 Apple Inc.
|
||||
|
||||
import Foundation
|
||||
|
||||
// From https://github.com/apple/swift/blob/cb0fb1ea051631219c0b944b84c78571448d58c2/benchmark/utils/TestsUtils.swift#L254
|
||||
//
|
||||
// This is just a seedable RandomNumberGenerator for shuffle()
|
||||
|
||||
// This is a fixed-increment version of Java 8's SplittableRandom generator.
|
||||
// It is a very fast generator passing BigCrush, with 64 bits of state.
|
||||
// See http://dx.doi.org/10.1145/2714064.2660195 and
|
||||
// http://docs.oracle.com/javase/8/docs/api/java/util/SplittableRandom.html
|
||||
//
|
||||
// Derived from public domain C implementation by Sebastiano Vigna
|
||||
// See http://xoshiro.di.unimi.it/splitmix64.c
|
||||
public struct SplitMix64: RandomNumberGenerator {
|
||||
private var state: UInt64
|
||||
|
||||
public init(seed: UInt64) {
|
||||
self.state = seed
|
||||
}
|
||||
|
||||
public mutating func next() -> UInt64 {
|
||||
self.state &+= 0x9e37_79b9_7f4a_7c15
|
||||
var z: UInt64 = self.state
|
||||
z = (z ^ (z &>> 30)) &* 0xbf58_476d_1ce4_e5b9
|
||||
z = (z ^ (z &>> 27)) &* 0x94d0_49bb_1331_11eb
|
||||
return z ^ (z &>> 31)
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user