initial commit
This commit is contained in:
102
Tools/Tutorial/Tutorial.swift
Normal file
102
Tools/Tutorial/Tutorial.swift
Normal file
@@ -0,0 +1,102 @@
|
||||
// Copyright © 2024 Apple Inc.
|
||||
|
||||
import Foundation
|
||||
import MLX
|
||||
|
||||
/// mlx-swift tutorial based on:
|
||||
/// https://github.com/ml-explore/mlx/blob/main/examples/cpp/tutorial.cpp
|
||||
@main
|
||||
struct Tutorial {
|
||||
|
||||
static func scalarBasics() {
|
||||
// create a scalar array
|
||||
let x = MLXArray(1.0)
|
||||
|
||||
// the datatype is .float32
|
||||
let dtype = x.dtype
|
||||
assert(dtype == .float32)
|
||||
|
||||
// get the value
|
||||
let s = x.item(Float.self)
|
||||
assert(s == 1.0)
|
||||
|
||||
// reading the value with a different type is a fatal error
|
||||
// let i = x.item(Int.self)
|
||||
|
||||
// scalars have a size of 1
|
||||
let size = x.size
|
||||
assert(size == 1)
|
||||
|
||||
// scalars have 0 dimensions
|
||||
let ndim = x.ndim
|
||||
assert(ndim == 0)
|
||||
|
||||
// scalar shapes are empty arrays
|
||||
let shape = x.shape
|
||||
assert(shape == [])
|
||||
}
|
||||
|
||||
static func arrayBasics() {
|
||||
// make a multidimensional array.
|
||||
//
|
||||
// Note: the argument is a [Double] array literal, which is not
|
||||
// a supported type, but we can explicitly convert it to [Float]
|
||||
// when we create the MLXArray.
|
||||
let x = MLXArray(converting: [1.0, 2.0, 3.0, 4.0], [2, 2])
|
||||
|
||||
// mlx is row-major by default so the first row of this array
|
||||
// is [1.0, 2.0] and the second row is [3.0, 4.0]
|
||||
print(x[0])
|
||||
print(x[1])
|
||||
|
||||
// make an array of shape [2, 2] filled with ones
|
||||
let y = MLXArray.ones([2, 2])
|
||||
|
||||
// pointwise add x and y
|
||||
let z = x + y
|
||||
|
||||
// mlx is lazy by default. At this point `z` only
|
||||
// has a shape and a type but no actual data
|
||||
assert(z.dtype == .float32)
|
||||
assert(z.shape == [2, 2])
|
||||
|
||||
// To actually run the computation you must evaluate `z`.
|
||||
// Under the hood, mlx records operations in a graph.
|
||||
// The variable `z` is a node in the graph which points to its operation
|
||||
// and inputs. When `eval` is called on an array (or arrays), the array and
|
||||
// all of its dependencies are recursively evaluated to produce the result.
|
||||
// Once an array is evaluated, it has data and is detached from its inputs.
|
||||
|
||||
// Note: this is being called for demonstration purposes -- all reads
|
||||
// ensure the array is evaluated.
|
||||
z.eval()
|
||||
|
||||
// this implicitly evaluates z before converting to a description
|
||||
print(z)
|
||||
}
|
||||
|
||||
static func automaticDifferentiation() {
|
||||
func fn(_ x: MLXArray) -> MLXArray {
|
||||
x.square()
|
||||
}
|
||||
|
||||
let gradFn = grad(fn)
|
||||
|
||||
let x = MLXArray(1.5)
|
||||
let dfdx = gradFn(x)
|
||||
print(dfdx)
|
||||
|
||||
assert(dfdx.item() == Float(2 * 1.5))
|
||||
|
||||
let df2dx2 = grad(grad(fn))(x)
|
||||
print(df2dx2)
|
||||
|
||||
assert(df2dx2.item() == Float(2))
|
||||
}
|
||||
|
||||
static func main() {
|
||||
scalarBasics()
|
||||
arrayBasics()
|
||||
automaticDifferentiation()
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user