From c7dfd4725d8e0832d7fbe8868fd3002402a3dafe Mon Sep 17 00:00:00 2001 From: haritabh-z01 Date: Thu, 4 Dec 2025 00:23:22 +0530 Subject: [PATCH] chore: update model download recommendation --- .../components/screens/ModelSetupModal.tsx | 14 +++- .../src/services/onboarding-service.ts | 74 +++++++++++++++++++ .../src/services/transcription-service.ts | 10 +++ apps/desktop/src/trpc/routers/onboarding.ts | 15 ++++ 4 files changed, 109 insertions(+), 4 deletions(-) diff --git a/apps/desktop/src/renderer/onboarding/components/screens/ModelSetupModal.tsx b/apps/desktop/src/renderer/onboarding/components/screens/ModelSetupModal.tsx index 21dfe2e..48ae369 100644 --- a/apps/desktop/src/renderer/onboarding/components/screens/ModelSetupModal.tsx +++ b/apps/desktop/src/renderer/onboarding/components/screens/ModelSetupModal.tsx @@ -44,6 +44,12 @@ export function ModelSetupModal({ const [installedModelName, setInstalledModelName] = useState(""); const [downloadComplete, setDownloadComplete] = useState(false); + // Get recommended local model based on hardware + const { data: recommendedModelId = "whisper-base" } = + api.onboarding.getRecommendedLocalModel.useQuery(undefined, { + enabled: modelType === ModelType.Local && isOpen, + }); + // tRPC mutations const loginMutation = api.auth.login.useMutation({ onSuccess: () => { @@ -84,7 +90,7 @@ export function ModelSetupModal({ // Subscribe to download progress api.models.onDownloadProgress.useSubscription(undefined, { onData: (data) => { - if (data.modelId === "whisper-tiny") { + if (data.modelId === recommendedModelId) { setDownloadProgress(data.progress.progress); setDownloadInfo({ downloaded: data.progress.bytesDownloaded || 0, @@ -116,7 +122,7 @@ export function ModelSetupModal({ try { await downloadModelMutation.mutateAsync({ - modelId: "whisper-tiny", + modelId: recommendedModelId, }); // Progress will be handled by subscription } catch (err) { @@ -188,7 +194,7 @@ export function ModelSetupModal({ {modelAlreadyInstalled || downloadComplete ? "Ready for private, offline transcription." - : "Setting up Whisper Tiny for private, offline transcription"} + : "Setting up local model for private, offline transcription"} @@ -206,7 +212,7 @@ export function ModelSetupModal({ : "Download Complete"}

- Using: {installedModelName || "whisper-tiny"} + Using: {installedModelName || recommendedModelId}

diff --git a/apps/desktop/src/services/onboarding-service.ts b/apps/desktop/src/services/onboarding-service.ts index 3986fe6..e4e0f86 100644 --- a/apps/desktop/src/services/onboarding-service.ts +++ b/apps/desktop/src/services/onboarding-service.ts @@ -431,6 +431,80 @@ export class OnboardingService extends EventEmitter { ); } + /** + * Check for high-end hardware (RTX 50 series or M3 Pro/Max/M4) + */ + private hasHighEndHardware(gpuModel: string, cpuModel: string): boolean { + const upperGpu = gpuModel.toUpperCase(); + const upperCpu = cpuModel.toUpperCase(); + + // RTX 50 series + const hasRtx50 = ["RTX 5060", "RTX 5070", "RTX 5080", "RTX 5090"].some( + (m) => upperGpu.includes(m), + ); + + // M3 Pro/Max or M4+ + const hasM3ProMax = + upperCpu.includes("M3 PRO") || upperCpu.includes("M3 MAX"); + const hasM4Plus = ["M4", "M5", "M6"].some((chip) => + upperCpu.includes(`APPLE ${chip}`), + ); + + return hasRtx50 || hasM3ProMax || hasM4Plus; + } + + /** + * Check for NVIDIA RTX 20 series + */ + private hasNvidia20Series(gpuModel: string): boolean { + if (!gpuModel) return false; + const upperGpu = gpuModel.toUpperCase(); + return ["RTX 2060", "RTX 2070", "RTX 2080"].some((m) => + upperGpu.includes(m), + ); + } + + /** + * Check for Apple Silicon M1 + */ + private hasAppleSiliconM1(cpuModel: string): boolean { + if (process.platform !== "darwin" || process.arch !== "arm64") return false; + return cpuModel.toUpperCase().includes("APPLE M1"); + } + + /** + * Get recommended local model ID based on hardware + * - High-end (RTX 50, M3 Pro/Max, M4+) → whisper-large-v3-turbo + * - Mid-tier (RTX 30/40, M2/M3 base) → whisper-medium + * - Entry (RTX 20, M1) → whisper-small + * - Default → whisper-base + */ + getRecommendedLocalModelId(): string { + const systemInfo = this.telemetryService.getSystemInfo(); + const gpuModel = systemInfo?.gpu_model || ""; + const cpuModel = systemInfo?.cpu_model || ""; + + // High-end: RTX 50 series or M3 Pro/Max/M4+ + if (this.hasHighEndHardware(gpuModel, cpuModel)) { + return "whisper-large-v3-turbo"; + } + + // Mid-tier: RTX 30/40 series or M2/M3 base + if ( + this.hasNvidia30SeriesOrBetter(gpuModel) || + this.hasAppleSiliconM2OrBetter(cpuModel) + ) { + return "whisper-medium"; + } + + // Entry: RTX 20 series or M1 + if (this.hasNvidia20Series(gpuModel) || this.hasAppleSiliconM1(cpuModel)) { + return "whisper-small"; + } + + return "whisper-base"; + } + /** * Calculate model recommendation based on system specs */ diff --git a/apps/desktop/src/services/transcription-service.ts b/apps/desktop/src/services/transcription-service.ts index e59633b..bca80cf 100644 --- a/apps/desktop/src/services/transcription-service.ts +++ b/apps/desktop/src/services/transcription-service.ts @@ -157,6 +157,16 @@ export class TranscriptionService { */ public async isModelAvailable(): Promise { try { + // Check if selected model is a cloud model (doesn't need download) + const selectedModelId = await this.modelService.getSelectedModel(); + if (selectedModelId) { + const model = AVAILABLE_MODELS.find((m) => m.id === selectedModelId); + if (model?.provider === "Amical Cloud") { + return true; + } + } + + // For local models, check if any are downloaded const modelService = this.whisperProvider["modelService"]; const availableModels = await modelService.getValidDownloadedModels(); return Object.keys(availableModels).length > 0; diff --git a/apps/desktop/src/trpc/routers/onboarding.ts b/apps/desktop/src/trpc/routers/onboarding.ts index 0aa39ba..9c1460b 100644 --- a/apps/desktop/src/trpc/routers/onboarding.ts +++ b/apps/desktop/src/trpc/routers/onboarding.ts @@ -65,6 +65,21 @@ export const onboardingRouter = createRouter({ }, ), + /** + * Get recommended local model ID based on hardware + */ + getRecommendedLocalModel: procedure.query(({ ctx }): string => { + const { serviceManager } = ctx; + if (!serviceManager) { + return "whisper-base"; + } + const onboardingService = serviceManager.getOnboardingService(); + if (!onboardingService) { + return "whisper-base"; + } + return onboardingService.getRecommendedLocalModelId(); + }), + /** * Check if onboarding is needed */