Files
mlx-swift-examples/Libraries/MNIST/Files.swift
Awni Hannun 4ed4ec69e7 LeNet on MNIST + readme update (#12)
* LeNet on MNIST + readme update

* tanh + remove device toggle

* remove device entirely
2024-03-04 14:16:20 -08:00

103 lines
2.7 KiB
Swift

// Copyright © 2024 Apple Inc.
import Foundation
import Gzip
import MLX
// based on https://github.com/ml-explore/mlx-examples/blob/main/mnist/mnist.py
public enum Use: String, Hashable {
case test
case training
}
public enum DataKind: String, Hashable {
case images
case labels
}
public struct FileKind: Hashable, CustomStringConvertible {
let use: Use
let data: DataKind
public init(_ use: Use, _ data: DataKind) {
self.use = use
self.data = data
}
public var description: String {
"\(use.rawValue)-\(data.rawValue)"
}
}
struct LoadInfo {
let name: String
let offset: Int
let convert: (MLXArray) -> MLXArray
}
let baseURL = URL(string: "http://yann.lecun.com/exdb/mnist/")!
let files = [
FileKind(.training, .images): LoadInfo(
name: "train-images-idx3-ubyte.gz",
offset: 16,
convert: {
$0.reshaped([-1, 28, 28, 1]).asType(.float32) / 255.0
}),
FileKind(.test, .images): LoadInfo(
name: "t10k-images-idx3-ubyte.gz",
offset: 16,
convert: {
$0.reshaped([-1, 28, 28, 1]).asType(.float32) / 255.0
}),
FileKind(.training, .labels): LoadInfo(
name: "train-labels-idx1-ubyte.gz",
offset: 8,
convert: {
$0.asType(.uint32)
}),
FileKind(.test, .labels): LoadInfo(
name: "t10k-labels-idx1-ubyte.gz",
offset: 8,
convert: {
$0.asType(.uint32)
}),
]
public func download(into: URL) async throws {
for (_, info) in files {
let fileURL = into.appending(component: info.name)
if !FileManager.default.fileExists(atPath: fileURL.path()) {
print("Download: \(info.name)")
let url = baseURL.appending(component: info.name)
let (data, response) = try await URLSession.shared.data(from: url)
guard let httpResponse = response as? HTTPURLResponse else {
fatalError("Unable to download \(url), not an http response: \(response)")
}
guard httpResponse.statusCode == 200 else {
fatalError("Unable to download \(url): \(httpResponse)")
}
try data.write(to: fileURL)
}
}
}
public func load(from: URL) throws -> [FileKind: MLXArray] {
var result = [FileKind: MLXArray]()
for (key, info) in files {
let fileURL = from.appending(component: info.name)
let data = try Data(contentsOf: fileURL).gunzipped()
let array = MLXArray(
data.dropFirst(info.offset), [data.count - info.offset], type: UInt8.self)
result[key] = info.convert(array)
}
return result
}