From e7d9e91bf40f56622d69876e0e8a3140d56bed48 Mon Sep 17 00:00:00 2001 From: nchopra Date: Tue, 26 Aug 2025 01:29:35 +0530 Subject: [PATCH] fix: wire up transcription playback --- apps/desktop/forge.config.ts | 38 +-- apps/desktop/src/hooks/useAudioPlayer.ts | 111 +++++++++ .../components/TranscriptionsList.tsx | 65 +++++- .../src/trpc/routers/transcriptions.ts | 33 ++- packages/smart-whisper/package.json | 78 +++---- packages/smart-whisper/src/binding.ts | 111 ++++----- packages/smart-whisper/src/build.ts | 137 +++++------ .../smart-whisper/src/model-manager/index.ts | 105 +++++---- packages/smart-whisper/src/transcribe.ts | 186 +++++++-------- packages/smart-whisper/src/types.ts | 138 +++++------ packages/smart-whisper/src/whisper.ts | 217 +++++++++--------- packages/smart-whisper/tsup.config.ts | 23 +- turbo.json | 19 +- 13 files changed, 750 insertions(+), 511 deletions(-) create mode 100644 apps/desktop/src/hooks/useAudioPlayer.ts diff --git a/apps/desktop/forge.config.ts b/apps/desktop/forge.config.ts index 3d7ee77..5fddb37 100644 --- a/apps/desktop/forge.config.ts +++ b/apps/desktop/forge.config.ts @@ -149,7 +149,11 @@ const config: ForgeConfig = { // Copy the package console.log(`Copying ${dep}...`); - cpSync(rootDepPath, localDepPath, { recursive: true, dereference: true, force: true }); + cpSync(rootDepPath, localDepPath, { + recursive: true, + dereference: true, + force: true, + }); console.log(`✓ Successfully copied ${dep}`); } catch (error) { console.error(`Failed to copy ${dep}:`, error); @@ -160,31 +164,35 @@ const config: ForgeConfig = { console.log("Checking for symlinks in copied dependencies..."); for (const dep of nativeModuleDependenciesToPackage) { const localDepPath = join(localNodeModules, dep); - + try { if (existsSync(localDepPath)) { const stats = lstatSync(localDepPath); if (stats.isSymbolicLink()) { - console.log(`Found symlink for ${dep}, replacing with dereferenced copy...`); - + console.log( + `Found symlink for ${dep}, replacing with dereferenced copy...`, + ); + // Read where the symlink points to const symlinkTarget = readlinkSync(localDepPath); const absoluteTarget = join(localDepPath, "..", symlinkTarget); const sourcePath = normalize(absoluteTarget); - + console.log(` Symlink points to: ${sourcePath}`); - + // Remove the symlink rmSync(localDepPath, { recursive: true, force: true }); - + // Copy with dereference to get actual content - cpSync(sourcePath, localDepPath, { - recursive: true, + cpSync(sourcePath, localDepPath, { + recursive: true, force: true, - dereference: true // Follow symlinks and copy actual content + dereference: true, // Follow symlinks and copy actual content }); - - console.log(`✓ Successfully replaced symlink for ${dep} with actual content`); + + console.log( + `✓ Successfully replaced symlink for ${dep} with actual content`, + ); } } } catch (error) { @@ -404,8 +412,10 @@ const config: ForgeConfig = { const scopeDir = dep.split("/")[0]; // @libsql/client -> @libsql // for workspace packages only keep the actual package if (scopeDir === "@amical") { - if (filePath.startsWith(`/node_modules/${dep}`) || - filePath === `/node_modules/${scopeDir}`) { + if ( + filePath.startsWith(`/node_modules/${dep}`) || + filePath === `/node_modules/${scopeDir}` + ) { KEEP_FILE.keep = true; KEEP_FILE.log = true; } diff --git a/apps/desktop/src/hooks/useAudioPlayer.ts b/apps/desktop/src/hooks/useAudioPlayer.ts new file mode 100644 index 0000000..c5d704e --- /dev/null +++ b/apps/desktop/src/hooks/useAudioPlayer.ts @@ -0,0 +1,111 @@ +import { useRef, useState, useCallback, useEffect } from "react"; + +interface UseAudioPlayerReturn { + isPlaying: boolean; + currentPlayingId: number | null; + play: ( + audioData: ArrayBuffer, + transcriptionId: number, + mimeType?: string, + ) => void; + pause: () => void; + stop: () => void; + toggle: ( + audioData: ArrayBuffer, + transcriptionId: number, + mimeType?: string, + ) => void; +} + +export function useAudioPlayer(): UseAudioPlayerReturn { + const audioRef = useRef(null); + const currentBlobUrlRef = useRef(null); + const [isPlaying, setIsPlaying] = useState(false); + const [currentPlayingId, setCurrentPlayingId] = useState(null); + + const cleanup = useCallback(() => { + if (audioRef.current) { + audioRef.current.pause(); + audioRef.current.src = ""; + } + if (currentBlobUrlRef.current) { + URL.revokeObjectURL(currentBlobUrlRef.current); + currentBlobUrlRef.current = null; + } + setIsPlaying(false); + setCurrentPlayingId(null); + }, []); + + const play = useCallback( + ( + audioData: ArrayBuffer, + transcriptionId: number, + mimeType: string = "audio/wav", + ) => { + cleanup(); + + const blob = new Blob([audioData], { type: mimeType }); + const blobUrl = URL.createObjectURL(blob); + currentBlobUrlRef.current = blobUrl; + + if (!audioRef.current) { + audioRef.current = new Audio(); + } + + audioRef.current.src = blobUrl; + audioRef.current.onended = () => { + setIsPlaying(false); + setCurrentPlayingId(null); + }; + + audioRef.current + .play() + .then(() => { + setIsPlaying(true); + setCurrentPlayingId(transcriptionId); + }) + .catch((error) => { + console.error("Failed to play audio:", error); + cleanup(); + }); + }, + [cleanup], + ); + + const pause = useCallback(() => { + if (audioRef.current && !audioRef.current.paused) { + audioRef.current.pause(); + setIsPlaying(false); + } + }, []); + + const stop = useCallback(() => { + cleanup(); + }, [cleanup]); + + const toggle = useCallback( + (audioData: ArrayBuffer, transcriptionId: number, mimeType?: string) => { + if (currentPlayingId === transcriptionId && isPlaying) { + pause(); + } else { + play(audioData, transcriptionId, mimeType); + } + }, + [currentPlayingId, isPlaying, pause, play], + ); + + useEffect(() => { + return () => { + cleanup(); + }; + }, [cleanup]); + + return { + isPlaying, + currentPlayingId, + play, + pause, + stop, + toggle, + }; +} diff --git a/apps/desktop/src/renderer/main/pages/transcriptions/components/TranscriptionsList.tsx b/apps/desktop/src/renderer/main/pages/transcriptions/components/TranscriptionsList.tsx index 203d103..1c2b0a0 100644 --- a/apps/desktop/src/renderer/main/pages/transcriptions/components/TranscriptionsList.tsx +++ b/apps/desktop/src/renderer/main/pages/transcriptions/components/TranscriptionsList.tsx @@ -4,6 +4,7 @@ import { Button } from "@/components/ui/button"; import { Card, CardContent } from "@/components/ui/card"; import { Badge } from "@/components/ui/badge"; import { api } from "@/trpc/react"; +import { useAudioPlayer } from "@/hooks/useAudioPlayer"; import { Tooltip, @@ -14,8 +15,8 @@ import { import { Copy, Play, + Pause, Trash2, - Download, FileText, Search, MoreHorizontal, @@ -35,6 +36,7 @@ import { export const TranscriptionsList: React.FC = () => { const [searchTerm, setSearchTerm] = useState(""); const [openDropdownId, setOpenDropdownId] = useState(null); + const audioPlayer = useAudioPlayer(); // Get shortcuts data const shortcutsQuery = api.settings.getShortcuts.useQuery(); @@ -85,6 +87,34 @@ export const TranscriptionsList: React.FC = () => { }, }); + // Using mutation for fetching audio data instead of query to: + // - Prevent caching of large binary audio files in memory + // - Avoid automatic refetching behaviors (window focus, network reconnect) + // - Clearly indicate this is a user-triggered action (play button click) + // - Track loading state per transcription ID efficiently + const getAudioFileMutation = api.transcriptions.getAudioFile.useMutation({ + onSuccess: (data, variables) => { + if (data?.data) { + // Decode base64 to ArrayBuffer + const base64 = data.data; + const binaryString = atob(base64); + const bytes = new Uint8Array(binaryString.length); + for (let i = 0; i < binaryString.length; i++) { + bytes[i] = binaryString.charCodeAt(i); + } + // Pass the MIME type from the server response + audioPlayer.toggle( + bytes.buffer, + variables.transcriptionId, + data.mimeType, + ); + } + }, + onError: (error) => { + console.error("Error fetching audio file:", error); + }, + }); + const transcriptions = transcriptionsQuery.data || []; const totalCount = transcriptionsCountQuery.data || 0; const loading = @@ -103,9 +133,15 @@ export const TranscriptionsList: React.FC = () => { deleteTranscriptionMutation.mutate({ id }); }; - const handlePlayAudio = (audioFile: string) => { - // Implement audio playback functionality - console.log("Playing audio:", audioFile); + const handlePlayAudio = (transcriptionId: number) => { + if ( + audioPlayer.currentPlayingId === transcriptionId && + audioPlayer.isPlaying + ) { + audioPlayer.stop(); + } else { + getAudioFileMutation.mutate({ transcriptionId }); + } }; const handleDownloadAudio = async (transcriptionId: number) => { @@ -219,12 +255,27 @@ export const TranscriptionsList: React.FC = () => { variant="ghost" size="sm" className="h-8 w-8 p-0" - onClick={() => handlePlayAudio(transcription.audioFile!)} + onClick={() => handlePlayAudio(transcription.id)} + disabled={ + getAudioFileMutation.isPending && + getAudioFileMutation.variables?.transcriptionId === + transcription.id + } > - + {audioPlayer.currentPlayingId === transcription.id && + audioPlayer.isPlaying ? ( + + ) : ( + + )} - Play audio + + {audioPlayer.currentPlayingId === transcription.id && + audioPlayer.isPlaying + ? "Pause audio" + : "Play audio"} + )} diff --git a/apps/desktop/src/trpc/routers/transcriptions.ts b/apps/desktop/src/trpc/routers/transcriptions.ts index 4154ad3..5b310d9 100644 --- a/apps/desktop/src/trpc/routers/transcriptions.ts +++ b/apps/desktop/src/trpc/routers/transcriptions.ts @@ -120,10 +120,15 @@ export const transcriptionsRouter = createRouter({ return result; }), - // Get audio file for download + // Get audio file for playback + // Implemented as mutation instead of query because: + // 1. Large binary data (audio files) shouldn't be cached by React Query + // 2. Prevents automatic refetching on window focus/network reconnect + // 3. Represents an explicit user action (clicking play), not passive data fetching + // 4. Avoids memory overhead from React Query's caching system getAudioFile: procedure .input(z.object({ transcriptionId: z.number() })) - .query(async ({ input, ctx }) => { + .mutation(async ({ input, ctx }) => { const transcription = await getTranscriptionById(input.transcriptionId); if (!transcription?.audioFile) { @@ -138,10 +143,28 @@ export const transcriptionsRouter = createRouter({ const audioData = await fs.promises.readFile(transcription.audioFile); const filename = path.basename(transcription.audioFile); + // Detect MIME type based on file extension + const ext = path.extname(transcription.audioFile).toLowerCase(); + let mimeType = "audio/wav"; // Default for our WAV files + + // Map common audio extensions to MIME types + const mimeTypes: Record = { + ".wav": "audio/wav", + ".mp3": "audio/mpeg", + ".webm": "audio/webm", + ".ogg": "audio/ogg", + ".m4a": "audio/mp4", + ".flac": "audio/flac", + }; + + if (ext in mimeTypes) { + mimeType = mimeTypes[ext]; + } + return { - data: audioData, + data: audioData.toString("base64"), filename, - mimeType: "audio/webm", + mimeType, }; } catch (error) { const logger = ctx.serviceManager.getLogger(); @@ -155,6 +178,8 @@ export const transcriptionsRouter = createRouter({ }), // Download audio file with save dialog + // Mutation because this triggers a system dialog and file write operation + // Not a query since it has side effects beyond just fetching data downloadAudioFile: procedure .input(z.object({ transcriptionId: z.number() })) .mutation(async ({ input, ctx }) => { diff --git a/packages/smart-whisper/package.json b/packages/smart-whisper/package.json index 3f0cdd5..0dd7b80 100644 --- a/packages/smart-whisper/package.json +++ b/packages/smart-whisper/package.json @@ -1,41 +1,41 @@ { - "name": "@amical/smart-whisper", - "version": "0.1.0", - "description": "Whisper.cpp Node.js binding with auto model offloading strategy.", - "main": "dist/index.js", - "types": "dist/index.d.ts", - "keywords": [ - "whisper", - "whisper.cpp", - "native", - "binding", - "addon" - ], - "gypfile": true, - "files": [ - "dist", - "src", - "scripts", - "binding.gyp", - "whisper.cpp/**/*.{c,h,cpp,hpp,m,cu,metal}", - "whisper.cpp/Makefile", - "whisper.cpp/LICENSE" - ], - "scripts": { - "install": "tsup", - "postinstall": "node-gyp rebuild", - "build": "tsup && node-gyp rebuild", - "build:ts": "tsup", - "build:native": "node-gyp rebuild" - }, - "dependencies": { - "node-addon-api": "^8.5.0", - "minimatch": "10.0.3" - }, - "devDependencies": { - "@amical/typescript-config": "workspace:*", - "@types/node": "^24.3.0", - "tsup": "^8.5.0", - "typescript": "^5.8.2" - } + "name": "@amical/smart-whisper", + "version": "0.1.0", + "description": "Whisper.cpp Node.js binding with auto model offloading strategy.", + "main": "dist/index.js", + "types": "dist/index.d.ts", + "keywords": [ + "whisper", + "whisper.cpp", + "native", + "binding", + "addon" + ], + "gypfile": true, + "files": [ + "dist", + "src", + "scripts", + "binding.gyp", + "whisper.cpp/**/*.{c,h,cpp,hpp,m,cu,metal}", + "whisper.cpp/Makefile", + "whisper.cpp/LICENSE" + ], + "scripts": { + "install": "tsup", + "postinstall": "node-gyp rebuild", + "build": "tsup && node-gyp rebuild", + "build:ts": "tsup", + "build:native": "node-gyp rebuild" + }, + "dependencies": { + "node-addon-api": "^8.5.0", + "minimatch": "10.0.3" + }, + "devDependencies": { + "@amical/typescript-config": "workspace:*", + "@types/node": "^24.3.0", + "tsup": "^8.5.0", + "typescript": "^5.8.2" + } } diff --git a/packages/smart-whisper/src/binding.ts b/packages/smart-whisper/src/binding.ts index f8b421e..5781cb8 100644 --- a/packages/smart-whisper/src/binding.ts +++ b/packages/smart-whisper/src/binding.ts @@ -1,5 +1,6 @@ process.env.GGML_METAL_PATH_RESOURCES = - process.env.GGML_METAL_PATH_RESOURCES || path.join(__dirname, "../whisper.cpp/ggml/src"); + process.env.GGML_METAL_PATH_RESOURCES || + path.join(__dirname, "../whisper.cpp/ggml/src"); import path from "node:path"; import { TranscribeFormat, TranscribeParams, TranscribeResult } from "./types"; @@ -9,66 +10,66 @@ const module = require(path.join(__dirname, "../build/Release/smart-whisper")); * A external handle to a model. */ export type Handle = { - readonly "": unique symbol; + readonly "": unique symbol; }; export namespace Binding { - /** - * Load a model from a whisper weights file. - * @param file The path to the whisper weights file. - * @param gpu Whether to use the GPU or not. - * @param callback A callback that will be called with the handle to the model. - */ - export declare function load( - file: string, - gpu: boolean, - callback: (handle: Handle) => void, - ): void; + /** + * Load a model from a whisper weights file. + * @param file The path to the whisper weights file. + * @param gpu Whether to use the GPU or not. + * @param callback A callback that will be called with the handle to the model. + */ + export declare function load( + file: string, + gpu: boolean, + callback: (handle: Handle) => void, + ): void; - /** - * Release the memory of the model, it will be unusable after this. - * @param handle The handle to the model. - * @param callback A callback that will be called when the model is freed. - */ - export declare function free(handle: Handle, callback: () => void): void; + /** + * Release the memory of the model, it will be unusable after this. + * @param handle The handle to the model. + * @param callback A callback that will be called when the model is freed. + */ + export declare function free(handle: Handle, callback: () => void): void; - /** - * Transcribe a PCM buffer. - * @param handle The handle to the model. - * @param pcm The PCM buffer. - * @param params The parameters to use for transcription. - * @param finish A callback that will be called when the transcription is finished. - * @param progress A callback that will be called when a new result is available. - */ - export declare function transcribe< - Format extends TranscribeFormat, - TokenTimestamp extends boolean, - >( - handle: Handle, - pcm: Float32Array, - params: Partial>, - finish: (results: TranscribeResult[]) => void, - progress: (result: TranscribeResult) => void, - ): void; + /** + * Transcribe a PCM buffer. + * @param handle The handle to the model. + * @param pcm The PCM buffer. + * @param params The parameters to use for transcription. + * @param finish A callback that will be called when the transcription is finished. + * @param progress A callback that will be called when a new result is available. + */ + export declare function transcribe< + Format extends TranscribeFormat, + TokenTimestamp extends boolean, + >( + handle: Handle, + pcm: Float32Array, + params: Partial>, + finish: (results: TranscribeResult[]) => void, + progress: (result: TranscribeResult) => void, + ): void; - export declare class WhisperModel { - private _ctx; - constructor(handle: Handle); - get handle(): Handle | null; - get freed(): boolean; - /** - * Release the memory of the model, it will be unusable after this. - * It's safe to call this multiple times, but it will only free the model once. - */ - free(): Promise; - /** - * Load a model from a whisper weights file. - * @param file The path to the whisper weights file. - * @param gpu Whether to use the GPU or not. - * @returns A promise that resolves to a {@link WhisperModel}. - */ - static load(file: string, gpu?: boolean): Promise; - } + export declare class WhisperModel { + private _ctx; + constructor(handle: Handle); + get handle(): Handle | null; + get freed(): boolean; + /** + * Release the memory of the model, it will be unusable after this. + * It's safe to call this multiple times, but it will only free the model once. + */ + free(): Promise; + /** + * Load a model from a whisper weights file. + * @param file The path to the whisper weights file. + * @param gpu Whether to use the GPU or not. + * @returns A promise that resolves to a {@link WhisperModel}. + */ + static load(file: string, gpu?: boolean): Promise; + } } /** diff --git a/packages/smart-whisper/src/build.ts b/packages/smart-whisper/src/build.ts index 209c3c6..3d851d3 100644 --- a/packages/smart-whisper/src/build.ts +++ b/packages/smart-whisper/src/build.ts @@ -10,85 +10,88 @@ export const defines = cfg.defines.join(" "); export const libraries = cfg.libraries.join(" "); function config(): { - sources: string[]; - defines: string[]; - libraries: string[]; + sources: string[]; + defines: string[]; + libraries: string[]; } { - if (process.env.BYOL) { - return { - sources: [], - defines: [], - libraries: [process.env.BYOL], - }; - } + if (process.env.BYOL) { + return { + sources: [], + defines: [], + libraries: [process.env.BYOL], + }; + } - const COMPUTE_BACKEND: ComputeBackend = - (process.env.COMPUTE_BACKEND as ComputeBackend | undefined) ?? infer_backend(); + const COMPUTE_BACKEND: ComputeBackend = + (process.env.COMPUTE_BACKEND as ComputeBackend | undefined) ?? + infer_backend(); - const cfg = { - sources: [ - "whisper.cpp/src/whisper.cpp", - "whisper.cpp/ggml/src/ggml.c", - "whisper.cpp/ggml/src/ggml-alloc.c", - "whisper.cpp/ggml/src/ggml-backend.c", - "whisper.cpp/ggml/src/ggml-quants.c", - "whisper.cpp/ggml/src/ggml-aarch64.c", - ] as string[], - defines: [] as string[], - libraries: [] as string[], - }; + const cfg = { + sources: [ + "whisper.cpp/src/whisper.cpp", + "whisper.cpp/ggml/src/ggml.c", + "whisper.cpp/ggml/src/ggml-alloc.c", + "whisper.cpp/ggml/src/ggml-backend.c", + "whisper.cpp/ggml/src/ggml-quants.c", + "whisper.cpp/ggml/src/ggml-aarch64.c", + ] as string[], + defines: [] as string[], + libraries: [] as string[], + }; - switch (COMPUTE_BACKEND) { - case "accelerate": { - cfg.defines.push("GGML_USE_ACCELERATE"); + switch (COMPUTE_BACKEND) { + case "accelerate": { + cfg.defines.push("GGML_USE_ACCELERATE"); - cfg.libraries.push('"-framework Foundation"'); - cfg.libraries.push('"-framework Accelerate"'); - break; - } - case "metal": { - cfg.sources.push("whisper.cpp/ggml/src/ggml-metal.m"); + cfg.libraries.push('"-framework Foundation"'); + cfg.libraries.push('"-framework Accelerate"'); + break; + } + case "metal": { + cfg.sources.push("whisper.cpp/ggml/src/ggml-metal.m"); - cfg.defines.push("GGML_USE_ACCELERATE"); - cfg.defines.push("GGML_USE_METAL"); + cfg.defines.push("GGML_USE_ACCELERATE"); + cfg.defines.push("GGML_USE_METAL"); - cfg.libraries.push('"-framework Foundation"'); - cfg.libraries.push('"-framework Accelerate"'); - cfg.libraries.push('"-framework Metal"'); - cfg.libraries.push('"-framework MetalKit"'); - break; - } - case "openblas": { - cfg.defines.push("GGML_USE_OPENBLAS"); + cfg.libraries.push('"-framework Foundation"'); + cfg.libraries.push('"-framework Accelerate"'); + cfg.libraries.push('"-framework Metal"'); + cfg.libraries.push('"-framework MetalKit"'); + break; + } + case "openblas": { + cfg.defines.push("GGML_USE_OPENBLAS"); - cfg.libraries.push("-lopenblas"); - break; - } - default: { - } - } + cfg.libraries.push("-lopenblas"); + break; + } + default: { + } + } - return cfg; + return cfg; } function infer_backend(): ComputeBackend { - let backend: ComputeBackend = "cpu"; + let backend: ComputeBackend = "cpu"; - try { - if (os.platform() === "darwin") { - backend = "accelerate"; - if (os.arch() === "arm64") { - backend = "metal"; - } - } else if (os.platform() === "linux") { - const has_libopenblas = !!execSync("ldconfig -p | grep libopenblas").toString().trim(); - if (has_libopenblas) { - backend = "openblas"; - } - } - } catch { - // if anything goes wrong, just use the default cpu backend - } + try { + if (os.platform() === "darwin") { + backend = "accelerate"; + if (os.arch() === "arm64") { + backend = "metal"; + } + } else if (os.platform() === "linux") { + const has_libopenblas = !!execSync("ldconfig -p | grep libopenblas") + .toString() + .trim(); + if (has_libopenblas) { + backend = "openblas"; + } + } + } catch { + // if anything goes wrong, just use the default cpu backend + } - return backend; + return backend; } diff --git a/packages/smart-whisper/src/model-manager/index.ts b/packages/smart-whisper/src/model-manager/index.ts index 03340e0..d0dc41a 100644 --- a/packages/smart-whisper/src/model-manager/index.ts +++ b/packages/smart-whisper/src/model-manager/index.ts @@ -10,7 +10,8 @@ const ext = ".bin"; fs.mkdirSync(models, { recursive: true }); -const BASE_MODELS_URL = "https://huggingface.co/ggerganov/whisper.cpp/resolve/main"; +const BASE_MODELS_URL = + "https://huggingface.co/ggerganov/whisper.cpp/resolve/main"; /** * MODELS is an object that contains the URLs of different ggml whisper models. @@ -18,18 +19,18 @@ const BASE_MODELS_URL = "https://huggingface.co/ggerganov/whisper.cpp/resolve/ma * and the value is the URL of the model. */ export const MODELS = { - tiny: `${BASE_MODELS_URL}/ggml-tiny.bin`, - "tiny.en": `${BASE_MODELS_URL}/ggml-tiny.en.bin`, - small: `${BASE_MODELS_URL}/ggml-small.bin`, - "small.en": `${BASE_MODELS_URL}/ggml-small.en.bin`, - base: `${BASE_MODELS_URL}/ggml-base.bin`, - "base.en": `${BASE_MODELS_URL}/ggml-base.en.bin`, - medium: `${BASE_MODELS_URL}/ggml-medium.bin`, - "medium.en": `${BASE_MODELS_URL}/ggml-medium.en.bin`, - "large-v1": `${BASE_MODELS_URL}/ggml-large-v1.bin`, - "large-v2": `${BASE_MODELS_URL}/ggml-large-v2.bin`, - "large-v3": `${BASE_MODELS_URL}/ggml-large-v3.bin`, - "large-v3-turbo": `${BASE_MODELS_URL}/ggml-large-v3-turbo.bin`, + tiny: `${BASE_MODELS_URL}/ggml-tiny.bin`, + "tiny.en": `${BASE_MODELS_URL}/ggml-tiny.en.bin`, + small: `${BASE_MODELS_URL}/ggml-small.bin`, + "small.en": `${BASE_MODELS_URL}/ggml-small.en.bin`, + base: `${BASE_MODELS_URL}/ggml-base.bin`, + "base.en": `${BASE_MODELS_URL}/ggml-base.en.bin`, + medium: `${BASE_MODELS_URL}/ggml-medium.bin`, + "medium.en": `${BASE_MODELS_URL}/ggml-medium.en.bin`, + "large-v1": `${BASE_MODELS_URL}/ggml-large-v1.bin`, + "large-v2": `${BASE_MODELS_URL}/ggml-large-v2.bin`, + "large-v3": `${BASE_MODELS_URL}/ggml-large-v3.bin`, + "large-v3-turbo": `${BASE_MODELS_URL}/ggml-large-v3-turbo.bin`, } as const; export type ModelName = keyof typeof MODELS | (string & {}); @@ -42,39 +43,41 @@ export type ModelName = keyof typeof MODELS | (string & {}); * @throws An error if the model URL or shorthand is invalid, or if the model fails to download. */ export async function download(model: ModelName): Promise { - let url = "", - name = ""; - if (model in MODELS) { - url = MODELS[model as keyof typeof MODELS]; - name = model; - } else { - try { - url = new URL(model).href; - name = new URL(url).pathname.split("/").pop() ?? ""; - } catch {} - } + let url = "", + name = ""; + if (model in MODELS) { + url = MODELS[model as keyof typeof MODELS]; + name = model; + } else { + try { + url = new URL(model).href; + name = new URL(url).pathname.split("/").pop() ?? ""; + } catch {} + } - if (!url) { - throw new Error(`Invalid model URL or shorthand: ${model}`); - } + if (!url) { + throw new Error(`Invalid model URL or shorthand: ${model}`); + } - if (!name) { - throw new Error(`Failed to parse model name: ${url}`); - } + if (!name) { + throw new Error(`Failed to parse model name: ${url}`); + } - if (check(name)) { - return name; - } + if (check(name)) { + return name; + } - const res = await fetch(url); - if (!res.ok || !res.body) { - throw new Error(`Failed to download model: ${res.statusText}`); - } + const res = await fetch(url); + if (!res.ok || !res.body) { + throw new Error(`Failed to download model: ${res.statusText}`); + } - const stream = fs.createWriteStream(path.join(models, name.endsWith(ext) ? name : name + ext)); - Readable.fromWeb(res.body as ReadableStream).pipe(stream); + const stream = fs.createWriteStream( + path.join(models, name.endsWith(ext) ? name : name + ext), + ); + Readable.fromWeb(res.body as ReadableStream).pipe(stream); - return new Promise((resolve) => stream.on("finish", () => resolve(name))); + return new Promise((resolve) => stream.on("finish", () => resolve(name))); } /** @@ -82,9 +85,9 @@ export async function download(model: ModelName): Promise { * @param model - The name of the model to remove. */ export function remove(model: ModelName): void { - if (check(model)) { - fs.unlinkSync(path.join(models, model + ext)); - } + if (check(model)) { + fs.unlinkSync(path.join(models, model + ext)); + } } /** @@ -92,8 +95,8 @@ export function remove(model: ModelName): void { * @returns An array of model names. */ export function list(): ModelName[] { - const files = fs.readdirSync(models).filter((file) => file.endsWith(ext)); - return files.map((file) => file.slice(0, -ext.length)); + const files = fs.readdirSync(models).filter((file) => file.endsWith(ext)); + return files.map((file) => file.slice(0, -ext.length)); } /** @@ -102,7 +105,7 @@ export function list(): ModelName[] { * @returns True if the model exists, false otherwise. */ export function check(model: ModelName): boolean { - return fs.existsSync(path.join(models, model + ext)); + return fs.existsSync(path.join(models, model + ext)); } /** @@ -112,11 +115,11 @@ export function check(model: ModelName): boolean { * @throws Error if the model is not found. */ export function resolve(model: ModelName): string { - if (check(model)) { - return path.join(models, model + ext); - } else { - throw new Error(`Model not found: ${model}`); - } + if (check(model)) { + return path.join(models, model + ext); + } else { + throw new Error(`Model not found: ${model}`); + } } export const dir = { root, models }; diff --git a/packages/smart-whisper/src/transcribe.ts b/packages/smart-whisper/src/transcribe.ts index a166a86..376e2f5 100644 --- a/packages/smart-whisper/src/transcribe.ts +++ b/packages/smart-whisper/src/transcribe.ts @@ -4,107 +4,111 @@ import { TranscribeFormat, TranscribeParams, TranscribeResult } from "./types"; import { binding } from "./binding"; export class TranscribeTask< - Format extends TranscribeFormat, - TokenTimestamp extends boolean, + Format extends TranscribeFormat, + TokenTimestamp extends boolean, > extends EventEmitter { - private _model: WhisperModel; - private _result: Promise[]> | null = null; + private _model: WhisperModel; + private _result: Promise[]> | null = + null; - /** - * You should not construct this class directly, use {@link TranscribeTask.run} instead. - */ - constructor(model: WhisperModel) { - super(); - this._model = model; - } + /** + * You should not construct this class directly, use {@link TranscribeTask.run} instead. + */ + constructor(model: WhisperModel) { + super(); + this._model = model; + } - get model(): WhisperModel { - return this._model; - } + get model(): WhisperModel { + return this._model; + } - /** - * A promise that resolves to the result of the transcription task. - */ - get result(): Promise[]> { - if (this._result === null) { - throw new Error("Task has not been started"); - } - return this._result; - } + /** + * A promise that resolves to the result of the transcription task. + */ + get result(): Promise[]> { + if (this._result === null) { + throw new Error("Task has not been started"); + } + return this._result; + } - private async _run( - pcm: Float32Array, - params: Partial>, - ): Promise[]> { - return new Promise((resolve) => { - const handle = this.model.handle; - if (!handle) { - throw new Error("Model has been freed"); - } + private async _run( + pcm: Float32Array, + params: Partial>, + ): Promise[]> { + return new Promise((resolve) => { + const handle = this.model.handle; + if (!handle) { + throw new Error("Model has been freed"); + } - binding.transcribe( - handle, - pcm, - params, - (results) => { - this.emit("finish"); - resolve(results); - }, - (result) => { - this.emit("transcribed", result); - }, - ); - }); - } + binding.transcribe( + handle, + pcm, + params, + (results) => { + this.emit("finish"); + resolve(results); + }, + (result) => { + this.emit("transcribed", result); + }, + ); + }); + } - static async run( - model: WhisperModel, - pcm: Float32Array, - params: Partial>, - ): Promise> { - if (model.freed) { - throw new Error("Model has been freed"); - } + static async run< + Format extends TranscribeFormat, + TokenTimestamp extends boolean, + >( + model: WhisperModel, + pcm: Float32Array, + params: Partial>, + ): Promise> { + if (model.freed) { + throw new Error("Model has been freed"); + } - const task = new TranscribeTask(model); - task._result = task._run(pcm, params); + const task = new TranscribeTask(model); + task._result = task._run(pcm, params); - return task; - } + return task; + } - on( - event: "finish", - listener: (results: TranscribeResult[]) => void, - ): this; - on( - event: "transcribed", - listener: (result: TranscribeResult) => void, - ): this; - on(event: string, listener: (...args: any[]) => void): this { - return super.on(event, listener); - } + on( + event: "finish", + listener: (results: TranscribeResult[]) => void, + ): this; + on( + event: "transcribed", + listener: (result: TranscribeResult) => void, + ): this; + on(event: string, listener: (...args: any[]) => void): this { + return super.on(event, listener); + } - once( - event: "finish", - listener: (results: TranscribeResult[]) => void, - ): this; - once( - event: "transcribed", - listener: (result: TranscribeResult) => void, - ): this; - once(event: string, listener: (...args: any[]) => void): this { - return super.once(event, listener); - } + once( + event: "finish", + listener: (results: TranscribeResult[]) => void, + ): this; + once( + event: "transcribed", + listener: (result: TranscribeResult) => void, + ): this; + once(event: string, listener: (...args: any[]) => void): this { + return super.once(event, listener); + } - off( - event: "finish", - listener: (results: TranscribeResult[]) => void, - ): this; - off( - event: "transcribed", - listener: (result: TranscribeResult) => void, - ): this; - off(event: string, listener: (...args: any[]) => void): this { - return super.off(event, listener); - } + off( + event: "finish", + listener: (results: TranscribeResult[]) => void, + ): this; + off( + event: "transcribed", + listener: (result: TranscribeResult) => void, + ): this; + off(event: string, listener: (...args: any[]) => void): this { + return super.off(event, listener); + } } diff --git a/packages/smart-whisper/src/types.ts b/packages/smart-whisper/src/types.ts index 1cf96bf..a1c627d 100644 --- a/packages/smart-whisper/src/types.ts +++ b/packages/smart-whisper/src/types.ts @@ -1,6 +1,6 @@ export enum WhisperSamplingStrategy { - WHISPER_SAMPLING_GREEDY, - WHISPER_SAMPLING_BEAM_SEARCH, + WHISPER_SAMPLING_GREEDY, + WHISPER_SAMPLING_BEAM_SEARCH, } export type TranscribeFormat = "simple" | "detail"; @@ -9,94 +9,96 @@ export type TranscribeFormat = "simple" | "detail"; * See {@link https://github.com/ggerganov/whisper.cpp/blob/00b7a4be02ca82d53ac69dd2dd438c16e2af7658/whisper.h#L433C19-L433C19} for details. */ export interface TranscribeParams< - Format extends TranscribeFormat = TranscribeFormat, - TokenTimestamp extends boolean = false, + Format extends TranscribeFormat = TranscribeFormat, + TokenTimestamp extends boolean = false, > { - strategy: WhisperSamplingStrategy; - n_threads: number; - n_max_text_ctx: number; - offset_ms: number; - duration_ms: number; + strategy: WhisperSamplingStrategy; + n_threads: number; + n_max_text_ctx: number; + offset_ms: number; + duration_ms: number; - translate: boolean; - no_context: boolean; - no_timestamps: boolean; - single_segment: boolean; - print_special: boolean; - print_progress: boolean; - print_realtime: boolean; - print_timestamps: boolean; + translate: boolean; + no_context: boolean; + no_timestamps: boolean; + single_segment: boolean; + print_special: boolean; + print_progress: boolean; + print_realtime: boolean; + print_timestamps: boolean; - token_timestamps: TokenTimestamp; - thold_pt: number; - thold_ptsum: number; - max_len: number; - split_on_word: boolean; - max_tokens: number; + token_timestamps: TokenTimestamp; + thold_pt: number; + thold_ptsum: number; + max_len: number; + split_on_word: boolean; + max_tokens: number; - speed_up: boolean; - debug_mode: boolean; - audio_ctx: number; + speed_up: boolean; + debug_mode: boolean; + audio_ctx: number; - tdrz_enable: boolean; + tdrz_enable: boolean; - initial_prompt: string; + initial_prompt: string; - /** - * Language code, e.g. "en", "de", "fr", "es", "it", "nl", "pt", "ru", "tr", "uk", "pl", "sv", "cs", "zh", "ja", "ko" - */ - language: string; + /** + * Language code, e.g. "en", "de", "fr", "es", "it", "nl", "pt", "ru", "tr", "uk", "pl", "sv", "cs", "zh", "ja", "ko" + */ + language: string; - suppress_blank: boolean; - suppress_non_speech_tokens: boolean; + suppress_blank: boolean; + suppress_non_speech_tokens: boolean; - temperature: number; - max_initial_ts: number; - length_penalty: number; + temperature: number; + max_initial_ts: number; + length_penalty: number; - temperature_inc: number; - entropy_thold: number; - logprob_thold: number; - no_speech_thold: number; + temperature_inc: number; + entropy_thold: number; + logprob_thold: number; + no_speech_thold: number; - best_of: number; + best_of: number; - beam_size: number; + beam_size: number; - format: Format; + format: Format; } export interface TranscribeSimpleResult { - from: number; - to: number; - text: string; + from: number; + to: number; + text: string; } /** * Represents a detailed result of transcription. */ export interface TranscribeDetailedResult - extends TranscribeSimpleResult { - /** The detected spoken language. */ - lang: string; - /** The confidence level of the transcription, calculated by the average probability of the tokens. */ - confidence: number; - /** The tokens generated during the transcription process. */ - tokens: { - /** The text of the token, for CJK languages, due to the BPE encoding, the token text may not be readable. */ - text: string; - /** The ID of the token. */ - id: number; - /** The probability of the token. */ - p: number; - /** The start timestamp of the token, in milliseconds. Only available when `token_timestamps` of {@link TranscribeParams} is `true`. */ - from: TokenTimestamp extends true ? number : undefined; - /** The end timestamp of the token, in milliseconds. Only available when `token_timestamps` of {@link TranscribeParams} is `true`. */ - to: TokenTimestamp extends true ? number : undefined; - }[]; + extends TranscribeSimpleResult { + /** The detected spoken language. */ + lang: string; + /** The confidence level of the transcription, calculated by the average probability of the tokens. */ + confidence: number; + /** The tokens generated during the transcription process. */ + tokens: { + /** The text of the token, for CJK languages, due to the BPE encoding, the token text may not be readable. */ + text: string; + /** The ID of the token. */ + id: number; + /** The probability of the token. */ + p: number; + /** The start timestamp of the token, in milliseconds. Only available when `token_timestamps` of {@link TranscribeParams} is `true`. */ + from: TokenTimestamp extends true ? number : undefined; + /** The end timestamp of the token, in milliseconds. Only available when `token_timestamps` of {@link TranscribeParams} is `true`. */ + to: TokenTimestamp extends true ? number : undefined; + }[]; } export type TranscribeResult< - Format extends TranscribeFormat = TranscribeFormat, - TokenTimestamp extends boolean = boolean, -> = Format extends "simple" ? TranscribeSimpleResult : TranscribeDetailedResult; + Format extends TranscribeFormat = TranscribeFormat, + TokenTimestamp extends boolean = boolean, +> = Format extends "simple" + ? TranscribeSimpleResult + : TranscribeDetailedResult; diff --git a/packages/smart-whisper/src/whisper.ts b/packages/smart-whisper/src/whisper.ts index 4119be9..fad4071 100644 --- a/packages/smart-whisper/src/whisper.ts +++ b/packages/smart-whisper/src/whisper.ts @@ -1,17 +1,21 @@ -import type { TranscribeFormat, TranscribeParams, TranscribeResult } from "./types"; +import type { + TranscribeFormat, + TranscribeParams, + TranscribeResult, +} from "./types"; import { WhisperModel } from "./model"; import { TranscribeTask } from "./transcribe"; export interface WhisperConfig { - /** - * Time in seconds to wait before offloading the model if it's not being used. - */ - offload: number; + /** + * Time in seconds to wait before offloading the model if it's not being used. + */ + offload: number; - /** - * Whether to use the GPU or not. - */ - gpu: boolean; + /** + * Whether to use the GPU or not. + */ + gpu: boolean; } /** @@ -19,112 +23,119 @@ export interface WhisperConfig { * It handles the loading and offloading of the model, managing transcription tasks, and configuring model parameters. */ export class Whisper { - private _file: string; - private _available: WhisperModel | null = null; - private _loading: Promise | null = null; - private _tasks: Promise[] = []; - private _config: WhisperConfig; - private _offload_timer: NodeJS.Timeout | null = null; + private _file: string; + private _available: WhisperModel | null = null; + private _loading: Promise | null = null; + private _tasks: Promise[] = []; + private _config: WhisperConfig; + private _offload_timer: NodeJS.Timeout | null = null; - /** - * Constructs a new Whisper instance with a specified model file and configuration. - * @param file - The path to the Whisper model file. - * @param config - Optional configuration for the Whisper instance. - */ - constructor(file: string, config: Partial = {}) { - this._file = file; - this._config = { - offload: 300, - gpu: true, - ...config, - }; - } + /** + * Constructs a new Whisper instance with a specified model file and configuration. + * @param file - The path to the Whisper model file. + * @param config - Optional configuration for the Whisper instance. + */ + constructor(file: string, config: Partial = {}) { + this._file = file; + this._config = { + offload: 300, + gpu: true, + ...config, + }; + } - get file(): string { - return this._file; - } + get file(): string { + return this._file; + } - set file(file: string) { - this._file = file; - } + set file(file: string) { + this._file = file; + } - get config(): WhisperConfig { - return this._config; - } + get config(): WhisperConfig { + return this._config; + } - get tasks(): Promise[] { - return this._tasks; - } + get tasks(): Promise[] { + return this._tasks; + } - reset_offload_timer(): void { - this.clear_offload_timer(); - this._offload_timer = setTimeout(() => { - this.free(); - }, this.config.offload * 1000); - } + reset_offload_timer(): void { + this.clear_offload_timer(); + this._offload_timer = setTimeout(() => { + this.free(); + }, this.config.offload * 1000); + } - private clear_offload_timer(): void { - if (this._offload_timer !== null) { - clearTimeout(this._offload_timer); - this._offload_timer = null; - } - } + private clear_offload_timer(): void { + if (this._offload_timer !== null) { + clearTimeout(this._offload_timer); + this._offload_timer = null; + } + } - async model(): Promise { - if (this._available === null) { - return this.load(); - } - this.reset_offload_timer(); - return Promise.resolve(this._available); - } + async model(): Promise { + if (this._available === null) { + return this.load(); + } + this.reset_offload_timer(); + return Promise.resolve(this._available); + } - /** - * Loads the whisper model asynchronously. - * If the model is already being loaded, returns the existing one. - * - * You don't need to call this method directly, it's called automatically if necessary when you call {@link Whisper.transcribe}. - * - * @returns A Promise that resolves to the loaded model. - */ - async load(): Promise { - if (this._loading !== null) { - return this._loading; - } + /** + * Loads the whisper model asynchronously. + * If the model is already being loaded, returns the existing one. + * + * You don't need to call this method directly, it's called automatically if necessary when you call {@link Whisper.transcribe}. + * + * @returns A Promise that resolves to the loaded model. + */ + async load(): Promise { + if (this._loading !== null) { + return this._loading; + } - const model = WhisperModel.load(this.file, this.config.gpu); - this._loading = model; - this._available = await model; - this._loading = null; - this.reset_offload_timer(); - return this._available; - } + const model = WhisperModel.load(this.file, this.config.gpu); + this._loading = model; + this._available = await model; + this._loading = null; + this.reset_offload_timer(); + return this._available; + } - /** - * Transcribes the given PCM audio data using the Whisper model. - * @param pcm - The mono 16k PCM audio data to transcribe. - * @param params - Optional parameters for transcription. - * @returns A promise that resolves to the result of the transcription task. - */ - async transcribe( - pcm: Float32Array, - params: Partial> = {}, - ): Promise> { - const model = await this.model(); - const task = await TranscribeTask.run(model, pcm, params); - this._tasks.push(task.result); - return task; - } + /** + * Transcribes the given PCM audio data using the Whisper model. + * @param pcm - The mono 16k PCM audio data to transcribe. + * @param params - Optional parameters for transcription. + * @returns A promise that resolves to the result of the transcription task. + */ + async transcribe< + Format extends TranscribeFormat, + TokenTimestamp extends boolean, + >( + pcm: Float32Array, + params: Partial> = {}, + ): Promise> { + const model = await this.model(); + const task = await TranscribeTask.run( + model, + pcm, + params, + ); + this._tasks.push(task.result); + return task; + } - async free(): Promise { - if (this._available === null) { - return; - } - const model = this._available; - this._available = null; - this.clear_offload_timer(); - await Promise.all(this.tasks); - await model.free(); - } + async free(): Promise { + if (this._available === null) { + return; + } + const model = this._available; + this._available = null; + this.clear_offload_timer(); + await Promise.all(this.tasks); + await model.free(); + } } /** diff --git a/packages/smart-whisper/tsup.config.ts b/packages/smart-whisper/tsup.config.ts index d76bcf2..9f67551 100644 --- a/packages/smart-whisper/tsup.config.ts +++ b/packages/smart-whisper/tsup.config.ts @@ -2,14 +2,17 @@ import { defineConfig } from "tsup"; import { readFileSync, writeFileSync } from "node:fs"; export default defineConfig({ - entry: ["src/index.ts", "src/build.ts"], - outDir: "dist", - dts: true, - async onSuccess() { - // replace `#include "ggml-common.h" in whisper.cpp/ggml/src/ggml-metal.metal with full content - const metal = readFileSync("whisper.cpp/ggml/src/ggml-metal.metal", "utf-8"); - const common = readFileSync("whisper.cpp/ggml/src/ggml-common.h", "utf-8"); - const replaced = metal.replace(/#include "ggml-common.h"/, common); - writeFileSync("whisper.cpp/ggml/src/ggml-metal.metal", replaced); - }, + entry: ["src/index.ts", "src/build.ts"], + outDir: "dist", + dts: true, + async onSuccess() { + // replace `#include "ggml-common.h" in whisper.cpp/ggml/src/ggml-metal.metal with full content + const metal = readFileSync( + "whisper.cpp/ggml/src/ggml-metal.metal", + "utf-8", + ); + const common = readFileSync("whisper.cpp/ggml/src/ggml-common.h", "utf-8"); + const replaced = metal.replace(/#include "ggml-common.h"/, common); + writeFileSync("whisper.cpp/ggml/src/ggml-metal.metal", replaced); + }, }); diff --git a/turbo.json b/turbo.json index 4aecadf..93bad84 100644 --- a/turbo.json +++ b/turbo.json @@ -12,7 +12,14 @@ "build": { "dependsOn": ["^build"], "inputs": ["$TURBO_DEFAULT$", ".env*"], - "outputs": [".next/**", "!.next/cache/**", "bin/**", "out/**", "dist/**", "build/**"], + "outputs": [ + ".next/**", + "!.next/cache/**", + "bin/**", + "out/**", + "dist/**", + "build/**" + ], "env": [ "AWS_ACCESS_KEY_ID", "AWS_SECRET_ACCESS_KEY", @@ -23,7 +30,15 @@ }, "build:native": { "dependsOn": [], - "inputs": ["Sources/**", "Package.swift", "main.swift", "scripts/**", "src/binding/**", "binding.gyp", "whisper.cpp/**"], + "inputs": [ + "Sources/**", + "Package.swift", + "main.swift", + "scripts/**", + "src/binding/**", + "binding.gyp", + "whisper.cpp/**" + ], "outputs": ["bin/**", "build/**"], "cache": true },