use memory limit API (#13)

* add buffer cache limit

* swift-format

* a more reasonable size

* add memory stats to command line tool, update to final api

* add note about changing models
This commit is contained in:
David Koski
2024-03-05 15:22:12 -08:00
committed by GitHub
parent 430b464c8d
commit 61105bf0c4
3 changed files with 115 additions and 39 deletions

View File

@@ -73,6 +73,9 @@ class LLMEvaluator {
func load() async throws -> (LLMModel, LLM.Tokenizer) { func load() async throws -> (LLMModel, LLM.Tokenizer) {
switch loadState { switch loadState {
case .idle: case .idle:
// limit the buffer cache
MLX.GPU.set(cacheLimit: 20 * 1024 * 1024)
let (model, tokenizer) = try await LLM.load(configuration: modelConfiguration) { let (model, tokenizer) = try await LLM.load(configuration: modelConfiguration) {
[modelConfiguration] progress in [modelConfiguration] progress in
DispatchQueue.main.sync { DispatchQueue.main.sync {
@@ -80,6 +83,8 @@ class LLMEvaluator {
"Downloading \(modelConfiguration.id): \(Int(progress.fractionCompleted * 100))%" "Downloading \(modelConfiguration.id): \(Int(progress.fractionCompleted * 100))%"
} }
} }
self.output =
"Loaded \(modelConfiguration.id). Weights: \(MLX.GPU.activeMemory / 1024 / 1024)M"
loadState = .loaded(model, tokenizer) loadState = .loaded(model, tokenizer)
return (model, tokenizer) return (model, tokenizer)

View File

@@ -16,9 +16,25 @@ Some notes about the setup:
- this downloads models from hugging face so LLMEval -> Signing & Capabilities has the "Outgoing Connections (Client)" set in the App Sandbox - this downloads models from hugging face so LLMEval -> Signing & Capabilities has the "Outgoing Connections (Client)" set in the App Sandbox
- LLM models are large so this uses the Increased Memory Limit entitlement on iOS to allow ... increased memory limits for devices that have more memory - LLM models are large so this uses the Increased Memory Limit entitlement on iOS to allow ... increased memory limits for devices that have more memory
- `MLX.GPU.set(cacheLimit: 20 * 1024 * 1024)` is used to limit the buffer cache size
- The Phi2 4 bit model is small enough to run on some iPhone models - The Phi2 4 bit model is small enough to run on some iPhone models
- this can be changed by editing `let modelConfiguration = ModelConfiguration.phi4bit` - this can be changed by editing `let modelConfiguration = ModelConfiguration.phi4bit`
### Trying Different Models
The example application uses Phi2 model by default, see [ContentView.swift](ContentView.swift#L58):
```
/// this controls which model loads -- phi4bit is one of the smaller ones so this will fit on
/// more devices
let modelConfiguration = ModelConfiguration.phi4bit
```
There are some pre-configured models in [LLM/Models.swift](../../Libraries/LLM/Models.swift#L62)
and you can load any weights from Hugging Face where there
is a model architecture defined and you have enough
memory.
### Troubleshooting ### Troubleshooting
If the program crashes with a very deep stack trace you may need to build If the program crashes with a very deep stack trace you may need to build

View File

@@ -14,12 +14,7 @@ struct LLMTool: AsyncParsableCommand {
defaultSubcommand: SyncGenerator.self) defaultSubcommand: SyncGenerator.self)
} }
struct SyncGenerator: AsyncParsableCommand { struct LLMArguments: ParsableArguments {
static var configuration = CommandConfiguration(
commandName: "sync",
abstract: "Synchronous generator"
)
@Option(name: .long, help: "Name of the huggingface model") @Option(name: .long, help: "Name of the huggingface model")
var model: String = "mlx-community/Mistral-7B-v0.1-hf-4bit-mlx" var model: String = "mlx-community/Mistral-7B-v0.1-hf-4bit-mlx"
@@ -36,20 +31,91 @@ struct SyncGenerator: AsyncParsableCommand {
@Option(name: .long, help: "The PRNG seed") @Option(name: .long, help: "The PRNG seed")
var seed: UInt64 = 0 var seed: UInt64 = 0
@MainActor @Flag(help: "Show memory stats")
func run() async throws { var memoryStats = false
@Option(name: .long, help: "Maximum cache size in M")
var cacheSize: Int?
@Option(name: .long, help: "Maximum memory size in M")
var memorySize: Int?
var startMemory: GPU.Snapshot?
mutating func load() async throws -> (LLMModel, Tokenizer, ModelConfiguration) {
MLXRandom.seed(seed) MLXRandom.seed(seed)
if let cacheSize {
GPU.set(cacheLimit: cacheSize * 1024 * 1024)
}
if let memorySize {
GPU.set(memoryLimit: memorySize * 1024 * 1024)
}
let modelConfiguration = ModelConfiguration.configuration(id: model) let modelConfiguration = ModelConfiguration.configuration(id: model)
let (model, tokenizer) = try await load(configuration: modelConfiguration) let (model, tokenizer) = try await LLM.load(configuration: modelConfiguration)
print("Model loaded -> \(self.model)") startMemory = GPU.snapshot()
let prompt = modelConfiguration.prepare(prompt: self.prompt) return (model, tokenizer, modelConfiguration)
}
func tokenizePropmpt(configuration: ModelConfiguration, tokenizer: Tokenizer) -> (String, [Int])
{
let prompt = configuration.prepare(prompt: self.prompt)
let promptTokens = tokenizer.encode(text: prompt) let promptTokens = tokenizer.encode(text: prompt)
return (prompt, promptTokens)
}
func reportMemoryStatistics() {
if memoryStats, let startMemory {
let endMemory = GPU.snapshot()
print("=======")
print("Memory size: \(GPU.memoryLimit / 1024)K")
print("Cache size: \(GPU.cacheLimit / 1024)K")
print("")
print("=======")
print("Starting memory")
print(startMemory.description)
print("")
print("=======")
print("Ending memory")
print(endMemory.description)
print("")
print("=======")
print("Growth")
print(startMemory.delta(endMemory).description)
}
}
}
struct SyncGenerator: AsyncParsableCommand {
static var configuration = CommandConfiguration(
commandName: "sync",
abstract: "Synchronous generator"
)
@OptionGroup var args: LLMArguments
@MainActor
mutating func run() async throws {
let (model, tokenizer, modelConfiguration) = try await args.load()
print("Model loaded -> \(modelConfiguration.id)")
let (prompt, promptTokens) = args.tokenizePropmpt(
configuration: modelConfiguration, tokenizer: tokenizer)
print("Starting generation ...") print("Starting generation ...")
print(self.prompt, terminator: "") print(prompt, terminator: "")
var start = Date.timeIntervalSinceReferenceDate var start = Date.timeIntervalSinceReferenceDate
var promptTime: TimeInterval = 0 var promptTime: TimeInterval = 0
@@ -59,7 +125,8 @@ struct SyncGenerator: AsyncParsableCommand {
var tokens = [Int]() var tokens = [Int]()
var printed = 0 var printed = 0
for token in TokenIterator(prompt: MLXArray(promptTokens), model: model, temp: temperature) for token in TokenIterator(
prompt: MLXArray(promptTokens), model: model, temp: args.temperature)
{ {
if tokens.isEmpty { if tokens.isEmpty {
eval(token) eval(token)
@@ -83,7 +150,7 @@ struct SyncGenerator: AsyncParsableCommand {
printed = fullOutput.count printed = fullOutput.count
if tokens.count == maxTokens { if tokens.count == args.maxTokens {
break break
} }
} }
@@ -98,6 +165,8 @@ struct SyncGenerator: AsyncParsableCommand {
Prompt Tokens per second: \((Double(promptTokens.count) / promptTime).formatted()) Prompt Tokens per second: \((Double(promptTokens.count) / promptTime).formatted())
Generation tokens per second: \((Double(tokens.count - 1) / generateTime).formatted()) Generation tokens per second: \((Double(tokens.count - 1) / generateTime).formatted())
""") """)
args.reportMemoryStatistics()
} }
} }
@@ -112,35 +181,19 @@ struct AsyncGenerator: AsyncParsableCommand {
abstract: "async generator" abstract: "async generator"
) )
@Option(name: .long, help: "Name of the huggingface model") @OptionGroup var args: LLMArguments
var model: String = "mlx-community/Mistral-7B-v0.1-hf-4bit-mlx"
@Option(name: .shortAndLong, help: "The message to be processed by the model")
var prompt = "compare python and swift"
@Option(name: .shortAndLong, help: "Maximum number of tokens to generate")
var maxTokens = 100
@Option(name: .shortAndLong, help: "The sampling temperature")
var temperature: Float = 0.6
@Option(name: .long, help: "The PRNG seed")
var seed: UInt64 = 0
@MainActor @MainActor
func run() async throws { mutating func run() async throws {
MLXRandom.seed(seed) let (model, tokenizer, modelConfiguration) = try await args.load()
let modelConfiguration = ModelConfiguration.configuration(id: model) print("Model loaded -> \(modelConfiguration.id)")
let (model, tokenizer) = try await load(configuration: modelConfiguration)
print("Model loaded -> \(self.model)") let (prompt, promptTokens) = args.tokenizePropmpt(
configuration: modelConfiguration, tokenizer: tokenizer)
let prompt = modelConfiguration.prepare(prompt: self.prompt)
let promptTokens = tokenizer.encode(text: prompt)
print("Starting generation ...") print("Starting generation ...")
print(self.prompt, terminator: "") print(prompt, terminator: "")
var start = Date.timeIntervalSinceReferenceDate var start = Date.timeIntervalSinceReferenceDate
var promptTime: TimeInterval = 0 var promptTime: TimeInterval = 0
@@ -151,7 +204,7 @@ struct AsyncGenerator: AsyncParsableCommand {
var printed = 0 var printed = 0
let (task, channel) = generate( let (task, channel) = generate(
prompt: MLXArray(promptTokens), model: model, temp: temperature) prompt: MLXArray(promptTokens), model: model, temp: args.temperature)
for await token in channel { for await token in channel {
if tokens.isEmpty { if tokens.isEmpty {
@@ -174,7 +227,7 @@ struct AsyncGenerator: AsyncParsableCommand {
printed = fullOutput.count printed = fullOutput.count
if tokens.count == maxTokens { if tokens.count == args.maxTokens {
break break
} }
} }
@@ -193,6 +246,8 @@ struct AsyncGenerator: AsyncParsableCommand {
Generation tokens per second: \((Double(tokens.count - 1) / generateTime).formatted()) Generation tokens per second: \((Double(tokens.count - 1) / generateTime).formatted())
""") """)
args.reportMemoryStatistics()
// wait for the task to complete -- since it is running async, it might // wait for the task to complete -- since it is running async, it might
// be in the middle of running the model // be in the middle of running the model
try? await Task.sleep(for: .milliseconds(500)) try? await Task.sleep(for: .milliseconds(500))