LLMEval UI Improvements (#27)

* Feat: LLMEval UI Improvements

1. adds Markdown rendering in the UI
2. Adds init time and token/second stat
3. Minor UI enhancements

* feat: adds a copy to clipboard button for llm outputs

* adds scrollviewreader to sync with main

* ran pre-format to resolve formatting issues

* updates the missing dependency in project definition

* feat: switch between plain text and markdown

adds a segemented picker to switch between plain text and markdown
This commit is contained in:
Ashraful Islam
2024-03-18 11:15:50 -05:00
committed by GitHub
parent 6c270a9d12
commit a7b2b54f18
3 changed files with 123 additions and 12 deletions

View File

@@ -3,6 +3,7 @@
import LLM import LLM
import MLX import MLX
import MLXRandom import MLXRandom
import MarkdownUI
import Metal import Metal
import SwiftUI import SwiftUI
import Tokenizers import Tokenizers
@@ -12,17 +13,54 @@ struct ContentView: View {
@State var prompt = "compare python and swift" @State var prompt = "compare python and swift"
@State var llm = LLMEvaluator() @State var llm = LLMEvaluator()
enum displayStyle: String, CaseIterable, Identifiable {
case plain, markdown
var id: Self { self }
}
@State private var selectedDisplayStyle = displayStyle.markdown
var body: some View { var body: some View {
VStack(alignment: .leading) {
VStack { VStack {
HStack {
Text(llm.modelInfo)
.textFieldStyle(.roundedBorder)
Spacer()
Text(llm.stat)
}
HStack {
Spacer()
if llm.running {
ProgressView()
Spacer()
}
Picker("", selection: $selectedDisplayStyle) {
ForEach(displayStyle.allCases, id: \.self) { option in
Text(option.rawValue.capitalized)
.tag(option)
}
}
.pickerStyle(.segmented)
.frame(maxWidth: 150)
}
}
// show the model output // show the model output
ScrollView(.vertical) { ScrollView(.vertical) {
ScrollViewReader { sp in ScrollViewReader { sp in
if llm.running { Group {
ProgressView() if selectedDisplayStyle == .plain {
}
Text(llm.output) Text(llm.output)
.textSelection(.enabled) .textSelection(.enabled)
} else {
Markdown(llm.output)
.textSelection(.enabled)
}
}
.onChange(of: llm.output) { _, _ in .onChange(of: llm.output) { _, _ in
sp.scrollTo("bottom") sp.scrollTo("bottom")
} }
@@ -42,6 +80,20 @@ struct ContentView: View {
} }
} }
.padding() .padding()
.toolbar {
ToolbarItem(placement: .primaryAction) {
Button {
Task {
copyToClipboard(llm.output)
}
} label: {
Label("Copy Output", systemImage: "doc.on.doc.fill")
}
.disabled(llm.output == "")
.labelStyle(.titleAndIcon)
}
}
.task { .task {
// pre-load the weights on launch to speed up the first generation // pre-load the weights on launch to speed up the first generation
_ = try? await llm.load() _ = try? await llm.load()
@@ -53,6 +105,14 @@ struct ContentView: View {
await llm.generate(prompt: prompt) await llm.generate(prompt: prompt)
} }
} }
private func copyToClipboard(_ string: String) {
#if os(macOS)
NSPasteboard.general.clearContents()
NSPasteboard.general.setString(string, forType: .string)
#else
UIPasteboard.general.string = string
#endif
}
} }
@Observable @Observable
@@ -62,6 +122,8 @@ class LLMEvaluator {
var running = false var running = false
var output = "" var output = ""
var modelInfo = ""
var stat = ""
/// this controls which model loads -- phi4bit is one of the smaller ones so this will fit on /// this controls which model loads -- phi4bit is one of the smaller ones so this will fit on
/// more devices /// more devices
@@ -89,11 +151,11 @@ class LLMEvaluator {
let (model, tokenizer) = try await LLM.load(configuration: modelConfiguration) { let (model, tokenizer) = try await LLM.load(configuration: modelConfiguration) {
[modelConfiguration] progress in [modelConfiguration] progress in
DispatchQueue.main.sync { DispatchQueue.main.sync {
self.output = self.modelInfo =
"Downloading \(modelConfiguration.id): \(Int(progress.fractionCompleted * 100))%" "Downloading \(modelConfiguration.id): \(Int(progress.fractionCompleted * 100))%"
} }
} }
self.output = self.modelInfo =
"Loaded \(modelConfiguration.id). Weights: \(MLX.GPU.activeMemory / 1024 / 1024)M" "Loaded \(modelConfiguration.id). Weights: \(MLX.GPU.activeMemory / 1024 / 1024)M"
loadState = .loaded(model, tokenizer) loadState = .loaded(model, tokenizer)
return (model, tokenizer) return (model, tokenizer)
@@ -104,6 +166,7 @@ class LLMEvaluator {
} }
func generate(prompt: String) async { func generate(prompt: String) async {
let startTime = Date()
do { do {
let (model, tokenizer) = try await load() let (model, tokenizer) = try await load()
@@ -116,6 +179,12 @@ class LLMEvaluator {
let prompt = modelConfiguration.prepare(prompt: prompt) let prompt = modelConfiguration.prepare(prompt: prompt)
let promptTokens = MLXArray(tokenizer.encode(text: prompt)) let promptTokens = MLXArray(tokenizer.encode(text: prompt))
let initTime = Date()
let initDuration = initTime.timeIntervalSince(startTime)
await MainActor.run {
self.stat = "Init: \(String(format: "%.3f", initDuration))s"
}
// each time you generate you will get something new // each time you generate you will get something new
MLXRandom.seed(UInt64(Date.timeIntervalSinceReferenceDate * 1000)) MLXRandom.seed(UInt64(Date.timeIntervalSinceReferenceDate * 1000))
@@ -141,8 +210,12 @@ class LLMEvaluator {
} }
} }
let tokenDuration = Date().timeIntervalSince(initTime)
let tokensPerSecond = Double(outputTokens.count) / tokenDuration
await MainActor.run { await MainActor.run {
running = false running = false
self.stat += " Token/second: \(String(format: "%.3f", tokensPerSecond))"
} }
} catch { } catch {

View File

@@ -9,6 +9,7 @@
/* Begin PBXBuildFile section */ /* Begin PBXBuildFile section */
525C1E9D2B9A011000B5C356 /* Starcoder2.swift in Sources */ = {isa = PBXBuildFile; fileRef = 525C1E9C2B9A010F00B5C356 /* Starcoder2.swift */; }; 525C1E9D2B9A011000B5C356 /* Starcoder2.swift in Sources */ = {isa = PBXBuildFile; fileRef = 525C1E9C2B9A010F00B5C356 /* Starcoder2.swift */; };
52A776182B94B5EE00AA6E80 /* Qwen2.swift in Sources */ = {isa = PBXBuildFile; fileRef = 52A776172B94B5EE00AA6E80 /* Qwen2.swift */; }; 52A776182B94B5EE00AA6E80 /* Qwen2.swift in Sources */ = {isa = PBXBuildFile; fileRef = 52A776172B94B5EE00AA6E80 /* Qwen2.swift */; };
81695B412BA373D300F260D8 /* MarkdownUI in Frameworks */ = {isa = PBXBuildFile; productRef = 81695B402BA373D300F260D8 /* MarkdownUI */; };
C3288D762B6D9313009FF608 /* LinearModelTraining.swift in Sources */ = {isa = PBXBuildFile; fileRef = C3288D752B6D9313009FF608 /* LinearModelTraining.swift */; }; C3288D762B6D9313009FF608 /* LinearModelTraining.swift in Sources */ = {isa = PBXBuildFile; fileRef = C3288D752B6D9313009FF608 /* LinearModelTraining.swift */; };
C3288D7B2B6D9339009FF608 /* ArgumentParser in Frameworks */ = {isa = PBXBuildFile; productRef = C3288D7A2B6D9339009FF608 /* ArgumentParser */; }; C3288D7B2B6D9339009FF608 /* ArgumentParser in Frameworks */ = {isa = PBXBuildFile; productRef = C3288D7A2B6D9339009FF608 /* ArgumentParser */; };
C34E48F52B696F0B00FCB841 /* LLMTool.swift in Sources */ = {isa = PBXBuildFile; fileRef = C34E48F42B696F0B00FCB841 /* LLMTool.swift */; }; C34E48F52B696F0B00FCB841 /* LLMTool.swift in Sources */ = {isa = PBXBuildFile; fileRef = C34E48F42B696F0B00FCB841 /* LLMTool.swift */; };
@@ -308,6 +309,7 @@
buildActionMask = 2147483647; buildActionMask = 2147483647;
files = ( files = (
C3A8B3F82B92A3360002EFB8 /* LLM.framework in Frameworks */, C3A8B3F82B92A3360002EFB8 /* LLM.framework in Frameworks */,
81695B412BA373D300F260D8 /* MarkdownUI in Frameworks */,
); );
runOnlyForDeploymentPostprocessing = 0; runOnlyForDeploymentPostprocessing = 0;
}; };
@@ -684,6 +686,9 @@
C3A8B3FB2B92A3360002EFB8 /* PBXTargetDependency */, C3A8B3FB2B92A3360002EFB8 /* PBXTargetDependency */,
); );
name = LLMEval; name = LLMEval;
packageProductDependencies = (
81695B402BA373D300F260D8 /* MarkdownUI */,
);
productName = LLMEval; productName = LLMEval;
productReference = C3A8B3DC2B92A29E0002EFB8 /* LLMEval.app */; productReference = C3A8B3DC2B92A29E0002EFB8 /* LLMEval.app */;
productType = "com.apple.product-type.application"; productType = "com.apple.product-type.application";
@@ -740,6 +745,7 @@
C34E491A2B69C43600FCB841 /* XCRemoteSwiftPackageReference "GzipSwift" */, C34E491A2B69C43600FCB841 /* XCRemoteSwiftPackageReference "GzipSwift" */,
C3FBCB1F2B8520B00007E490 /* XCRemoteSwiftPackageReference "mlx-swift" */, C3FBCB1F2B8520B00007E490 /* XCRemoteSwiftPackageReference "mlx-swift" */,
C38935BB2B866BFA0037B833 /* XCRemoteSwiftPackageReference "swift-transformers" */, C38935BB2B866BFA0037B833 /* XCRemoteSwiftPackageReference "swift-transformers" */,
81695B3F2BA373D300F260D8 /* XCRemoteSwiftPackageReference "swift-markdown-ui" */,
); );
productRefGroup = C39273752B606A0A00368D5D /* Products */; productRefGroup = C39273752B606A0A00368D5D /* Products */;
projectDirPath = ""; projectDirPath = "";
@@ -2200,6 +2206,14 @@
/* End XCConfigurationList section */ /* End XCConfigurationList section */
/* Begin XCRemoteSwiftPackageReference section */ /* Begin XCRemoteSwiftPackageReference section */
81695B3F2BA373D300F260D8 /* XCRemoteSwiftPackageReference "swift-markdown-ui" */ = {
isa = XCRemoteSwiftPackageReference;
repositoryURL = "https://github.com/gonzalezreal/swift-markdown-ui";
requirement = {
branch = main;
kind = branch;
};
};
C34E491A2B69C43600FCB841 /* XCRemoteSwiftPackageReference "GzipSwift" */ = { C34E491A2B69C43600FCB841 /* XCRemoteSwiftPackageReference "GzipSwift" */ = {
isa = XCRemoteSwiftPackageReference; isa = XCRemoteSwiftPackageReference;
repositoryURL = "https://github.com/1024jp/GzipSwift"; repositoryURL = "https://github.com/1024jp/GzipSwift";
@@ -2243,6 +2257,11 @@
/* End XCRemoteSwiftPackageReference section */ /* End XCRemoteSwiftPackageReference section */
/* Begin XCSwiftPackageProductDependency section */ /* Begin XCSwiftPackageProductDependency section */
81695B402BA373D300F260D8 /* MarkdownUI */ = {
isa = XCSwiftPackageProductDependency;
package = 81695B3F2BA373D300F260D8 /* XCRemoteSwiftPackageReference "swift-markdown-ui" */;
productName = MarkdownUI;
};
C3288D7A2B6D9339009FF608 /* ArgumentParser */ = { C3288D7A2B6D9339009FF608 /* ArgumentParser */ = {
isa = XCSwiftPackageProductDependency; isa = XCSwiftPackageProductDependency;
package = C392736E2B60699100368D5D /* XCRemoteSwiftPackageReference "swift-argument-parser" */; package = C392736E2B60699100368D5D /* XCRemoteSwiftPackageReference "swift-argument-parser" */;

View File

@@ -1,4 +1,5 @@
{ {
"originHash" : "73a943caf561dd1482e57053ba01456f25dea3267b2a84a996e42284e17aa6fc",
"pins" : [ "pins" : [
{ {
"identity" : "gzipswift", "identity" : "gzipswift",
@@ -15,7 +16,16 @@
"location" : "https://github.com/ml-explore/mlx-swift", "location" : "https://github.com/ml-explore/mlx-swift",
"state" : { "state" : {
"branch" : "main", "branch" : "main",
"revision" : "948000ceaa27c343f4dd5ce40f727f221bf45c6e" "revision" : "24e71937e12efe01a0d28a429a703036fae2ff8a"
}
},
{
"identity" : "networkimage",
"kind" : "remoteSourceControl",
"location" : "https://github.com/gonzalezreal/NetworkImage",
"state" : {
"revision" : "7aff8d1b31148d32c5933d75557d42f6323ee3d1",
"version" : "6.0.0"
} }
}, },
{ {
@@ -45,6 +55,15 @@
"version" : "1.1.0" "version" : "1.1.0"
} }
}, },
{
"identity" : "swift-markdown-ui",
"kind" : "remoteSourceControl",
"location" : "https://github.com/gonzalezreal/swift-markdown-ui",
"state" : {
"branch" : "main",
"revision" : "ae799d015a5374708f7b4c85f3294c05f2a564e2"
}
},
{ {
"identity" : "swift-numerics", "identity" : "swift-numerics",
"kind" : "remoteSourceControl", "kind" : "remoteSourceControl",