// Copyright © 2024 Apple Inc. import LLM import MLX import MLXRandom import MarkdownUI import Metal import SwiftUI import Tokenizers struct ContentView: View { @State var prompt = "" @State var llm = LLMEvaluator() @Environment(DeviceStat.self) private var deviceStat enum displayStyle: String, CaseIterable, Identifiable { case plain, markdown var id: Self { self } } @State private var selectedDisplayStyle = displayStyle.markdown var body: some View { VStack(alignment: .leading) { VStack { HStack { Text(llm.modelInfo) .textFieldStyle(.roundedBorder) Spacer() Text(llm.stat) } HStack { Spacer() if llm.running { ProgressView() .frame(maxHeight: 20) Spacer() } Picker("", selection: $selectedDisplayStyle) { ForEach(displayStyle.allCases, id: \.self) { option in Text(option.rawValue.capitalized) .tag(option) } } .pickerStyle(.segmented) #if os(visionOS) .frame(maxWidth: 250) #else .frame(maxWidth: 150) #endif } } // show the model output ScrollView(.vertical) { ScrollViewReader { sp in Group { if selectedDisplayStyle == .plain { Text(llm.output) .textSelection(.enabled) } else { Markdown(llm.output) .textSelection(.enabled) } } .onChange(of: llm.output) { _, _ in sp.scrollTo("bottom") } Spacer() .frame(width: 1, height: 1) .id("bottom") } } HStack { TextField("prompt", text: $prompt) .onSubmit(generate) .disabled(llm.running) #if os(visionOS) .textFieldStyle(.roundedBorder) #endif Button("generate", action: generate) .disabled(llm.running) } } #if os(visionOS) .padding(40) #else .padding() #endif .toolbar { ToolbarItem { Label( "Memory Usage: \(deviceStat.gpuUsage.activeMemory.formatted(.byteCount(style: .memory)))", systemImage: "info.circle.fill" ) .labelStyle(.titleAndIcon) .padding(.horizontal) .help( Text( """ Active Memory: \(deviceStat.gpuUsage.activeMemory.formatted(.byteCount(style: .memory)))/\(GPU.memoryLimit.formatted(.byteCount(style: .memory))) Cache Memory: \(deviceStat.gpuUsage.cacheMemory.formatted(.byteCount(style: .memory)))/\(GPU.cacheLimit.formatted(.byteCount(style: .memory))) Peak Memory: \(deviceStat.gpuUsage.peakMemory.formatted(.byteCount(style: .memory))) """ ) ) } ToolbarItem(placement: .primaryAction) { Button { Task { copyToClipboard(llm.output) } } label: { Label("Copy Output", systemImage: "doc.on.doc.fill") } .disabled(llm.output == "") .labelStyle(.titleAndIcon) } } .task { self.prompt = llm.modelConfiguration.defaultPrompt // pre-load the weights on launch to speed up the first generation _ = try? await llm.load() } } private func generate() { Task { await llm.generate(prompt: prompt) } } private func copyToClipboard(_ string: String) { #if os(macOS) NSPasteboard.general.clearContents() NSPasteboard.general.setString(string, forType: .string) #else UIPasteboard.general.string = string #endif } } @Observable class LLMEvaluator { @MainActor var running = false var output = "" var modelInfo = "" var stat = "" /// this controls which model loads -- phi4bit is one of the smaller ones so this will fit on /// more devices let modelConfiguration = ModelConfiguration.phi3_4bit /// parameters controlling the output let generateParameters = GenerateParameters(temperature: 0.6) let maxTokens = 240 /// update the display every N tokens -- 4 looks like it updates continuously /// and is low overhead. observed ~15% reduction in tokens/s when updating /// on every token let displayEveryNTokens = 4 enum LoadState { case idle case loaded(LLMModel, Tokenizers.Tokenizer) } var loadState = LoadState.idle /// load and return the model -- can be called multiple times, subsequent calls will /// just return the loaded model func load() async throws -> (LLMModel, Tokenizers.Tokenizer) { switch loadState { case .idle: // limit the buffer cache MLX.GPU.set(cacheLimit: 20 * 1024 * 1024) let (model, tokenizer) = try await LLM.load(configuration: modelConfiguration) { [modelConfiguration] progress in DispatchQueue.main.sync { self.modelInfo = "Downloading \(modelConfiguration.name): \(Int(progress.fractionCompleted * 100))%" } } self.modelInfo = "Loaded \(modelConfiguration.id). Weights: \(MLX.GPU.activeMemory / 1024 / 1024)M" loadState = .loaded(model, tokenizer) return (model, tokenizer) case .loaded(let model, let tokenizer): return (model, tokenizer) } } func generate(prompt: String) async { let canGenerate = await MainActor.run { if running { return false } else { running = true self.output = "" return true } } guard canGenerate else { return } do { let (model, tokenizer) = try await load() // augment the prompt as needed let prompt = modelConfiguration.prepare(prompt: prompt) let promptTokens = tokenizer.encode(text: prompt) // each time you generate you will get something new MLXRandom.seed(UInt64(Date.timeIntervalSinceReferenceDate * 1000)) let result = await LLM.generate( promptTokens: promptTokens, parameters: generateParameters, model: model, tokenizer: tokenizer, extraEOSTokens: modelConfiguration.extraEOSTokens ) { tokens in // update the output -- this will make the view show the text as it generates if tokens.count % displayEveryNTokens == 0 { let text = tokenizer.decode(tokens: tokens) await MainActor.run { self.output = text } } if tokens.count >= maxTokens { return .stop } else { return .more } } // update the text if needed, e.g. we haven't displayed because of displayEveryNTokens await MainActor.run { if result.output != self.output { self.output = result.output } running = false self.stat = " Tokens/second: \(String(format: "%.3f", result.tokensPerSecond))" } } catch { await MainActor.run { running = false output = "Failed: \(error)" } } } }