chore: preload whisper models

This commit is contained in:
haritabh-z01 2025-07-06 16:02:51 +05:30
parent 3f8ea518f6
commit c0cbbcdda9
9 changed files with 198 additions and 12 deletions

View file

@ -96,6 +96,7 @@ export interface AppSettingsData {
confidenceThreshold: number;
enablePunctuation: boolean;
enableTimestamps: boolean;
preloadWhisperModel?: boolean;
};
recording?: {
defaultFormat: "wav" | "mp3" | "flac";

View file

@ -1,4 +1,4 @@
import { useState, useEffect } from "react";
import { useState } from "react";
import { api } from "@/trpc/react";
import type { RecordingState } from "@/types/recording";

View file

@ -87,6 +87,7 @@ export class ServiceManager {
this.transcriptionService = new TranscriptionService(
this.modelManagerService,
this.vadService,
this.settingsService,
);
await this.transcriptionService.initialize();

View file

@ -21,7 +21,7 @@ export class WhisperProvider implements TranscriptionProvider {
// Configuration
private readonly FRAME_SIZE = 512; // 32ms at 16kHz
private readonly MIN_SPEECH_DURATION_MS = 500; // Minimum speech duration to transcribe
private readonly MAX_SILENCE_DURATION_MS = 2000; // Max silence before cutting
private readonly MAX_SILENCE_DURATION_MS = 800; // Max silence before cutting
private readonly SAMPLE_RATE = 16000;
private readonly SPEECH_PROBABILITY_THRESHOLD = 0.2; // Threshold for speech detection
@ -29,6 +29,13 @@ export class WhisperProvider implements TranscriptionProvider {
this.modelManager = modelManager;
}
/**
* Preload the Whisper model into memory
*/
async preloadModel(): Promise<void> {
await this.initializeWhisper();
}
async transcribe(params: TranscribeParams): Promise<string> {
try {
await this.initializeWhisper();
@ -71,13 +78,13 @@ export class WhisperProvider implements TranscriptionProvider {
const aggregatedAudio = this.aggregateFrames();
// Skip if too short or only silence
if (aggregatedAudio.length < this.FRAME_SIZE * 2) {
/* if (aggregatedAudio.length < this.FRAME_SIZE * 2) {
logger.transcription.debug("Skipping transcription - audio too short");
this.frameBuffer = [];
this.frameBufferSpeechProbabilities = [];
this.silenceFrameCount = 0;
return "";
}
} */
logger.transcription.debug(
`Starting transcription of ${aggregatedAudio.length} samples (${((aggregatedAudio.length / this.SAMPLE_RATE) * 1000).toFixed(0)}ms)`,
@ -266,7 +273,7 @@ export class WhisperProvider implements TranscriptionProvider {
return prompt;
}
private async initializeWhisper(): Promise<void> {
async initializeWhisper(): Promise<void> {
if (this.whisperInstance) {
return; // Already initialized
}

View file

@ -1,4 +1,4 @@
import React from "react";
import React, { useState, useEffect } from "react";
import {
Card,
CardHeader,
@ -9,8 +9,44 @@ import {
import { Button } from "@/components/ui/button";
import { Label } from "@/components/ui/label";
import { Switch } from "@/components/ui/switch";
import { api } from "@/trpc/react";
import { toast } from "sonner";
export function AdvancedSettings() {
const [preloadWhisperModel, setPreloadWhisperModel] = useState(true);
// tRPC queries and mutations
const settingsQuery = api.settings.getSettings.useQuery();
const utils = api.useUtils();
const updateTranscriptionSettingsMutation =
api.settings.updateTranscriptionSettings.useMutation({
onSuccess: () => {
utils.settings.getSettings.invalidate();
toast.success("Settings updated");
},
onError: (error) => {
console.error("Failed to update transcription settings:", error);
toast.error("Failed to update settings. Please try again.");
},
});
// Load settings when query data is available
useEffect(() => {
if (settingsQuery.data?.transcription) {
setPreloadWhisperModel(
settingsQuery.data.transcription.preloadWhisperModel !== false,
);
}
}, [settingsQuery.data]);
const handlePreloadWhisperModelChange = (checked: boolean) => {
setPreloadWhisperModel(checked);
updateTranscriptionSettingsMutation.mutate({
preloadWhisperModel: checked,
});
};
return (
<Card>
<CardHeader>
@ -18,6 +54,20 @@ export function AdvancedSettings() {
<CardDescription>Advanced configuration options</CardDescription>
</CardHeader>
<CardContent className="space-y-4">
<div className="flex items-center justify-between">
<div>
<Label htmlFor="preload-whisper">Preload Whisper Model</Label>
<p className="text-sm text-muted-foreground">
Load AI model at startup for faster transcription
</p>
</div>
<Switch
id="preload-whisper"
checked={preloadWhisperModel}
onCheckedChange={handlePreloadWhisperModelChange}
/>
</div>
<div className="flex items-center justify-between">
<div>
<Label htmlFor="debug-mode">Debug Mode</Label>

View file

@ -7,6 +7,7 @@ import { createDefaultContext } from "../pipeline/core/context";
import { WhisperProvider } from "../pipeline/providers/transcription/whisper-provider";
import { OpenRouterProvider } from "../pipeline/providers/formatting/openrouter-formatter";
import { ModelManagerService } from "../services/model-manager";
import { SettingsService } from "../services/settings-service";
import { appContextStore } from "../stores/app-context";
import { createTranscription } from "../db/transcriptions";
import { logger } from "../main/logger";
@ -31,14 +32,17 @@ export class TranscriptionService {
private openRouterProvider: OpenRouterProvider | null = null;
private formatterEnabled = false;
private streamingSessions: Map<string, ExtendedStreamingSession> = new Map();
private vadService: VADService | null = null;
private vadService: VADService;
private settingsService: SettingsService;
constructor(
modelManagerService: ModelManagerService,
vadService: VADService | null = null,
vadService: VADService,
settingsService: SettingsService,
) {
this.whisperProvider = new WhisperProvider(modelManagerService);
this.vadService = vadService;
this.settingsService = settingsService;
}
async initialize(): Promise<void> {
@ -47,9 +51,66 @@ export class TranscriptionService {
} else {
logger.transcription.warn("VAD service not available");
}
// Check if we should preload Whisper model
const transcriptionSettings =
await this.settingsService.getTranscriptionSettings();
const shouldPreload = transcriptionSettings?.preloadWhisperModel !== false; // Default to true
if (shouldPreload) {
logger.transcription.info("Preloading Whisper model...");
await this.preloadWhisperModel();
logger.transcription.info("Whisper model preloaded successfully");
} else {
logger.transcription.info("Whisper model preloading disabled");
}
logger.transcription.info("Transcription service initialized");
}
/**
* Preload Whisper model into memory
*/
async preloadWhisperModel(): Promise<void> {
try {
// This will trigger the model initialization in WhisperProvider
await this.whisperProvider.preloadModel();
logger.transcription.info("Whisper model preloaded successfully");
} catch (error) {
logger.transcription.error("Failed to preload Whisper model:", error);
throw error;
}
}
/**
* Handle model change - dispose old model and load new one if preloading is enabled
*/
async handleModelChange(): Promise<void> {
try {
// Dispose current model
await this.whisperProvider.dispose();
// Check if preloading is enabled
if (this.settingsService) {
const transcriptionSettings =
await this.settingsService.getTranscriptionSettings();
const shouldPreload =
transcriptionSettings?.preloadWhisperModel !== false;
if (shouldPreload) {
logger.transcription.info(
"Reloading Whisper model after model change...",
);
await this.whisperProvider.preloadModel();
logger.transcription.info("Whisper model reloaded successfully");
}
}
} catch (error) {
logger.transcription.error("Failed to handle model change:", error);
// Don't throw - model will be loaded on first use
}
}
/**
* Configure formatter for post-processing
*/

View file

@ -55,7 +55,7 @@ export class VADService extends EventEmitter {
// Load ONNX model
this.session = await ort.InferenceSession.create(this.modelPath, {
executionProviders: ["cpu"], // Use CPU provider for compatibility
executionProviders: ["coreml", "cpu"],
});
// Initialize hidden states (h and c)

View file

@ -19,6 +19,7 @@ const t = initTRPC.create({
declare global {
var modelManagerService: any;
var transcriptionService: any;
}
export const modelsRouter = t.router({
@ -120,9 +121,14 @@ export const modelsRouter = t.router({
if (!globalThis.modelManagerService) {
throw new Error("Model manager service not initialized");
}
return await globalThis.modelManagerService.setSelectedModel(
input.modelId,
);
await globalThis.modelManagerService.setSelectedModel(input.modelId);
// Notify transcription service about model change
if (globalThis.transcriptionService) {
await globalThis.transcriptionService.handleModelChange();
}
return true;
}),
// Subscriptions using Observables

View file

@ -32,6 +32,66 @@ const SetShortcutSchema = z.object({
shortcut: z.string(),
});
export const settingsRouter = t.router({
// Get all settings
getSettings: t.procedure.query(async () => {
try {
if (!globalThis.settingsService) {
throw new Error("SettingsService not available");
}
return await globalThis.settingsService.getAllSettings();
} catch (error) {
if (globalThis.logger) {
globalThis.logger.main.error("Error getting settings:", error);
}
return {};
}
}),
// Update transcription settings
updateTranscriptionSettings: t.procedure
.input(
z.object({
language: z.string().optional(),
autoTranscribe: z.boolean().optional(),
confidenceThreshold: z.number().optional(),
enablePunctuation: z.boolean().optional(),
enableTimestamps: z.boolean().optional(),
preloadWhisperModel: z.boolean().optional(),
}),
)
.mutation(async ({ input }) => {
try {
if (!globalThis.settingsService) {
throw new Error("SettingsService not available");
}
// Check if preloadWhisperModel setting is changing
const currentSettings =
await globalThis.settingsService.getTranscriptionSettings();
const preloadChanged =
input.preloadWhisperModel !== undefined &&
currentSettings &&
input.preloadWhisperModel !== currentSettings.preloadWhisperModel;
await globalThis.settingsService.setTranscriptionSettings(input);
// Handle model preloading change
if (preloadChanged && globalThis.transcriptionService) {
await globalThis.transcriptionService.handleModelChange();
}
return true;
} catch (error) {
if (globalThis.logger) {
globalThis.logger.main.error(
"Error updating transcription settings:",
error,
);
}
throw error;
}
}),
// Get formatter configuration
getFormatterConfig: t.procedure.query(async () => {
try {