chore: preload whisper models
This commit is contained in:
parent
3f8ea518f6
commit
c0cbbcdda9
9 changed files with 198 additions and 12 deletions
|
|
@ -96,6 +96,7 @@ export interface AppSettingsData {
|
|||
confidenceThreshold: number;
|
||||
enablePunctuation: boolean;
|
||||
enableTimestamps: boolean;
|
||||
preloadWhisperModel?: boolean;
|
||||
};
|
||||
recording?: {
|
||||
defaultFormat: "wav" | "mp3" | "flac";
|
||||
|
|
|
|||
|
|
@ -1,4 +1,4 @@
|
|||
import { useState, useEffect } from "react";
|
||||
import { useState } from "react";
|
||||
import { api } from "@/trpc/react";
|
||||
import type { RecordingState } from "@/types/recording";
|
||||
|
||||
|
|
|
|||
|
|
@ -87,6 +87,7 @@ export class ServiceManager {
|
|||
this.transcriptionService = new TranscriptionService(
|
||||
this.modelManagerService,
|
||||
this.vadService,
|
||||
this.settingsService,
|
||||
);
|
||||
await this.transcriptionService.initialize();
|
||||
|
||||
|
|
|
|||
|
|
@ -21,7 +21,7 @@ export class WhisperProvider implements TranscriptionProvider {
|
|||
// Configuration
|
||||
private readonly FRAME_SIZE = 512; // 32ms at 16kHz
|
||||
private readonly MIN_SPEECH_DURATION_MS = 500; // Minimum speech duration to transcribe
|
||||
private readonly MAX_SILENCE_DURATION_MS = 2000; // Max silence before cutting
|
||||
private readonly MAX_SILENCE_DURATION_MS = 800; // Max silence before cutting
|
||||
private readonly SAMPLE_RATE = 16000;
|
||||
private readonly SPEECH_PROBABILITY_THRESHOLD = 0.2; // Threshold for speech detection
|
||||
|
||||
|
|
@ -29,6 +29,13 @@ export class WhisperProvider implements TranscriptionProvider {
|
|||
this.modelManager = modelManager;
|
||||
}
|
||||
|
||||
/**
|
||||
* Preload the Whisper model into memory
|
||||
*/
|
||||
async preloadModel(): Promise<void> {
|
||||
await this.initializeWhisper();
|
||||
}
|
||||
|
||||
async transcribe(params: TranscribeParams): Promise<string> {
|
||||
try {
|
||||
await this.initializeWhisper();
|
||||
|
|
@ -71,13 +78,13 @@ export class WhisperProvider implements TranscriptionProvider {
|
|||
const aggregatedAudio = this.aggregateFrames();
|
||||
|
||||
// Skip if too short or only silence
|
||||
if (aggregatedAudio.length < this.FRAME_SIZE * 2) {
|
||||
/* if (aggregatedAudio.length < this.FRAME_SIZE * 2) {
|
||||
logger.transcription.debug("Skipping transcription - audio too short");
|
||||
this.frameBuffer = [];
|
||||
this.frameBufferSpeechProbabilities = [];
|
||||
this.silenceFrameCount = 0;
|
||||
return "";
|
||||
}
|
||||
} */
|
||||
|
||||
logger.transcription.debug(
|
||||
`Starting transcription of ${aggregatedAudio.length} samples (${((aggregatedAudio.length / this.SAMPLE_RATE) * 1000).toFixed(0)}ms)`,
|
||||
|
|
@ -266,7 +273,7 @@ export class WhisperProvider implements TranscriptionProvider {
|
|||
return prompt;
|
||||
}
|
||||
|
||||
private async initializeWhisper(): Promise<void> {
|
||||
async initializeWhisper(): Promise<void> {
|
||||
if (this.whisperInstance) {
|
||||
return; // Already initialized
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,4 +1,4 @@
|
|||
import React from "react";
|
||||
import React, { useState, useEffect } from "react";
|
||||
import {
|
||||
Card,
|
||||
CardHeader,
|
||||
|
|
@ -9,8 +9,44 @@ import {
|
|||
import { Button } from "@/components/ui/button";
|
||||
import { Label } from "@/components/ui/label";
|
||||
import { Switch } from "@/components/ui/switch";
|
||||
import { api } from "@/trpc/react";
|
||||
import { toast } from "sonner";
|
||||
|
||||
export function AdvancedSettings() {
|
||||
const [preloadWhisperModel, setPreloadWhisperModel] = useState(true);
|
||||
|
||||
// tRPC queries and mutations
|
||||
const settingsQuery = api.settings.getSettings.useQuery();
|
||||
const utils = api.useUtils();
|
||||
|
||||
const updateTranscriptionSettingsMutation =
|
||||
api.settings.updateTranscriptionSettings.useMutation({
|
||||
onSuccess: () => {
|
||||
utils.settings.getSettings.invalidate();
|
||||
toast.success("Settings updated");
|
||||
},
|
||||
onError: (error) => {
|
||||
console.error("Failed to update transcription settings:", error);
|
||||
toast.error("Failed to update settings. Please try again.");
|
||||
},
|
||||
});
|
||||
|
||||
// Load settings when query data is available
|
||||
useEffect(() => {
|
||||
if (settingsQuery.data?.transcription) {
|
||||
setPreloadWhisperModel(
|
||||
settingsQuery.data.transcription.preloadWhisperModel !== false,
|
||||
);
|
||||
}
|
||||
}, [settingsQuery.data]);
|
||||
|
||||
const handlePreloadWhisperModelChange = (checked: boolean) => {
|
||||
setPreloadWhisperModel(checked);
|
||||
updateTranscriptionSettingsMutation.mutate({
|
||||
preloadWhisperModel: checked,
|
||||
});
|
||||
};
|
||||
|
||||
return (
|
||||
<Card>
|
||||
<CardHeader>
|
||||
|
|
@ -18,6 +54,20 @@ export function AdvancedSettings() {
|
|||
<CardDescription>Advanced configuration options</CardDescription>
|
||||
</CardHeader>
|
||||
<CardContent className="space-y-4">
|
||||
<div className="flex items-center justify-between">
|
||||
<div>
|
||||
<Label htmlFor="preload-whisper">Preload Whisper Model</Label>
|
||||
<p className="text-sm text-muted-foreground">
|
||||
Load AI model at startup for faster transcription
|
||||
</p>
|
||||
</div>
|
||||
<Switch
|
||||
id="preload-whisper"
|
||||
checked={preloadWhisperModel}
|
||||
onCheckedChange={handlePreloadWhisperModelChange}
|
||||
/>
|
||||
</div>
|
||||
|
||||
<div className="flex items-center justify-between">
|
||||
<div>
|
||||
<Label htmlFor="debug-mode">Debug Mode</Label>
|
||||
|
|
|
|||
|
|
@ -7,6 +7,7 @@ import { createDefaultContext } from "../pipeline/core/context";
|
|||
import { WhisperProvider } from "../pipeline/providers/transcription/whisper-provider";
|
||||
import { OpenRouterProvider } from "../pipeline/providers/formatting/openrouter-formatter";
|
||||
import { ModelManagerService } from "../services/model-manager";
|
||||
import { SettingsService } from "../services/settings-service";
|
||||
import { appContextStore } from "../stores/app-context";
|
||||
import { createTranscription } from "../db/transcriptions";
|
||||
import { logger } from "../main/logger";
|
||||
|
|
@ -31,14 +32,17 @@ export class TranscriptionService {
|
|||
private openRouterProvider: OpenRouterProvider | null = null;
|
||||
private formatterEnabled = false;
|
||||
private streamingSessions: Map<string, ExtendedStreamingSession> = new Map();
|
||||
private vadService: VADService | null = null;
|
||||
private vadService: VADService;
|
||||
private settingsService: SettingsService;
|
||||
|
||||
constructor(
|
||||
modelManagerService: ModelManagerService,
|
||||
vadService: VADService | null = null,
|
||||
vadService: VADService,
|
||||
settingsService: SettingsService,
|
||||
) {
|
||||
this.whisperProvider = new WhisperProvider(modelManagerService);
|
||||
this.vadService = vadService;
|
||||
this.settingsService = settingsService;
|
||||
}
|
||||
|
||||
async initialize(): Promise<void> {
|
||||
|
|
@ -47,9 +51,66 @@ export class TranscriptionService {
|
|||
} else {
|
||||
logger.transcription.warn("VAD service not available");
|
||||
}
|
||||
|
||||
// Check if we should preload Whisper model
|
||||
const transcriptionSettings =
|
||||
await this.settingsService.getTranscriptionSettings();
|
||||
const shouldPreload = transcriptionSettings?.preloadWhisperModel !== false; // Default to true
|
||||
|
||||
if (shouldPreload) {
|
||||
logger.transcription.info("Preloading Whisper model...");
|
||||
await this.preloadWhisperModel();
|
||||
logger.transcription.info("Whisper model preloaded successfully");
|
||||
} else {
|
||||
logger.transcription.info("Whisper model preloading disabled");
|
||||
}
|
||||
|
||||
logger.transcription.info("Transcription service initialized");
|
||||
}
|
||||
|
||||
/**
|
||||
* Preload Whisper model into memory
|
||||
*/
|
||||
async preloadWhisperModel(): Promise<void> {
|
||||
try {
|
||||
// This will trigger the model initialization in WhisperProvider
|
||||
await this.whisperProvider.preloadModel();
|
||||
logger.transcription.info("Whisper model preloaded successfully");
|
||||
} catch (error) {
|
||||
logger.transcription.error("Failed to preload Whisper model:", error);
|
||||
throw error;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Handle model change - dispose old model and load new one if preloading is enabled
|
||||
*/
|
||||
async handleModelChange(): Promise<void> {
|
||||
try {
|
||||
// Dispose current model
|
||||
await this.whisperProvider.dispose();
|
||||
|
||||
// Check if preloading is enabled
|
||||
if (this.settingsService) {
|
||||
const transcriptionSettings =
|
||||
await this.settingsService.getTranscriptionSettings();
|
||||
const shouldPreload =
|
||||
transcriptionSettings?.preloadWhisperModel !== false;
|
||||
|
||||
if (shouldPreload) {
|
||||
logger.transcription.info(
|
||||
"Reloading Whisper model after model change...",
|
||||
);
|
||||
await this.whisperProvider.preloadModel();
|
||||
logger.transcription.info("Whisper model reloaded successfully");
|
||||
}
|
||||
}
|
||||
} catch (error) {
|
||||
logger.transcription.error("Failed to handle model change:", error);
|
||||
// Don't throw - model will be loaded on first use
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Configure formatter for post-processing
|
||||
*/
|
||||
|
|
|
|||
|
|
@ -55,7 +55,7 @@ export class VADService extends EventEmitter {
|
|||
|
||||
// Load ONNX model
|
||||
this.session = await ort.InferenceSession.create(this.modelPath, {
|
||||
executionProviders: ["cpu"], // Use CPU provider for compatibility
|
||||
executionProviders: ["coreml", "cpu"],
|
||||
});
|
||||
|
||||
// Initialize hidden states (h and c)
|
||||
|
|
|
|||
|
|
@ -19,6 +19,7 @@ const t = initTRPC.create({
|
|||
|
||||
declare global {
|
||||
var modelManagerService: any;
|
||||
var transcriptionService: any;
|
||||
}
|
||||
|
||||
export const modelsRouter = t.router({
|
||||
|
|
@ -120,9 +121,14 @@ export const modelsRouter = t.router({
|
|||
if (!globalThis.modelManagerService) {
|
||||
throw new Error("Model manager service not initialized");
|
||||
}
|
||||
return await globalThis.modelManagerService.setSelectedModel(
|
||||
input.modelId,
|
||||
);
|
||||
await globalThis.modelManagerService.setSelectedModel(input.modelId);
|
||||
|
||||
// Notify transcription service about model change
|
||||
if (globalThis.transcriptionService) {
|
||||
await globalThis.transcriptionService.handleModelChange();
|
||||
}
|
||||
|
||||
return true;
|
||||
}),
|
||||
|
||||
// Subscriptions using Observables
|
||||
|
|
|
|||
|
|
@ -32,6 +32,66 @@ const SetShortcutSchema = z.object({
|
|||
shortcut: z.string(),
|
||||
});
|
||||
export const settingsRouter = t.router({
|
||||
// Get all settings
|
||||
getSettings: t.procedure.query(async () => {
|
||||
try {
|
||||
if (!globalThis.settingsService) {
|
||||
throw new Error("SettingsService not available");
|
||||
}
|
||||
return await globalThis.settingsService.getAllSettings();
|
||||
} catch (error) {
|
||||
if (globalThis.logger) {
|
||||
globalThis.logger.main.error("Error getting settings:", error);
|
||||
}
|
||||
return {};
|
||||
}
|
||||
}),
|
||||
|
||||
// Update transcription settings
|
||||
updateTranscriptionSettings: t.procedure
|
||||
.input(
|
||||
z.object({
|
||||
language: z.string().optional(),
|
||||
autoTranscribe: z.boolean().optional(),
|
||||
confidenceThreshold: z.number().optional(),
|
||||
enablePunctuation: z.boolean().optional(),
|
||||
enableTimestamps: z.boolean().optional(),
|
||||
preloadWhisperModel: z.boolean().optional(),
|
||||
}),
|
||||
)
|
||||
.mutation(async ({ input }) => {
|
||||
try {
|
||||
if (!globalThis.settingsService) {
|
||||
throw new Error("SettingsService not available");
|
||||
}
|
||||
|
||||
// Check if preloadWhisperModel setting is changing
|
||||
const currentSettings =
|
||||
await globalThis.settingsService.getTranscriptionSettings();
|
||||
const preloadChanged =
|
||||
input.preloadWhisperModel !== undefined &&
|
||||
currentSettings &&
|
||||
input.preloadWhisperModel !== currentSettings.preloadWhisperModel;
|
||||
|
||||
await globalThis.settingsService.setTranscriptionSettings(input);
|
||||
|
||||
// Handle model preloading change
|
||||
if (preloadChanged && globalThis.transcriptionService) {
|
||||
await globalThis.transcriptionService.handleModelChange();
|
||||
}
|
||||
|
||||
return true;
|
||||
} catch (error) {
|
||||
if (globalThis.logger) {
|
||||
globalThis.logger.main.error(
|
||||
"Error updating transcription settings:",
|
||||
error,
|
||||
);
|
||||
}
|
||||
throw error;
|
||||
}
|
||||
}),
|
||||
|
||||
// Get formatter configuration
|
||||
getFormatterConfig: t.procedure.query(async () => {
|
||||
try {
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue