swift-format, circleci setup

This commit is contained in:
David Koski
2024-03-01 16:10:34 -08:00
parent b41f14fba7
commit c49dd73c28
3 changed files with 10 additions and 5 deletions

View File

@@ -38,6 +38,7 @@ jobs:
xcodebuild -skipPackagePluginValidation -scheme llm-tool xcodebuild -skipPackagePluginValidation -scheme llm-tool
xcodebuild -skipPackagePluginValidation -scheme mnist-tool xcodebuild -skipPackagePluginValidation -scheme mnist-tool
xcodebuild -skipPackagePluginValidation -scheme MNISTTrainer xcodebuild -skipPackagePluginValidation -scheme MNISTTrainer
xcodebuild -skipPackagePluginValidation -scheme LLMEval
workflows: workflows:
build_and_test: build_and_test:

View File

@@ -2,9 +2,9 @@
import LLM import LLM
import MLX import MLX
import Metal
import SwiftUI import SwiftUI
import Tokenizers import Tokenizers
import Metal
struct ContentView: View { struct ContentView: View {
@@ -62,9 +62,11 @@ class LLMEvaluator {
func load() async throws -> (LLMModel, LLM.Tokenizer) { func load() async throws -> (LLMModel, LLM.Tokenizer) {
switch loadState { switch loadState {
case .idle: case .idle:
let (model, tokenizer) = try await LLM.load(configuration: modelConfiguration) { [modelConfiguration] progress in let (model, tokenizer) = try await LLM.load(configuration: modelConfiguration) {
[modelConfiguration] progress in
DispatchQueue.main.sync { DispatchQueue.main.sync {
self.output = "Downloading \(modelConfiguration.id): \(Int(progress.fractionCompleted * 100))%" self.output =
"Downloading \(modelConfiguration.id): \(Int(progress.fractionCompleted * 100))%"
} }
} }
loadState = .loaded(model, tokenizer) loadState = .loaded(model, tokenizer)

View File

@@ -54,7 +54,8 @@ public struct ModelConfiguration {
extension ModelConfiguration { extension ModelConfiguration {
public static let mistral7B4bit = ModelConfiguration(id: "mlx-community/Mistral-7B-v0.1-hf-4bit-mlx") public static let mistral7B4bit = ModelConfiguration(
id: "mlx-community/Mistral-7B-v0.1-hf-4bit-mlx")
public static let codeLlama13b4bit = ModelConfiguration( public static let codeLlama13b4bit = ModelConfiguration(
id: "mlx-community/CodeLlama-13b-Instruct-hf-4bit-MLX", id: "mlx-community/CodeLlama-13b-Instruct-hf-4bit-MLX",
@@ -67,7 +68,8 @@ extension ModelConfiguration {
"<PRE> " + prompt.replacingOccurrences(of: "<FILL_ME>", with: "<SUF>") + " <MID>" "<PRE> " + prompt.replacingOccurrences(of: "<FILL_ME>", with: "<SUF>") + " <MID>"
} }
public static let phi4bit = ModelConfiguration(id: "mlx-community/phi-2-hf-4bit-mlx") { prompt in public static let phi4bit = ModelConfiguration(id: "mlx-community/phi-2-hf-4bit-mlx") {
prompt in
"Instruct: \(prompt). Output: " "Instruct: \(prompt). Output: "
} }