initial commit
This commit is contained in:
113
Tools/LinearModelTraining/LinearModelTraining.swift
Normal file
113
Tools/LinearModelTraining/LinearModelTraining.swift
Normal file
@@ -0,0 +1,113 @@
|
||||
// Copyright © 2024 Apple Inc.
|
||||
|
||||
import ArgumentParser
|
||||
import Foundation
|
||||
import MLX
|
||||
import MLXNN
|
||||
import MLXOptimizers
|
||||
import MLXRandom
|
||||
|
||||
extension MLX.DeviceType: ExpressibleByArgument {
|
||||
public init?(argument: String) {
|
||||
self.init(rawValue: argument)
|
||||
}
|
||||
}
|
||||
|
||||
@main
|
||||
struct Train: AsyncParsableCommand {
|
||||
|
||||
@Option var epochs = 20
|
||||
@Option var batchSize = 8
|
||||
|
||||
@Option var m: Float = 0.25
|
||||
@Option var b: Float = 7
|
||||
|
||||
@Flag var compile = false
|
||||
|
||||
@Option var device = DeviceType.cpu
|
||||
|
||||
func run() async throws {
|
||||
Device.setDefault(device: Device(device))
|
||||
|
||||
// A very simple model that implements the equation
|
||||
// for a linear function: y = mx + b. This can be trained
|
||||
// to match data -- in this case an unknown (to the model)
|
||||
// linear function.
|
||||
//
|
||||
// This is a nice example because most people know how
|
||||
// linear functions work and we can see how the slope
|
||||
// and intercept converge.
|
||||
class LinearFunctionModel: Module, UnaryLayer {
|
||||
let m = MLXRandom.uniform(low: -5.0, high: 5.0)
|
||||
let b = MLXRandom.uniform(low: -5.0, high: 5.0)
|
||||
|
||||
func callAsFunction(_ x: MLXArray) -> MLXArray {
|
||||
m * x + b
|
||||
}
|
||||
}
|
||||
|
||||
// measure the distance from the prediction (model(x)) and the
|
||||
// ground truth (y). this gives feedback on how close the
|
||||
// prediction is from matching the truth
|
||||
func loss(model: LinearFunctionModel, x: MLXArray, y: MLXArray) -> MLXArray {
|
||||
mseLoss(predictions: model(x), targets: y, reduction: .mean)
|
||||
}
|
||||
|
||||
let model = LinearFunctionModel()
|
||||
eval(model.parameters())
|
||||
|
||||
let lg = valueAndGrad(model: model, loss)
|
||||
|
||||
// the optimizer will use the gradients update the model parameters
|
||||
let optimizer = SGD(learningRate: 1e-1)
|
||||
|
||||
// the function to train our model against -- it doesn't have
|
||||
// to be linear, but matching what the model models is easy
|
||||
// to understand
|
||||
func f(_ x: MLXArray) -> MLXArray {
|
||||
// these are the target parameters
|
||||
let m = self.m
|
||||
let b = self.b
|
||||
|
||||
// our actual function
|
||||
return m * x + b
|
||||
}
|
||||
|
||||
func step(_ x: MLXArray, _ y: MLXArray) -> MLXArray {
|
||||
let (loss, grads) = lg(model, x, y)
|
||||
optimizer.update(model: model, gradients: grads)
|
||||
return loss
|
||||
}
|
||||
|
||||
let resolvedStep =
|
||||
self.compile
|
||||
? MLX.compile(inputs: [model, optimizer], outputs: [model, optimizer], step) : step
|
||||
|
||||
for _ in 0 ..< epochs {
|
||||
// we expect that the parameters will approach the targets
|
||||
print("target: b = \(b), m = \(m)")
|
||||
print("parameters: \(model.parameters())")
|
||||
|
||||
// generate random training data along with the ground truth.
|
||||
// notice that the shape is [B, 1] where B is the batch
|
||||
// dimension -- this allows us to train on several samples simultaneously
|
||||
//
|
||||
// note: a very large batch size will take longer to converge because
|
||||
// the gradient will be representing too many samples down into
|
||||
// a single float parameter.
|
||||
let x = MLXRandom.uniform(low: -5.0, high: 5.0, [batchSize, 1])
|
||||
let y = f(x)
|
||||
eval(x, y)
|
||||
|
||||
// compute the loss and gradients. use the optimizer
|
||||
// to adjust the parameters closer to the target
|
||||
let loss = resolvedStep(x, y)
|
||||
|
||||
eval(model, optimizer)
|
||||
|
||||
// we should see this converge toward 0
|
||||
print("loss: \(loss)")
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
14
Tools/LinearModelTraining/README.md
Normal file
14
Tools/LinearModelTraining/README.md
Normal file
@@ -0,0 +1,14 @@
|
||||
# LinearModelTraining
|
||||
|
||||
A command line tool that creates a Model that represents:
|
||||
|
||||
f(x) = mx + b
|
||||
|
||||
and trains it against an unknown linear function. Very
|
||||
simple but illustrates:
|
||||
|
||||
- a very simple model with parameters
|
||||
- a loss function
|
||||
- the gradient
|
||||
- use of an optimizers
|
||||
- the training loop
|
||||
102
Tools/Tutorial/Tutorial.swift
Normal file
102
Tools/Tutorial/Tutorial.swift
Normal file
@@ -0,0 +1,102 @@
|
||||
// Copyright © 2024 Apple Inc.
|
||||
|
||||
import Foundation
|
||||
import MLX
|
||||
|
||||
/// mlx-swift tutorial based on:
|
||||
/// https://github.com/ml-explore/mlx/blob/main/examples/cpp/tutorial.cpp
|
||||
@main
|
||||
struct Tutorial {
|
||||
|
||||
static func scalarBasics() {
|
||||
// create a scalar array
|
||||
let x = MLXArray(1.0)
|
||||
|
||||
// the datatype is .float32
|
||||
let dtype = x.dtype
|
||||
assert(dtype == .float32)
|
||||
|
||||
// get the value
|
||||
let s = x.item(Float.self)
|
||||
assert(s == 1.0)
|
||||
|
||||
// reading the value with a different type is a fatal error
|
||||
// let i = x.item(Int.self)
|
||||
|
||||
// scalars have a size of 1
|
||||
let size = x.size
|
||||
assert(size == 1)
|
||||
|
||||
// scalars have 0 dimensions
|
||||
let ndim = x.ndim
|
||||
assert(ndim == 0)
|
||||
|
||||
// scalar shapes are empty arrays
|
||||
let shape = x.shape
|
||||
assert(shape == [])
|
||||
}
|
||||
|
||||
static func arrayBasics() {
|
||||
// make a multidimensional array.
|
||||
//
|
||||
// Note: the argument is a [Double] array literal, which is not
|
||||
// a supported type, but we can explicitly convert it to [Float]
|
||||
// when we create the MLXArray.
|
||||
let x = MLXArray(converting: [1.0, 2.0, 3.0, 4.0], [2, 2])
|
||||
|
||||
// mlx is row-major by default so the first row of this array
|
||||
// is [1.0, 2.0] and the second row is [3.0, 4.0]
|
||||
print(x[0])
|
||||
print(x[1])
|
||||
|
||||
// make an array of shape [2, 2] filled with ones
|
||||
let y = MLXArray.ones([2, 2])
|
||||
|
||||
// pointwise add x and y
|
||||
let z = x + y
|
||||
|
||||
// mlx is lazy by default. At this point `z` only
|
||||
// has a shape and a type but no actual data
|
||||
assert(z.dtype == .float32)
|
||||
assert(z.shape == [2, 2])
|
||||
|
||||
// To actually run the computation you must evaluate `z`.
|
||||
// Under the hood, mlx records operations in a graph.
|
||||
// The variable `z` is a node in the graph which points to its operation
|
||||
// and inputs. When `eval` is called on an array (or arrays), the array and
|
||||
// all of its dependencies are recursively evaluated to produce the result.
|
||||
// Once an array is evaluated, it has data and is detached from its inputs.
|
||||
|
||||
// Note: this is being called for demonstration purposes -- all reads
|
||||
// ensure the array is evaluated.
|
||||
z.eval()
|
||||
|
||||
// this implicitly evaluates z before converting to a description
|
||||
print(z)
|
||||
}
|
||||
|
||||
static func automaticDifferentiation() {
|
||||
func fn(_ x: MLXArray) -> MLXArray {
|
||||
x.square()
|
||||
}
|
||||
|
||||
let gradFn = grad(fn)
|
||||
|
||||
let x = MLXArray(1.5)
|
||||
let dfdx = gradFn(x)
|
||||
print(dfdx)
|
||||
|
||||
assert(dfdx.item() == Float(2 * 1.5))
|
||||
|
||||
let df2dx2 = grad(grad(fn))(x)
|
||||
print(df2dx2)
|
||||
|
||||
assert(df2dx2.item() == Float(2))
|
||||
}
|
||||
|
||||
static func main() {
|
||||
scalarBasics()
|
||||
arrayBasics()
|
||||
automaticDifferentiation()
|
||||
}
|
||||
}
|
||||
190
Tools/llm-tool/LLMTool.swift
Normal file
190
Tools/llm-tool/LLMTool.swift
Normal file
@@ -0,0 +1,190 @@
|
||||
// Copyright © 2024 Apple Inc.
|
||||
|
||||
import ArgumentParser
|
||||
import Foundation
|
||||
import LLM
|
||||
import MLX
|
||||
import MLXRandom
|
||||
|
||||
struct LLMTool: AsyncParsableCommand {
|
||||
static var configuration = CommandConfiguration(
|
||||
abstract: "Command line tool for generating text using Llama models",
|
||||
subcommands: [SyncGenerator.self, AsyncGenerator.self],
|
||||
defaultSubcommand: SyncGenerator.self)
|
||||
}
|
||||
|
||||
@main
|
||||
struct SyncGenerator: AsyncParsableCommand {
|
||||
|
||||
static var configuration = CommandConfiguration(
|
||||
commandName: "sync",
|
||||
abstract: "Synchronous generator"
|
||||
)
|
||||
|
||||
@Option(name: .long, help: "Name of the huggingface model")
|
||||
var model: String = "mlx-community/Mistral-7B-v0.1-hf-4bit-mlx"
|
||||
|
||||
@Option(name: .shortAndLong, help: "The message to be processed by the model")
|
||||
var prompt = "compare swift and python"
|
||||
|
||||
@Option(name: .shortAndLong, help: "Maximum number of tokens to generate")
|
||||
var maxTokens = 100
|
||||
|
||||
@Option(name: .shortAndLong, help: "The sampling temperature")
|
||||
var temperature: Float = 0.0
|
||||
|
||||
@Option(name: .long, help: "The PRNG seed")
|
||||
var seed: UInt64 = 0
|
||||
|
||||
@MainActor
|
||||
func run() async throws {
|
||||
MLXRandom.seed(seed)
|
||||
|
||||
let (model, tokenizer) = try await load(name: model)
|
||||
|
||||
print("Starting generation ...")
|
||||
print(prompt, terminator: "")
|
||||
|
||||
var start = Date.timeIntervalSinceReferenceDate
|
||||
var promptTime: TimeInterval = 0
|
||||
|
||||
let prompt = MLXArray(tokenizer.encode(text: prompt))
|
||||
|
||||
// collect the tokens and keep track of how much of the string
|
||||
// we have printed already
|
||||
var tokens = [Int]()
|
||||
var printed = 0
|
||||
|
||||
for token in TokenIterator(prompt: prompt, model: model, temp: temperature) {
|
||||
if tokens.isEmpty {
|
||||
eval(token)
|
||||
let now = Date.timeIntervalSinceReferenceDate
|
||||
promptTime = now - start
|
||||
start = now
|
||||
}
|
||||
|
||||
let t = token.item(Int.self)
|
||||
if t == tokenizer.unknownTokenId {
|
||||
break
|
||||
}
|
||||
tokens.append(t)
|
||||
|
||||
// print any new parts of the string
|
||||
let fullOutput = tokenizer.decode(tokens: tokens)
|
||||
let emitLength = fullOutput.count - printed
|
||||
let suffix = fullOutput.suffix(emitLength)
|
||||
print(suffix, terminator: "")
|
||||
fflush(stdout)
|
||||
|
||||
printed = fullOutput.count
|
||||
|
||||
if tokens.count == maxTokens {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
print()
|
||||
print("------")
|
||||
let now = Date.timeIntervalSinceReferenceDate
|
||||
let generateTime = now - start
|
||||
|
||||
print(
|
||||
"""
|
||||
Prompt Tokens per second: \((Double(prompt.size) / promptTime).formatted())
|
||||
Generation tokens per second: \((Double(tokens.count - 1) / generateTime).formatted())
|
||||
""")
|
||||
}
|
||||
}
|
||||
|
||||
/// Example of an async generator.
|
||||
///
|
||||
/// Note that all of the computation is done on another thread and TokenId (Int32) are sent
|
||||
/// rather than MLXArray.
|
||||
struct AsyncGenerator: AsyncParsableCommand {
|
||||
|
||||
static var configuration = CommandConfiguration(
|
||||
commandName: "async",
|
||||
abstract: "async generator"
|
||||
)
|
||||
|
||||
@Option(name: .long, help: "Name of the huggingface model")
|
||||
var model: String = "mlx-community/Mistral-7B-v0.1-hf-4bit-mlx"
|
||||
|
||||
@Option(name: .shortAndLong, help: "The message to be processed by the model")
|
||||
var prompt = "compare swift and python"
|
||||
|
||||
@Option(name: .shortAndLong, help: "Maximum number of tokens to generate")
|
||||
var maxTokens = 100
|
||||
|
||||
@Option(name: .shortAndLong, help: "The sampling temperature")
|
||||
var temperature: Float = 0.0
|
||||
|
||||
@Option(name: .long, help: "The PRNG seed")
|
||||
var seed: UInt64 = 0
|
||||
|
||||
@MainActor
|
||||
func run() async throws {
|
||||
MLXRandom.seed(seed)
|
||||
|
||||
let (model, tokenizer) = try await load(name: model)
|
||||
|
||||
print("Starting generation ...")
|
||||
print(prompt, terminator: "")
|
||||
|
||||
var start = Date.timeIntervalSinceReferenceDate
|
||||
var promptTime: TimeInterval = 0
|
||||
|
||||
let prompt = MLXArray(tokenizer.encode(text: prompt))
|
||||
|
||||
// collect the tokens and keep track of how much of the string
|
||||
// we have printed already
|
||||
var tokens = [Int]()
|
||||
var printed = 0
|
||||
|
||||
let (task, channel) = generate(prompt: prompt, model: model, temp: temperature)
|
||||
|
||||
for await token in channel {
|
||||
if tokens.isEmpty {
|
||||
let now = Date.timeIntervalSinceReferenceDate
|
||||
promptTime = now - start
|
||||
start = now
|
||||
}
|
||||
|
||||
if token == tokenizer.unknownTokenId {
|
||||
break
|
||||
}
|
||||
tokens.append(token)
|
||||
|
||||
// print any new parts of the string
|
||||
let fullOutput = tokenizer.decode(tokens: tokens)
|
||||
let emitLength = fullOutput.count - printed
|
||||
let suffix = fullOutput.suffix(emitLength)
|
||||
print(suffix, terminator: "")
|
||||
fflush(stdout)
|
||||
|
||||
printed = fullOutput.count
|
||||
|
||||
if tokens.count == maxTokens {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
// tell the task to stop
|
||||
task.cancel()
|
||||
|
||||
print()
|
||||
print("------")
|
||||
let now = Date.timeIntervalSinceReferenceDate
|
||||
let generateTime = now - start
|
||||
|
||||
print(
|
||||
"""
|
||||
Prompt Tokens per second: \((Double(prompt.size) / promptTime).formatted())
|
||||
Generation tokens per second: \((Double(tokens.count - 1) / generateTime).formatted())
|
||||
""")
|
||||
|
||||
// wait for the task to complete -- since it is running async, it might
|
||||
// be in the middle of running the model
|
||||
try? await Task.sleep(for: .milliseconds(500))
|
||||
}
|
||||
}
|
||||
38
Tools/llm-tool/README.md
Normal file
38
Tools/llm-tool/README.md
Normal file
@@ -0,0 +1,38 @@
|
||||
# llm-tool
|
||||
|
||||
See various READMEs:
|
||||
|
||||
- [Llama](../../Libraries/Llama/README.md)
|
||||
|
||||
### Building
|
||||
|
||||
Build the `llm-tool` scheme in Xcode.
|
||||
|
||||
### Running (Xcode)
|
||||
|
||||
To run this in Xcode simply press cmd-opt-r to set the scheme arguments. For example:
|
||||
|
||||
```
|
||||
--model mlx-community/Mistral-7B-v0.1-hf-4bit-mlx
|
||||
--prompt "swift programming language"
|
||||
--max-tokens 50
|
||||
```
|
||||
|
||||
Then cmd-r to run.
|
||||
|
||||
> Note: you may be prompted for access to your Documents directory -- this is where
|
||||
the huggingface HubApi stores the downloaded files.
|
||||
|
||||
### Running (Command Line)
|
||||
|
||||
`llm-tool` can also be run from the command line if built from Xcode, but
|
||||
the `DYLD_FRAMEWORK_PATH` must be set so that the frameworks and bundles can be found:
|
||||
|
||||
- [MLX troubleshooting](https://ml-explore.github.io/mlx-swift/MLX/documentation/mlx/troubleshooting)
|
||||
|
||||
The easiest way to do this is drag the Products/llm-tool into Terminal to get the path:
|
||||
|
||||
```
|
||||
DYLD_FRAMEWORK_PATH=~/Library/Developer/Xcode/DerivedData/mlx-examples-swift-ceuohnhzsownvsbbleukxoksddja/Build/Products/Debug ~/Library/Developer/Xcode/DerivedData/mlx-examples-swift-ceuohnhzsownvsbbleukxoksddja/Build/Products/Debug/llm-tool --prompt "swift programming language"
|
||||
```
|
||||
|
||||
108
Tools/mnist-tool/MNISTTool.swift
Normal file
108
Tools/mnist-tool/MNISTTool.swift
Normal file
@@ -0,0 +1,108 @@
|
||||
// Copyright © 2024 Apple Inc.
|
||||
|
||||
import ArgumentParser
|
||||
import Foundation
|
||||
import MLX
|
||||
import MLXNN
|
||||
import MLXOptimizers
|
||||
import MLXRandom
|
||||
import MNIST
|
||||
|
||||
@main
|
||||
struct MNISTTool: AsyncParsableCommand {
|
||||
static var configuration = CommandConfiguration(
|
||||
abstract: "Command line tool for training mnist models",
|
||||
subcommands: [Train.self],
|
||||
defaultSubcommand: Train.self)
|
||||
}
|
||||
|
||||
extension MLX.DeviceType: ExpressibleByArgument {
|
||||
public init?(argument: String) {
|
||||
self.init(rawValue: argument)
|
||||
}
|
||||
}
|
||||
|
||||
struct Train: AsyncParsableCommand {
|
||||
|
||||
@Option(name: .long, help: "Directory with the training data")
|
||||
var data: String
|
||||
|
||||
@Option(name: .long, help: "The PRNG seed")
|
||||
var seed: UInt64 = 0
|
||||
|
||||
@Option var layers = 2
|
||||
@Option var hidden = 32
|
||||
@Option var batchSize = 256
|
||||
@Option var epochs = 20
|
||||
@Option var learningRate: Float = 1e-1
|
||||
|
||||
@Option var classes = 10
|
||||
|
||||
@Option var device = DeviceType.cpu
|
||||
|
||||
@Flag var compile = false
|
||||
|
||||
func run() async throws {
|
||||
Device.setDefault(device: Device(device))
|
||||
|
||||
MLXRandom.seed(seed)
|
||||
var generator: RandomNumberGenerator = SplitMix64(seed: seed)
|
||||
|
||||
// load the data
|
||||
let url = URL(filePath: data)
|
||||
|
||||
try FileManager.default.createDirectory(at: url, withIntermediateDirectories: true)
|
||||
try await download(into: url)
|
||||
|
||||
let data = try load(from: url)
|
||||
|
||||
let trainImages = data[.init(.training, .images)]!
|
||||
let trainLabels = data[.init(.training, .labels)]!
|
||||
let testImages = data[.init(.test, .images)]!
|
||||
let testLabels = data[.init(.test, .labels)]!
|
||||
|
||||
// create the model
|
||||
let model = MLP(
|
||||
layers: layers, inputDimensions: trainImages.dim(-1), hiddenDimensions: hidden,
|
||||
outputDimensions: classes)
|
||||
eval(model.parameters())
|
||||
|
||||
let lg = valueAndGrad(model: model, loss)
|
||||
let optimizer = SGD(learningRate: learningRate)
|
||||
|
||||
func step(_ x: MLXArray, _ y: MLXArray) -> MLXArray {
|
||||
let (loss, grads) = lg(model, x, y)
|
||||
optimizer.update(model: model, gradients: grads)
|
||||
return loss
|
||||
}
|
||||
|
||||
let resolvedStep =
|
||||
compile
|
||||
? MLX.compile(inputs: [model, optimizer], outputs: [model, optimizer], step) : step
|
||||
|
||||
for e in 0 ..< epochs {
|
||||
let start = Date.timeIntervalSinceReferenceDate
|
||||
|
||||
for (x, y) in iterateBatches(
|
||||
batchSize: batchSize, x: trainImages, y: trainLabels, using: &generator)
|
||||
{
|
||||
_ = resolvedStep(x, y)
|
||||
|
||||
// eval the parameters so the next iteration is independent
|
||||
eval(model, optimizer)
|
||||
}
|
||||
|
||||
let accuracy = eval(model: model, x: testImages, y: testLabels)
|
||||
|
||||
let end = Date.timeIntervalSinceReferenceDate
|
||||
|
||||
print(
|
||||
"""
|
||||
Epoch \(e): test accuracy \(accuracy.item(Float.self).formatted())
|
||||
Time: \((end - start).formatted())
|
||||
|
||||
"""
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
36
Tools/mnist-tool/README.md
Normal file
36
Tools/mnist-tool/README.md
Normal file
@@ -0,0 +1,36 @@
|
||||
# mnist-tool
|
||||
|
||||
See other README:
|
||||
|
||||
- [MNIST](../../Libraries/MNIST/README.md)
|
||||
|
||||
### Building
|
||||
|
||||
`mnist-tool` has no dependencies outside of the package dependencies
|
||||
represented in xcode.
|
||||
|
||||
When you run the tool it will download the test/train datasets and
|
||||
store them in a specified directory (see run arguments -- default is /tmp).
|
||||
|
||||
Simply build the project in xcode.
|
||||
|
||||
### Running (Xcode)
|
||||
|
||||
To run this in Xcode simply press cmd-opt-r to set the scheme arguments. For example:
|
||||
|
||||
```
|
||||
--data /tmp
|
||||
```
|
||||
|
||||
Then cmd-r to run.
|
||||
|
||||
### Running (CommandLine)
|
||||
|
||||
`mnist-tool` can also be run from the command line if built from Xcode, but
|
||||
the `DYLD_FRAMEWORK_PATH` must be set so that the frameworks and bundles can be found:
|
||||
|
||||
- [MLX troubleshooting](https://ml-explore.github.io/mlx-swift/MLX/documentation/mlx/troubleshooting)
|
||||
|
||||
```
|
||||
DYLD_FRAMEWORK_PATH=~/Library/Developer/Xcode/DerivedData/mlx-examples-swift-ceuohnhzsownvsbbleukxoksddja/Build/Products/Debug ~/Library/Developer/Xcode/DerivedData/mlx-examples-swift-ceuohnhzsownvsbbleukxoksddja/Build/Products/Debug/mnist-tool --data /tmp
|
||||
```
|
||||
Reference in New Issue
Block a user