From c0cbbcdda98d3290fd8d0e89f7af11da337a04fd Mon Sep 17 00:00:00 2001 From: haritabh-z01 Date: Sun, 6 Jul 2025 16:02:51 +0530 Subject: [PATCH] chore: preload whisper models --- apps/desktop/src/db/schema.ts | 1 + apps/desktop/src/hooks/useRecordingState.ts | 2 +- .../src/main/managers/service-manager.ts | 1 + .../transcription/whisper-provider.ts | 15 +++-- .../settings/components/AdvancedSettings.tsx | 52 ++++++++++++++- .../src/services/transcription-service.ts | 65 ++++++++++++++++++- apps/desktop/src/services/vad-service.ts | 2 +- apps/desktop/src/trpc/routers/models.ts | 12 +++- apps/desktop/src/trpc/routers/settings.ts | 60 +++++++++++++++++ 9 files changed, 198 insertions(+), 12 deletions(-) diff --git a/apps/desktop/src/db/schema.ts b/apps/desktop/src/db/schema.ts index c2e996b..27a4d40 100644 --- a/apps/desktop/src/db/schema.ts +++ b/apps/desktop/src/db/schema.ts @@ -96,6 +96,7 @@ export interface AppSettingsData { confidenceThreshold: number; enablePunctuation: boolean; enableTimestamps: boolean; + preloadWhisperModel?: boolean; }; recording?: { defaultFormat: "wav" | "mp3" | "flac"; diff --git a/apps/desktop/src/hooks/useRecordingState.ts b/apps/desktop/src/hooks/useRecordingState.ts index bbddca7..ff386a3 100644 --- a/apps/desktop/src/hooks/useRecordingState.ts +++ b/apps/desktop/src/hooks/useRecordingState.ts @@ -1,4 +1,4 @@ -import { useState, useEffect } from "react"; +import { useState } from "react"; import { api } from "@/trpc/react"; import type { RecordingState } from "@/types/recording"; diff --git a/apps/desktop/src/main/managers/service-manager.ts b/apps/desktop/src/main/managers/service-manager.ts index cb4b693..9305c7f 100644 --- a/apps/desktop/src/main/managers/service-manager.ts +++ b/apps/desktop/src/main/managers/service-manager.ts @@ -87,6 +87,7 @@ export class ServiceManager { this.transcriptionService = new TranscriptionService( this.modelManagerService, this.vadService, + this.settingsService, ); await this.transcriptionService.initialize(); diff --git a/apps/desktop/src/pipeline/providers/transcription/whisper-provider.ts b/apps/desktop/src/pipeline/providers/transcription/whisper-provider.ts index a03cee3..711c1c3 100644 --- a/apps/desktop/src/pipeline/providers/transcription/whisper-provider.ts +++ b/apps/desktop/src/pipeline/providers/transcription/whisper-provider.ts @@ -21,7 +21,7 @@ export class WhisperProvider implements TranscriptionProvider { // Configuration private readonly FRAME_SIZE = 512; // 32ms at 16kHz private readonly MIN_SPEECH_DURATION_MS = 500; // Minimum speech duration to transcribe - private readonly MAX_SILENCE_DURATION_MS = 2000; // Max silence before cutting + private readonly MAX_SILENCE_DURATION_MS = 800; // Max silence before cutting private readonly SAMPLE_RATE = 16000; private readonly SPEECH_PROBABILITY_THRESHOLD = 0.2; // Threshold for speech detection @@ -29,6 +29,13 @@ export class WhisperProvider implements TranscriptionProvider { this.modelManager = modelManager; } + /** + * Preload the Whisper model into memory + */ + async preloadModel(): Promise { + await this.initializeWhisper(); + } + async transcribe(params: TranscribeParams): Promise { try { await this.initializeWhisper(); @@ -71,13 +78,13 @@ export class WhisperProvider implements TranscriptionProvider { const aggregatedAudio = this.aggregateFrames(); // Skip if too short or only silence - if (aggregatedAudio.length < this.FRAME_SIZE * 2) { + /* if (aggregatedAudio.length < this.FRAME_SIZE * 2) { logger.transcription.debug("Skipping transcription - audio too short"); this.frameBuffer = []; this.frameBufferSpeechProbabilities = []; this.silenceFrameCount = 0; return ""; - } + } */ logger.transcription.debug( `Starting transcription of ${aggregatedAudio.length} samples (${((aggregatedAudio.length / this.SAMPLE_RATE) * 1000).toFixed(0)}ms)`, @@ -266,7 +273,7 @@ export class WhisperProvider implements TranscriptionProvider { return prompt; } - private async initializeWhisper(): Promise { + async initializeWhisper(): Promise { if (this.whisperInstance) { return; // Already initialized } diff --git a/apps/desktop/src/renderer/main/pages/settings/components/AdvancedSettings.tsx b/apps/desktop/src/renderer/main/pages/settings/components/AdvancedSettings.tsx index c7f2a6e..37a7b3c 100644 --- a/apps/desktop/src/renderer/main/pages/settings/components/AdvancedSettings.tsx +++ b/apps/desktop/src/renderer/main/pages/settings/components/AdvancedSettings.tsx @@ -1,4 +1,4 @@ -import React from "react"; +import React, { useState, useEffect } from "react"; import { Card, CardHeader, @@ -9,8 +9,44 @@ import { import { Button } from "@/components/ui/button"; import { Label } from "@/components/ui/label"; import { Switch } from "@/components/ui/switch"; +import { api } from "@/trpc/react"; +import { toast } from "sonner"; export function AdvancedSettings() { + const [preloadWhisperModel, setPreloadWhisperModel] = useState(true); + + // tRPC queries and mutations + const settingsQuery = api.settings.getSettings.useQuery(); + const utils = api.useUtils(); + + const updateTranscriptionSettingsMutation = + api.settings.updateTranscriptionSettings.useMutation({ + onSuccess: () => { + utils.settings.getSettings.invalidate(); + toast.success("Settings updated"); + }, + onError: (error) => { + console.error("Failed to update transcription settings:", error); + toast.error("Failed to update settings. Please try again."); + }, + }); + + // Load settings when query data is available + useEffect(() => { + if (settingsQuery.data?.transcription) { + setPreloadWhisperModel( + settingsQuery.data.transcription.preloadWhisperModel !== false, + ); + } + }, [settingsQuery.data]); + + const handlePreloadWhisperModelChange = (checked: boolean) => { + setPreloadWhisperModel(checked); + updateTranscriptionSettingsMutation.mutate({ + preloadWhisperModel: checked, + }); + }; + return ( @@ -18,6 +54,20 @@ export function AdvancedSettings() { Advanced configuration options +
+
+ +

+ Load AI model at startup for faster transcription +

+
+ +
+
diff --git a/apps/desktop/src/services/transcription-service.ts b/apps/desktop/src/services/transcription-service.ts index 62490be..5f8853a 100644 --- a/apps/desktop/src/services/transcription-service.ts +++ b/apps/desktop/src/services/transcription-service.ts @@ -7,6 +7,7 @@ import { createDefaultContext } from "../pipeline/core/context"; import { WhisperProvider } from "../pipeline/providers/transcription/whisper-provider"; import { OpenRouterProvider } from "../pipeline/providers/formatting/openrouter-formatter"; import { ModelManagerService } from "../services/model-manager"; +import { SettingsService } from "../services/settings-service"; import { appContextStore } from "../stores/app-context"; import { createTranscription } from "../db/transcriptions"; import { logger } from "../main/logger"; @@ -31,14 +32,17 @@ export class TranscriptionService { private openRouterProvider: OpenRouterProvider | null = null; private formatterEnabled = false; private streamingSessions: Map = new Map(); - private vadService: VADService | null = null; + private vadService: VADService; + private settingsService: SettingsService; constructor( modelManagerService: ModelManagerService, - vadService: VADService | null = null, + vadService: VADService, + settingsService: SettingsService, ) { this.whisperProvider = new WhisperProvider(modelManagerService); this.vadService = vadService; + this.settingsService = settingsService; } async initialize(): Promise { @@ -47,9 +51,66 @@ export class TranscriptionService { } else { logger.transcription.warn("VAD service not available"); } + + // Check if we should preload Whisper model + const transcriptionSettings = + await this.settingsService.getTranscriptionSettings(); + const shouldPreload = transcriptionSettings?.preloadWhisperModel !== false; // Default to true + + if (shouldPreload) { + logger.transcription.info("Preloading Whisper model..."); + await this.preloadWhisperModel(); + logger.transcription.info("Whisper model preloaded successfully"); + } else { + logger.transcription.info("Whisper model preloading disabled"); + } + logger.transcription.info("Transcription service initialized"); } + /** + * Preload Whisper model into memory + */ + async preloadWhisperModel(): Promise { + try { + // This will trigger the model initialization in WhisperProvider + await this.whisperProvider.preloadModel(); + logger.transcription.info("Whisper model preloaded successfully"); + } catch (error) { + logger.transcription.error("Failed to preload Whisper model:", error); + throw error; + } + } + + /** + * Handle model change - dispose old model and load new one if preloading is enabled + */ + async handleModelChange(): Promise { + try { + // Dispose current model + await this.whisperProvider.dispose(); + + // Check if preloading is enabled + if (this.settingsService) { + const transcriptionSettings = + await this.settingsService.getTranscriptionSettings(); + const shouldPreload = + transcriptionSettings?.preloadWhisperModel !== false; + + if (shouldPreload) { + logger.transcription.info( + "Reloading Whisper model after model change...", + ); + await this.whisperProvider.preloadModel(); + logger.transcription.info("Whisper model reloaded successfully"); + } + } + } catch (error) { + logger.transcription.error("Failed to handle model change:", error); + // Don't throw - model will be loaded on first use + } + } + /** * Configure formatter for post-processing */ diff --git a/apps/desktop/src/services/vad-service.ts b/apps/desktop/src/services/vad-service.ts index 59c7d70..f6d5ff2 100644 --- a/apps/desktop/src/services/vad-service.ts +++ b/apps/desktop/src/services/vad-service.ts @@ -55,7 +55,7 @@ export class VADService extends EventEmitter { // Load ONNX model this.session = await ort.InferenceSession.create(this.modelPath, { - executionProviders: ["cpu"], // Use CPU provider for compatibility + executionProviders: ["coreml", "cpu"], }); // Initialize hidden states (h and c) diff --git a/apps/desktop/src/trpc/routers/models.ts b/apps/desktop/src/trpc/routers/models.ts index 17a3ec6..6a0ad6b 100644 --- a/apps/desktop/src/trpc/routers/models.ts +++ b/apps/desktop/src/trpc/routers/models.ts @@ -19,6 +19,7 @@ const t = initTRPC.create({ declare global { var modelManagerService: any; + var transcriptionService: any; } export const modelsRouter = t.router({ @@ -120,9 +121,14 @@ export const modelsRouter = t.router({ if (!globalThis.modelManagerService) { throw new Error("Model manager service not initialized"); } - return await globalThis.modelManagerService.setSelectedModel( - input.modelId, - ); + await globalThis.modelManagerService.setSelectedModel(input.modelId); + + // Notify transcription service about model change + if (globalThis.transcriptionService) { + await globalThis.transcriptionService.handleModelChange(); + } + + return true; }), // Subscriptions using Observables diff --git a/apps/desktop/src/trpc/routers/settings.ts b/apps/desktop/src/trpc/routers/settings.ts index d9b0eaa..1181fd3 100644 --- a/apps/desktop/src/trpc/routers/settings.ts +++ b/apps/desktop/src/trpc/routers/settings.ts @@ -32,6 +32,66 @@ const SetShortcutSchema = z.object({ shortcut: z.string(), }); export const settingsRouter = t.router({ + // Get all settings + getSettings: t.procedure.query(async () => { + try { + if (!globalThis.settingsService) { + throw new Error("SettingsService not available"); + } + return await globalThis.settingsService.getAllSettings(); + } catch (error) { + if (globalThis.logger) { + globalThis.logger.main.error("Error getting settings:", error); + } + return {}; + } + }), + + // Update transcription settings + updateTranscriptionSettings: t.procedure + .input( + z.object({ + language: z.string().optional(), + autoTranscribe: z.boolean().optional(), + confidenceThreshold: z.number().optional(), + enablePunctuation: z.boolean().optional(), + enableTimestamps: z.boolean().optional(), + preloadWhisperModel: z.boolean().optional(), + }), + ) + .mutation(async ({ input }) => { + try { + if (!globalThis.settingsService) { + throw new Error("SettingsService not available"); + } + + // Check if preloadWhisperModel setting is changing + const currentSettings = + await globalThis.settingsService.getTranscriptionSettings(); + const preloadChanged = + input.preloadWhisperModel !== undefined && + currentSettings && + input.preloadWhisperModel !== currentSettings.preloadWhisperModel; + + await globalThis.settingsService.setTranscriptionSettings(input); + + // Handle model preloading change + if (preloadChanged && globalThis.transcriptionService) { + await globalThis.transcriptionService.handleModelChange(); + } + + return true; + } catch (error) { + if (globalThis.logger) { + globalThis.logger.main.error( + "Error updating transcription settings:", + error, + ); + } + throw error; + } + }), + // Get formatter configuration getFormatterConfig: t.procedure.query(async () => { try {