use memory limit API (#13)
* add buffer cache limit * swift-format * a more reasonable size * add memory stats to command line tool, update to final api * add note about changing models
This commit is contained in:
@@ -73,6 +73,9 @@ class LLMEvaluator {
|
|||||||
func load() async throws -> (LLMModel, LLM.Tokenizer) {
|
func load() async throws -> (LLMModel, LLM.Tokenizer) {
|
||||||
switch loadState {
|
switch loadState {
|
||||||
case .idle:
|
case .idle:
|
||||||
|
// limit the buffer cache
|
||||||
|
MLX.GPU.set(cacheLimit: 20 * 1024 * 1024)
|
||||||
|
|
||||||
let (model, tokenizer) = try await LLM.load(configuration: modelConfiguration) {
|
let (model, tokenizer) = try await LLM.load(configuration: modelConfiguration) {
|
||||||
[modelConfiguration] progress in
|
[modelConfiguration] progress in
|
||||||
DispatchQueue.main.sync {
|
DispatchQueue.main.sync {
|
||||||
@@ -80,6 +83,8 @@ class LLMEvaluator {
|
|||||||
"Downloading \(modelConfiguration.id): \(Int(progress.fractionCompleted * 100))%"
|
"Downloading \(modelConfiguration.id): \(Int(progress.fractionCompleted * 100))%"
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
self.output =
|
||||||
|
"Loaded \(modelConfiguration.id). Weights: \(MLX.GPU.activeMemory / 1024 / 1024)M"
|
||||||
loadState = .loaded(model, tokenizer)
|
loadState = .loaded(model, tokenizer)
|
||||||
return (model, tokenizer)
|
return (model, tokenizer)
|
||||||
|
|
||||||
|
|||||||
@@ -16,9 +16,25 @@ Some notes about the setup:
|
|||||||
|
|
||||||
- this downloads models from hugging face so LLMEval -> Signing & Capabilities has the "Outgoing Connections (Client)" set in the App Sandbox
|
- this downloads models from hugging face so LLMEval -> Signing & Capabilities has the "Outgoing Connections (Client)" set in the App Sandbox
|
||||||
- LLM models are large so this uses the Increased Memory Limit entitlement on iOS to allow ... increased memory limits for devices that have more memory
|
- LLM models are large so this uses the Increased Memory Limit entitlement on iOS to allow ... increased memory limits for devices that have more memory
|
||||||
|
- `MLX.GPU.set(cacheLimit: 20 * 1024 * 1024)` is used to limit the buffer cache size
|
||||||
- The Phi2 4 bit model is small enough to run on some iPhone models
|
- The Phi2 4 bit model is small enough to run on some iPhone models
|
||||||
- this can be changed by editing `let modelConfiguration = ModelConfiguration.phi4bit`
|
- this can be changed by editing `let modelConfiguration = ModelConfiguration.phi4bit`
|
||||||
|
|
||||||
|
### Trying Different Models
|
||||||
|
|
||||||
|
The example application uses Phi2 model by default, see [ContentView.swift](ContentView.swift#L58):
|
||||||
|
|
||||||
|
```
|
||||||
|
/// this controls which model loads -- phi4bit is one of the smaller ones so this will fit on
|
||||||
|
/// more devices
|
||||||
|
let modelConfiguration = ModelConfiguration.phi4bit
|
||||||
|
```
|
||||||
|
|
||||||
|
There are some pre-configured models in [LLM/Models.swift](../../Libraries/LLM/Models.swift#L62)
|
||||||
|
and you can load any weights from Hugging Face where there
|
||||||
|
is a model architecture defined and you have enough
|
||||||
|
memory.
|
||||||
|
|
||||||
### Troubleshooting
|
### Troubleshooting
|
||||||
|
|
||||||
If the program crashes with a very deep stack trace you may need to build
|
If the program crashes with a very deep stack trace you may need to build
|
||||||
|
|||||||
@@ -14,12 +14,7 @@ struct LLMTool: AsyncParsableCommand {
|
|||||||
defaultSubcommand: SyncGenerator.self)
|
defaultSubcommand: SyncGenerator.self)
|
||||||
}
|
}
|
||||||
|
|
||||||
struct SyncGenerator: AsyncParsableCommand {
|
struct LLMArguments: ParsableArguments {
|
||||||
|
|
||||||
static var configuration = CommandConfiguration(
|
|
||||||
commandName: "sync",
|
|
||||||
abstract: "Synchronous generator"
|
|
||||||
)
|
|
||||||
|
|
||||||
@Option(name: .long, help: "Name of the huggingface model")
|
@Option(name: .long, help: "Name of the huggingface model")
|
||||||
var model: String = "mlx-community/Mistral-7B-v0.1-hf-4bit-mlx"
|
var model: String = "mlx-community/Mistral-7B-v0.1-hf-4bit-mlx"
|
||||||
@@ -36,20 +31,91 @@ struct SyncGenerator: AsyncParsableCommand {
|
|||||||
@Option(name: .long, help: "The PRNG seed")
|
@Option(name: .long, help: "The PRNG seed")
|
||||||
var seed: UInt64 = 0
|
var seed: UInt64 = 0
|
||||||
|
|
||||||
@MainActor
|
@Flag(help: "Show memory stats")
|
||||||
func run() async throws {
|
var memoryStats = false
|
||||||
|
|
||||||
|
@Option(name: .long, help: "Maximum cache size in M")
|
||||||
|
var cacheSize: Int?
|
||||||
|
|
||||||
|
@Option(name: .long, help: "Maximum memory size in M")
|
||||||
|
var memorySize: Int?
|
||||||
|
|
||||||
|
var startMemory: GPU.Snapshot?
|
||||||
|
|
||||||
|
mutating func load() async throws -> (LLMModel, Tokenizer, ModelConfiguration) {
|
||||||
MLXRandom.seed(seed)
|
MLXRandom.seed(seed)
|
||||||
|
|
||||||
|
if let cacheSize {
|
||||||
|
GPU.set(cacheLimit: cacheSize * 1024 * 1024)
|
||||||
|
}
|
||||||
|
|
||||||
|
if let memorySize {
|
||||||
|
GPU.set(memoryLimit: memorySize * 1024 * 1024)
|
||||||
|
}
|
||||||
|
|
||||||
let modelConfiguration = ModelConfiguration.configuration(id: model)
|
let modelConfiguration = ModelConfiguration.configuration(id: model)
|
||||||
let (model, tokenizer) = try await load(configuration: modelConfiguration)
|
let (model, tokenizer) = try await LLM.load(configuration: modelConfiguration)
|
||||||
|
|
||||||
print("Model loaded -> \(self.model)")
|
startMemory = GPU.snapshot()
|
||||||
|
|
||||||
let prompt = modelConfiguration.prepare(prompt: self.prompt)
|
return (model, tokenizer, modelConfiguration)
|
||||||
|
}
|
||||||
|
|
||||||
|
func tokenizePropmpt(configuration: ModelConfiguration, tokenizer: Tokenizer) -> (String, [Int])
|
||||||
|
{
|
||||||
|
let prompt = configuration.prepare(prompt: self.prompt)
|
||||||
let promptTokens = tokenizer.encode(text: prompt)
|
let promptTokens = tokenizer.encode(text: prompt)
|
||||||
|
|
||||||
|
return (prompt, promptTokens)
|
||||||
|
}
|
||||||
|
|
||||||
|
func reportMemoryStatistics() {
|
||||||
|
if memoryStats, let startMemory {
|
||||||
|
let endMemory = GPU.snapshot()
|
||||||
|
|
||||||
|
print("=======")
|
||||||
|
print("Memory size: \(GPU.memoryLimit / 1024)K")
|
||||||
|
print("Cache size: \(GPU.cacheLimit / 1024)K")
|
||||||
|
|
||||||
|
print("")
|
||||||
|
print("=======")
|
||||||
|
print("Starting memory")
|
||||||
|
print(startMemory.description)
|
||||||
|
|
||||||
|
print("")
|
||||||
|
print("=======")
|
||||||
|
print("Ending memory")
|
||||||
|
print(endMemory.description)
|
||||||
|
|
||||||
|
print("")
|
||||||
|
print("=======")
|
||||||
|
print("Growth")
|
||||||
|
print(startMemory.delta(endMemory).description)
|
||||||
|
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
struct SyncGenerator: AsyncParsableCommand {
|
||||||
|
|
||||||
|
static var configuration = CommandConfiguration(
|
||||||
|
commandName: "sync",
|
||||||
|
abstract: "Synchronous generator"
|
||||||
|
)
|
||||||
|
|
||||||
|
@OptionGroup var args: LLMArguments
|
||||||
|
|
||||||
|
@MainActor
|
||||||
|
mutating func run() async throws {
|
||||||
|
let (model, tokenizer, modelConfiguration) = try await args.load()
|
||||||
|
|
||||||
|
print("Model loaded -> \(modelConfiguration.id)")
|
||||||
|
|
||||||
|
let (prompt, promptTokens) = args.tokenizePropmpt(
|
||||||
|
configuration: modelConfiguration, tokenizer: tokenizer)
|
||||||
|
|
||||||
print("Starting generation ...")
|
print("Starting generation ...")
|
||||||
print(self.prompt, terminator: "")
|
print(prompt, terminator: "")
|
||||||
|
|
||||||
var start = Date.timeIntervalSinceReferenceDate
|
var start = Date.timeIntervalSinceReferenceDate
|
||||||
var promptTime: TimeInterval = 0
|
var promptTime: TimeInterval = 0
|
||||||
@@ -59,7 +125,8 @@ struct SyncGenerator: AsyncParsableCommand {
|
|||||||
var tokens = [Int]()
|
var tokens = [Int]()
|
||||||
var printed = 0
|
var printed = 0
|
||||||
|
|
||||||
for token in TokenIterator(prompt: MLXArray(promptTokens), model: model, temp: temperature)
|
for token in TokenIterator(
|
||||||
|
prompt: MLXArray(promptTokens), model: model, temp: args.temperature)
|
||||||
{
|
{
|
||||||
if tokens.isEmpty {
|
if tokens.isEmpty {
|
||||||
eval(token)
|
eval(token)
|
||||||
@@ -83,7 +150,7 @@ struct SyncGenerator: AsyncParsableCommand {
|
|||||||
|
|
||||||
printed = fullOutput.count
|
printed = fullOutput.count
|
||||||
|
|
||||||
if tokens.count == maxTokens {
|
if tokens.count == args.maxTokens {
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -98,6 +165,8 @@ struct SyncGenerator: AsyncParsableCommand {
|
|||||||
Prompt Tokens per second: \((Double(promptTokens.count) / promptTime).formatted())
|
Prompt Tokens per second: \((Double(promptTokens.count) / promptTime).formatted())
|
||||||
Generation tokens per second: \((Double(tokens.count - 1) / generateTime).formatted())
|
Generation tokens per second: \((Double(tokens.count - 1) / generateTime).formatted())
|
||||||
""")
|
""")
|
||||||
|
|
||||||
|
args.reportMemoryStatistics()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -112,35 +181,19 @@ struct AsyncGenerator: AsyncParsableCommand {
|
|||||||
abstract: "async generator"
|
abstract: "async generator"
|
||||||
)
|
)
|
||||||
|
|
||||||
@Option(name: .long, help: "Name of the huggingface model")
|
@OptionGroup var args: LLMArguments
|
||||||
var model: String = "mlx-community/Mistral-7B-v0.1-hf-4bit-mlx"
|
|
||||||
|
|
||||||
@Option(name: .shortAndLong, help: "The message to be processed by the model")
|
|
||||||
var prompt = "compare python and swift"
|
|
||||||
|
|
||||||
@Option(name: .shortAndLong, help: "Maximum number of tokens to generate")
|
|
||||||
var maxTokens = 100
|
|
||||||
|
|
||||||
@Option(name: .shortAndLong, help: "The sampling temperature")
|
|
||||||
var temperature: Float = 0.6
|
|
||||||
|
|
||||||
@Option(name: .long, help: "The PRNG seed")
|
|
||||||
var seed: UInt64 = 0
|
|
||||||
|
|
||||||
@MainActor
|
@MainActor
|
||||||
func run() async throws {
|
mutating func run() async throws {
|
||||||
MLXRandom.seed(seed)
|
let (model, tokenizer, modelConfiguration) = try await args.load()
|
||||||
|
|
||||||
let modelConfiguration = ModelConfiguration.configuration(id: model)
|
print("Model loaded -> \(modelConfiguration.id)")
|
||||||
let (model, tokenizer) = try await load(configuration: modelConfiguration)
|
|
||||||
|
|
||||||
print("Model loaded -> \(self.model)")
|
let (prompt, promptTokens) = args.tokenizePropmpt(
|
||||||
|
configuration: modelConfiguration, tokenizer: tokenizer)
|
||||||
let prompt = modelConfiguration.prepare(prompt: self.prompt)
|
|
||||||
let promptTokens = tokenizer.encode(text: prompt)
|
|
||||||
|
|
||||||
print("Starting generation ...")
|
print("Starting generation ...")
|
||||||
print(self.prompt, terminator: "")
|
print(prompt, terminator: "")
|
||||||
|
|
||||||
var start = Date.timeIntervalSinceReferenceDate
|
var start = Date.timeIntervalSinceReferenceDate
|
||||||
var promptTime: TimeInterval = 0
|
var promptTime: TimeInterval = 0
|
||||||
@@ -151,7 +204,7 @@ struct AsyncGenerator: AsyncParsableCommand {
|
|||||||
var printed = 0
|
var printed = 0
|
||||||
|
|
||||||
let (task, channel) = generate(
|
let (task, channel) = generate(
|
||||||
prompt: MLXArray(promptTokens), model: model, temp: temperature)
|
prompt: MLXArray(promptTokens), model: model, temp: args.temperature)
|
||||||
|
|
||||||
for await token in channel {
|
for await token in channel {
|
||||||
if tokens.isEmpty {
|
if tokens.isEmpty {
|
||||||
@@ -174,7 +227,7 @@ struct AsyncGenerator: AsyncParsableCommand {
|
|||||||
|
|
||||||
printed = fullOutput.count
|
printed = fullOutput.count
|
||||||
|
|
||||||
if tokens.count == maxTokens {
|
if tokens.count == args.maxTokens {
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -193,6 +246,8 @@ struct AsyncGenerator: AsyncParsableCommand {
|
|||||||
Generation tokens per second: \((Double(tokens.count - 1) / generateTime).formatted())
|
Generation tokens per second: \((Double(tokens.count - 1) / generateTime).formatted())
|
||||||
""")
|
""")
|
||||||
|
|
||||||
|
args.reportMemoryStatistics()
|
||||||
|
|
||||||
// wait for the task to complete -- since it is running async, it might
|
// wait for the task to complete -- since it is running async, it might
|
||||||
// be in the middle of running the model
|
// be in the middle of running the model
|
||||||
try? await Task.sleep(for: .milliseconds(500))
|
try? await Task.sleep(for: .milliseconds(500))
|
||||||
|
|||||||
Reference in New Issue
Block a user