fix: wire up transcription playback

This commit is contained in:
nchopra 2025-08-26 01:29:35 +05:30
parent 17d034be80
commit e7d9e91bf4
13 changed files with 750 additions and 511 deletions

View file

@ -149,7 +149,11 @@ const config: ForgeConfig = {
// Copy the package
console.log(`Copying ${dep}...`);
cpSync(rootDepPath, localDepPath, { recursive: true, dereference: true, force: true });
cpSync(rootDepPath, localDepPath, {
recursive: true,
dereference: true,
force: true,
});
console.log(`✓ Successfully copied ${dep}`);
} catch (error) {
console.error(`Failed to copy ${dep}:`, error);
@ -160,31 +164,35 @@ const config: ForgeConfig = {
console.log("Checking for symlinks in copied dependencies...");
for (const dep of nativeModuleDependenciesToPackage) {
const localDepPath = join(localNodeModules, dep);
try {
if (existsSync(localDepPath)) {
const stats = lstatSync(localDepPath);
if (stats.isSymbolicLink()) {
console.log(`Found symlink for ${dep}, replacing with dereferenced copy...`);
console.log(
`Found symlink for ${dep}, replacing with dereferenced copy...`,
);
// Read where the symlink points to
const symlinkTarget = readlinkSync(localDepPath);
const absoluteTarget = join(localDepPath, "..", symlinkTarget);
const sourcePath = normalize(absoluteTarget);
console.log(` Symlink points to: ${sourcePath}`);
// Remove the symlink
rmSync(localDepPath, { recursive: true, force: true });
// Copy with dereference to get actual content
cpSync(sourcePath, localDepPath, {
recursive: true,
cpSync(sourcePath, localDepPath, {
recursive: true,
force: true,
dereference: true // Follow symlinks and copy actual content
dereference: true, // Follow symlinks and copy actual content
});
console.log(`✓ Successfully replaced symlink for ${dep} with actual content`);
console.log(
`✓ Successfully replaced symlink for ${dep} with actual content`,
);
}
}
} catch (error) {
@ -404,8 +412,10 @@ const config: ForgeConfig = {
const scopeDir = dep.split("/")[0]; // @libsql/client -> @libsql
// for workspace packages only keep the actual package
if (scopeDir === "@amical") {
if (filePath.startsWith(`/node_modules/${dep}`) ||
filePath === `/node_modules/${scopeDir}`) {
if (
filePath.startsWith(`/node_modules/${dep}`) ||
filePath === `/node_modules/${scopeDir}`
) {
KEEP_FILE.keep = true;
KEEP_FILE.log = true;
}

View file

@ -0,0 +1,111 @@
import { useRef, useState, useCallback, useEffect } from "react";
interface UseAudioPlayerReturn {
isPlaying: boolean;
currentPlayingId: number | null;
play: (
audioData: ArrayBuffer,
transcriptionId: number,
mimeType?: string,
) => void;
pause: () => void;
stop: () => void;
toggle: (
audioData: ArrayBuffer,
transcriptionId: number,
mimeType?: string,
) => void;
}
export function useAudioPlayer(): UseAudioPlayerReturn {
const audioRef = useRef<HTMLAudioElement | null>(null);
const currentBlobUrlRef = useRef<string | null>(null);
const [isPlaying, setIsPlaying] = useState(false);
const [currentPlayingId, setCurrentPlayingId] = useState<number | null>(null);
const cleanup = useCallback(() => {
if (audioRef.current) {
audioRef.current.pause();
audioRef.current.src = "";
}
if (currentBlobUrlRef.current) {
URL.revokeObjectURL(currentBlobUrlRef.current);
currentBlobUrlRef.current = null;
}
setIsPlaying(false);
setCurrentPlayingId(null);
}, []);
const play = useCallback(
(
audioData: ArrayBuffer,
transcriptionId: number,
mimeType: string = "audio/wav",
) => {
cleanup();
const blob = new Blob([audioData], { type: mimeType });
const blobUrl = URL.createObjectURL(blob);
currentBlobUrlRef.current = blobUrl;
if (!audioRef.current) {
audioRef.current = new Audio();
}
audioRef.current.src = blobUrl;
audioRef.current.onended = () => {
setIsPlaying(false);
setCurrentPlayingId(null);
};
audioRef.current
.play()
.then(() => {
setIsPlaying(true);
setCurrentPlayingId(transcriptionId);
})
.catch((error) => {
console.error("Failed to play audio:", error);
cleanup();
});
},
[cleanup],
);
const pause = useCallback(() => {
if (audioRef.current && !audioRef.current.paused) {
audioRef.current.pause();
setIsPlaying(false);
}
}, []);
const stop = useCallback(() => {
cleanup();
}, [cleanup]);
const toggle = useCallback(
(audioData: ArrayBuffer, transcriptionId: number, mimeType?: string) => {
if (currentPlayingId === transcriptionId && isPlaying) {
pause();
} else {
play(audioData, transcriptionId, mimeType);
}
},
[currentPlayingId, isPlaying, pause, play],
);
useEffect(() => {
return () => {
cleanup();
};
}, [cleanup]);
return {
isPlaying,
currentPlayingId,
play,
pause,
stop,
toggle,
};
}

View file

@ -4,6 +4,7 @@ import { Button } from "@/components/ui/button";
import { Card, CardContent } from "@/components/ui/card";
import { Badge } from "@/components/ui/badge";
import { api } from "@/trpc/react";
import { useAudioPlayer } from "@/hooks/useAudioPlayer";
import {
Tooltip,
@ -14,8 +15,8 @@ import {
import {
Copy,
Play,
Pause,
Trash2,
Download,
FileText,
Search,
MoreHorizontal,
@ -35,6 +36,7 @@ import {
export const TranscriptionsList: React.FC = () => {
const [searchTerm, setSearchTerm] = useState("");
const [openDropdownId, setOpenDropdownId] = useState<number | null>(null);
const audioPlayer = useAudioPlayer();
// Get shortcuts data
const shortcutsQuery = api.settings.getShortcuts.useQuery();
@ -85,6 +87,34 @@ export const TranscriptionsList: React.FC = () => {
},
});
// Using mutation for fetching audio data instead of query to:
// - Prevent caching of large binary audio files in memory
// - Avoid automatic refetching behaviors (window focus, network reconnect)
// - Clearly indicate this is a user-triggered action (play button click)
// - Track loading state per transcription ID efficiently
const getAudioFileMutation = api.transcriptions.getAudioFile.useMutation({
onSuccess: (data, variables) => {
if (data?.data) {
// Decode base64 to ArrayBuffer
const base64 = data.data;
const binaryString = atob(base64);
const bytes = new Uint8Array(binaryString.length);
for (let i = 0; i < binaryString.length; i++) {
bytes[i] = binaryString.charCodeAt(i);
}
// Pass the MIME type from the server response
audioPlayer.toggle(
bytes.buffer,
variables.transcriptionId,
data.mimeType,
);
}
},
onError: (error) => {
console.error("Error fetching audio file:", error);
},
});
const transcriptions = transcriptionsQuery.data || [];
const totalCount = transcriptionsCountQuery.data || 0;
const loading =
@ -103,9 +133,15 @@ export const TranscriptionsList: React.FC = () => {
deleteTranscriptionMutation.mutate({ id });
};
const handlePlayAudio = (audioFile: string) => {
// Implement audio playback functionality
console.log("Playing audio:", audioFile);
const handlePlayAudio = (transcriptionId: number) => {
if (
audioPlayer.currentPlayingId === transcriptionId &&
audioPlayer.isPlaying
) {
audioPlayer.stop();
} else {
getAudioFileMutation.mutate({ transcriptionId });
}
};
const handleDownloadAudio = async (transcriptionId: number) => {
@ -219,12 +255,27 @@ export const TranscriptionsList: React.FC = () => {
variant="ghost"
size="sm"
className="h-8 w-8 p-0"
onClick={() => handlePlayAudio(transcription.audioFile!)}
onClick={() => handlePlayAudio(transcription.id)}
disabled={
getAudioFileMutation.isPending &&
getAudioFileMutation.variables?.transcriptionId ===
transcription.id
}
>
<Play className="h-4 w-4" />
{audioPlayer.currentPlayingId === transcription.id &&
audioPlayer.isPlaying ? (
<Pause className="h-4 w-4" />
) : (
<Play className="h-4 w-4" />
)}
</Button>
</TooltipTrigger>
<TooltipContent>Play audio</TooltipContent>
<TooltipContent>
{audioPlayer.currentPlayingId === transcription.id &&
audioPlayer.isPlaying
? "Pause audio"
: "Play audio"}
</TooltipContent>
</Tooltip>
</TooltipProvider>
)}

View file

@ -120,10 +120,15 @@ export const transcriptionsRouter = createRouter({
return result;
}),
// Get audio file for download
// Get audio file for playback
// Implemented as mutation instead of query because:
// 1. Large binary data (audio files) shouldn't be cached by React Query
// 2. Prevents automatic refetching on window focus/network reconnect
// 3. Represents an explicit user action (clicking play), not passive data fetching
// 4. Avoids memory overhead from React Query's caching system
getAudioFile: procedure
.input(z.object({ transcriptionId: z.number() }))
.query(async ({ input, ctx }) => {
.mutation(async ({ input, ctx }) => {
const transcription = await getTranscriptionById(input.transcriptionId);
if (!transcription?.audioFile) {
@ -138,10 +143,28 @@ export const transcriptionsRouter = createRouter({
const audioData = await fs.promises.readFile(transcription.audioFile);
const filename = path.basename(transcription.audioFile);
// Detect MIME type based on file extension
const ext = path.extname(transcription.audioFile).toLowerCase();
let mimeType = "audio/wav"; // Default for our WAV files
// Map common audio extensions to MIME types
const mimeTypes: Record<string, string> = {
".wav": "audio/wav",
".mp3": "audio/mpeg",
".webm": "audio/webm",
".ogg": "audio/ogg",
".m4a": "audio/mp4",
".flac": "audio/flac",
};
if (ext in mimeTypes) {
mimeType = mimeTypes[ext];
}
return {
data: audioData,
data: audioData.toString("base64"),
filename,
mimeType: "audio/webm",
mimeType,
};
} catch (error) {
const logger = ctx.serviceManager.getLogger();
@ -155,6 +178,8 @@ export const transcriptionsRouter = createRouter({
}),
// Download audio file with save dialog
// Mutation because this triggers a system dialog and file write operation
// Not a query since it has side effects beyond just fetching data
downloadAudioFile: procedure
.input(z.object({ transcriptionId: z.number() }))
.mutation(async ({ input, ctx }) => {

View file

@ -1,41 +1,41 @@
{
"name": "@amical/smart-whisper",
"version": "0.1.0",
"description": "Whisper.cpp Node.js binding with auto model offloading strategy.",
"main": "dist/index.js",
"types": "dist/index.d.ts",
"keywords": [
"whisper",
"whisper.cpp",
"native",
"binding",
"addon"
],
"gypfile": true,
"files": [
"dist",
"src",
"scripts",
"binding.gyp",
"whisper.cpp/**/*.{c,h,cpp,hpp,m,cu,metal}",
"whisper.cpp/Makefile",
"whisper.cpp/LICENSE"
],
"scripts": {
"install": "tsup",
"postinstall": "node-gyp rebuild",
"build": "tsup && node-gyp rebuild",
"build:ts": "tsup",
"build:native": "node-gyp rebuild"
},
"dependencies": {
"node-addon-api": "^8.5.0",
"minimatch": "10.0.3"
},
"devDependencies": {
"@amical/typescript-config": "workspace:*",
"@types/node": "^24.3.0",
"tsup": "^8.5.0",
"typescript": "^5.8.2"
}
"name": "@amical/smart-whisper",
"version": "0.1.0",
"description": "Whisper.cpp Node.js binding with auto model offloading strategy.",
"main": "dist/index.js",
"types": "dist/index.d.ts",
"keywords": [
"whisper",
"whisper.cpp",
"native",
"binding",
"addon"
],
"gypfile": true,
"files": [
"dist",
"src",
"scripts",
"binding.gyp",
"whisper.cpp/**/*.{c,h,cpp,hpp,m,cu,metal}",
"whisper.cpp/Makefile",
"whisper.cpp/LICENSE"
],
"scripts": {
"install": "tsup",
"postinstall": "node-gyp rebuild",
"build": "tsup && node-gyp rebuild",
"build:ts": "tsup",
"build:native": "node-gyp rebuild"
},
"dependencies": {
"node-addon-api": "^8.5.0",
"minimatch": "10.0.3"
},
"devDependencies": {
"@amical/typescript-config": "workspace:*",
"@types/node": "^24.3.0",
"tsup": "^8.5.0",
"typescript": "^5.8.2"
}
}

View file

@ -1,5 +1,6 @@
process.env.GGML_METAL_PATH_RESOURCES =
process.env.GGML_METAL_PATH_RESOURCES || path.join(__dirname, "../whisper.cpp/ggml/src");
process.env.GGML_METAL_PATH_RESOURCES ||
path.join(__dirname, "../whisper.cpp/ggml/src");
import path from "node:path";
import { TranscribeFormat, TranscribeParams, TranscribeResult } from "./types";
@ -9,66 +10,66 @@ const module = require(path.join(__dirname, "../build/Release/smart-whisper"));
* A external handle to a model.
*/
export type Handle = {
readonly "": unique symbol;
readonly "": unique symbol;
};
export namespace Binding {
/**
* Load a model from a whisper weights file.
* @param file The path to the whisper weights file.
* @param gpu Whether to use the GPU or not.
* @param callback A callback that will be called with the handle to the model.
*/
export declare function load(
file: string,
gpu: boolean,
callback: (handle: Handle) => void,
): void;
/**
* Load a model from a whisper weights file.
* @param file The path to the whisper weights file.
* @param gpu Whether to use the GPU or not.
* @param callback A callback that will be called with the handle to the model.
*/
export declare function load(
file: string,
gpu: boolean,
callback: (handle: Handle) => void,
): void;
/**
* Release the memory of the model, it will be unusable after this.
* @param handle The handle to the model.
* @param callback A callback that will be called when the model is freed.
*/
export declare function free(handle: Handle, callback: () => void): void;
/**
* Release the memory of the model, it will be unusable after this.
* @param handle The handle to the model.
* @param callback A callback that will be called when the model is freed.
*/
export declare function free(handle: Handle, callback: () => void): void;
/**
* Transcribe a PCM buffer.
* @param handle The handle to the model.
* @param pcm The PCM buffer.
* @param params The parameters to use for transcription.
* @param finish A callback that will be called when the transcription is finished.
* @param progress A callback that will be called when a new result is available.
*/
export declare function transcribe<
Format extends TranscribeFormat,
TokenTimestamp extends boolean,
>(
handle: Handle,
pcm: Float32Array,
params: Partial<TranscribeParams<Format, TokenTimestamp>>,
finish: (results: TranscribeResult<Format, TokenTimestamp>[]) => void,
progress: (result: TranscribeResult<Format, TokenTimestamp>) => void,
): void;
/**
* Transcribe a PCM buffer.
* @param handle The handle to the model.
* @param pcm The PCM buffer.
* @param params The parameters to use for transcription.
* @param finish A callback that will be called when the transcription is finished.
* @param progress A callback that will be called when a new result is available.
*/
export declare function transcribe<
Format extends TranscribeFormat,
TokenTimestamp extends boolean,
>(
handle: Handle,
pcm: Float32Array,
params: Partial<TranscribeParams<Format, TokenTimestamp>>,
finish: (results: TranscribeResult<Format, TokenTimestamp>[]) => void,
progress: (result: TranscribeResult<Format, TokenTimestamp>) => void,
): void;
export declare class WhisperModel {
private _ctx;
constructor(handle: Handle);
get handle(): Handle | null;
get freed(): boolean;
/**
* Release the memory of the model, it will be unusable after this.
* It's safe to call this multiple times, but it will only free the model once.
*/
free(): Promise<void>;
/**
* Load a model from a whisper weights file.
* @param file The path to the whisper weights file.
* @param gpu Whether to use the GPU or not.
* @returns A promise that resolves to a {@link WhisperModel}.
*/
static load(file: string, gpu?: boolean): Promise<WhisperModel>;
}
export declare class WhisperModel {
private _ctx;
constructor(handle: Handle);
get handle(): Handle | null;
get freed(): boolean;
/**
* Release the memory of the model, it will be unusable after this.
* It's safe to call this multiple times, but it will only free the model once.
*/
free(): Promise<void>;
/**
* Load a model from a whisper weights file.
* @param file The path to the whisper weights file.
* @param gpu Whether to use the GPU or not.
* @returns A promise that resolves to a {@link WhisperModel}.
*/
static load(file: string, gpu?: boolean): Promise<WhisperModel>;
}
}
/**

View file

@ -10,85 +10,88 @@ export const defines = cfg.defines.join(" ");
export const libraries = cfg.libraries.join(" ");
function config(): {
sources: string[];
defines: string[];
libraries: string[];
sources: string[];
defines: string[];
libraries: string[];
} {
if (process.env.BYOL) {
return {
sources: [],
defines: [],
libraries: [process.env.BYOL],
};
}
if (process.env.BYOL) {
return {
sources: [],
defines: [],
libraries: [process.env.BYOL],
};
}
const COMPUTE_BACKEND: ComputeBackend =
(process.env.COMPUTE_BACKEND as ComputeBackend | undefined) ?? infer_backend();
const COMPUTE_BACKEND: ComputeBackend =
(process.env.COMPUTE_BACKEND as ComputeBackend | undefined) ??
infer_backend();
const cfg = {
sources: [
"whisper.cpp/src/whisper.cpp",
"whisper.cpp/ggml/src/ggml.c",
"whisper.cpp/ggml/src/ggml-alloc.c",
"whisper.cpp/ggml/src/ggml-backend.c",
"whisper.cpp/ggml/src/ggml-quants.c",
"whisper.cpp/ggml/src/ggml-aarch64.c",
] as string[],
defines: [] as string[],
libraries: [] as string[],
};
const cfg = {
sources: [
"whisper.cpp/src/whisper.cpp",
"whisper.cpp/ggml/src/ggml.c",
"whisper.cpp/ggml/src/ggml-alloc.c",
"whisper.cpp/ggml/src/ggml-backend.c",
"whisper.cpp/ggml/src/ggml-quants.c",
"whisper.cpp/ggml/src/ggml-aarch64.c",
] as string[],
defines: [] as string[],
libraries: [] as string[],
};
switch (COMPUTE_BACKEND) {
case "accelerate": {
cfg.defines.push("GGML_USE_ACCELERATE");
switch (COMPUTE_BACKEND) {
case "accelerate": {
cfg.defines.push("GGML_USE_ACCELERATE");
cfg.libraries.push('"-framework Foundation"');
cfg.libraries.push('"-framework Accelerate"');
break;
}
case "metal": {
cfg.sources.push("whisper.cpp/ggml/src/ggml-metal.m");
cfg.libraries.push('"-framework Foundation"');
cfg.libraries.push('"-framework Accelerate"');
break;
}
case "metal": {
cfg.sources.push("whisper.cpp/ggml/src/ggml-metal.m");
cfg.defines.push("GGML_USE_ACCELERATE");
cfg.defines.push("GGML_USE_METAL");
cfg.defines.push("GGML_USE_ACCELERATE");
cfg.defines.push("GGML_USE_METAL");
cfg.libraries.push('"-framework Foundation"');
cfg.libraries.push('"-framework Accelerate"');
cfg.libraries.push('"-framework Metal"');
cfg.libraries.push('"-framework MetalKit"');
break;
}
case "openblas": {
cfg.defines.push("GGML_USE_OPENBLAS");
cfg.libraries.push('"-framework Foundation"');
cfg.libraries.push('"-framework Accelerate"');
cfg.libraries.push('"-framework Metal"');
cfg.libraries.push('"-framework MetalKit"');
break;
}
case "openblas": {
cfg.defines.push("GGML_USE_OPENBLAS");
cfg.libraries.push("-lopenblas");
break;
}
default: {
}
}
cfg.libraries.push("-lopenblas");
break;
}
default: {
}
}
return cfg;
return cfg;
}
function infer_backend(): ComputeBackend {
let backend: ComputeBackend = "cpu";
let backend: ComputeBackend = "cpu";
try {
if (os.platform() === "darwin") {
backend = "accelerate";
if (os.arch() === "arm64") {
backend = "metal";
}
} else if (os.platform() === "linux") {
const has_libopenblas = !!execSync("ldconfig -p | grep libopenblas").toString().trim();
if (has_libopenblas) {
backend = "openblas";
}
}
} catch {
// if anything goes wrong, just use the default cpu backend
}
try {
if (os.platform() === "darwin") {
backend = "accelerate";
if (os.arch() === "arm64") {
backend = "metal";
}
} else if (os.platform() === "linux") {
const has_libopenblas = !!execSync("ldconfig -p | grep libopenblas")
.toString()
.trim();
if (has_libopenblas) {
backend = "openblas";
}
}
} catch {
// if anything goes wrong, just use the default cpu backend
}
return backend;
return backend;
}

View file

@ -10,7 +10,8 @@ const ext = ".bin";
fs.mkdirSync(models, { recursive: true });
const BASE_MODELS_URL = "https://huggingface.co/ggerganov/whisper.cpp/resolve/main";
const BASE_MODELS_URL =
"https://huggingface.co/ggerganov/whisper.cpp/resolve/main";
/**
* MODELS is an object that contains the URLs of different ggml whisper models.
@ -18,18 +19,18 @@ const BASE_MODELS_URL = "https://huggingface.co/ggerganov/whisper.cpp/resolve/ma
* and the value is the URL of the model.
*/
export const MODELS = {
tiny: `${BASE_MODELS_URL}/ggml-tiny.bin`,
"tiny.en": `${BASE_MODELS_URL}/ggml-tiny.en.bin`,
small: `${BASE_MODELS_URL}/ggml-small.bin`,
"small.en": `${BASE_MODELS_URL}/ggml-small.en.bin`,
base: `${BASE_MODELS_URL}/ggml-base.bin`,
"base.en": `${BASE_MODELS_URL}/ggml-base.en.bin`,
medium: `${BASE_MODELS_URL}/ggml-medium.bin`,
"medium.en": `${BASE_MODELS_URL}/ggml-medium.en.bin`,
"large-v1": `${BASE_MODELS_URL}/ggml-large-v1.bin`,
"large-v2": `${BASE_MODELS_URL}/ggml-large-v2.bin`,
"large-v3": `${BASE_MODELS_URL}/ggml-large-v3.bin`,
"large-v3-turbo": `${BASE_MODELS_URL}/ggml-large-v3-turbo.bin`,
tiny: `${BASE_MODELS_URL}/ggml-tiny.bin`,
"tiny.en": `${BASE_MODELS_URL}/ggml-tiny.en.bin`,
small: `${BASE_MODELS_URL}/ggml-small.bin`,
"small.en": `${BASE_MODELS_URL}/ggml-small.en.bin`,
base: `${BASE_MODELS_URL}/ggml-base.bin`,
"base.en": `${BASE_MODELS_URL}/ggml-base.en.bin`,
medium: `${BASE_MODELS_URL}/ggml-medium.bin`,
"medium.en": `${BASE_MODELS_URL}/ggml-medium.en.bin`,
"large-v1": `${BASE_MODELS_URL}/ggml-large-v1.bin`,
"large-v2": `${BASE_MODELS_URL}/ggml-large-v2.bin`,
"large-v3": `${BASE_MODELS_URL}/ggml-large-v3.bin`,
"large-v3-turbo": `${BASE_MODELS_URL}/ggml-large-v3-turbo.bin`,
} as const;
export type ModelName = keyof typeof MODELS | (string & {});
@ -42,39 +43,41 @@ export type ModelName = keyof typeof MODELS | (string & {});
* @throws An error if the model URL or shorthand is invalid, or if the model fails to download.
*/
export async function download(model: ModelName): Promise<string> {
let url = "",
name = "";
if (model in MODELS) {
url = MODELS[model as keyof typeof MODELS];
name = model;
} else {
try {
url = new URL(model).href;
name = new URL(url).pathname.split("/").pop() ?? "";
} catch {}
}
let url = "",
name = "";
if (model in MODELS) {
url = MODELS[model as keyof typeof MODELS];
name = model;
} else {
try {
url = new URL(model).href;
name = new URL(url).pathname.split("/").pop() ?? "";
} catch {}
}
if (!url) {
throw new Error(`Invalid model URL or shorthand: ${model}`);
}
if (!url) {
throw new Error(`Invalid model URL or shorthand: ${model}`);
}
if (!name) {
throw new Error(`Failed to parse model name: ${url}`);
}
if (!name) {
throw new Error(`Failed to parse model name: ${url}`);
}
if (check(name)) {
return name;
}
if (check(name)) {
return name;
}
const res = await fetch(url);
if (!res.ok || !res.body) {
throw new Error(`Failed to download model: ${res.statusText}`);
}
const res = await fetch(url);
if (!res.ok || !res.body) {
throw new Error(`Failed to download model: ${res.statusText}`);
}
const stream = fs.createWriteStream(path.join(models, name.endsWith(ext) ? name : name + ext));
Readable.fromWeb(res.body as ReadableStream<Uint8Array>).pipe(stream);
const stream = fs.createWriteStream(
path.join(models, name.endsWith(ext) ? name : name + ext),
);
Readable.fromWeb(res.body as ReadableStream<Uint8Array>).pipe(stream);
return new Promise((resolve) => stream.on("finish", () => resolve(name)));
return new Promise((resolve) => stream.on("finish", () => resolve(name)));
}
/**
@ -82,9 +85,9 @@ export async function download(model: ModelName): Promise<string> {
* @param model - The name of the model to remove.
*/
export function remove(model: ModelName): void {
if (check(model)) {
fs.unlinkSync(path.join(models, model + ext));
}
if (check(model)) {
fs.unlinkSync(path.join(models, model + ext));
}
}
/**
@ -92,8 +95,8 @@ export function remove(model: ModelName): void {
* @returns An array of model names.
*/
export function list(): ModelName[] {
const files = fs.readdirSync(models).filter((file) => file.endsWith(ext));
return files.map((file) => file.slice(0, -ext.length));
const files = fs.readdirSync(models).filter((file) => file.endsWith(ext));
return files.map((file) => file.slice(0, -ext.length));
}
/**
@ -102,7 +105,7 @@ export function list(): ModelName[] {
* @returns True if the model exists, false otherwise.
*/
export function check(model: ModelName): boolean {
return fs.existsSync(path.join(models, model + ext));
return fs.existsSync(path.join(models, model + ext));
}
/**
@ -112,11 +115,11 @@ export function check(model: ModelName): boolean {
* @throws Error if the model is not found.
*/
export function resolve(model: ModelName): string {
if (check(model)) {
return path.join(models, model + ext);
} else {
throw new Error(`Model not found: ${model}`);
}
if (check(model)) {
return path.join(models, model + ext);
} else {
throw new Error(`Model not found: ${model}`);
}
}
export const dir = { root, models };

View file

@ -4,107 +4,111 @@ import { TranscribeFormat, TranscribeParams, TranscribeResult } from "./types";
import { binding } from "./binding";
export class TranscribeTask<
Format extends TranscribeFormat,
TokenTimestamp extends boolean,
Format extends TranscribeFormat,
TokenTimestamp extends boolean,
> extends EventEmitter {
private _model: WhisperModel;
private _result: Promise<TranscribeResult<Format, TokenTimestamp>[]> | null = null;
private _model: WhisperModel;
private _result: Promise<TranscribeResult<Format, TokenTimestamp>[]> | null =
null;
/**
* You should not construct this class directly, use {@link TranscribeTask.run} instead.
*/
constructor(model: WhisperModel) {
super();
this._model = model;
}
/**
* You should not construct this class directly, use {@link TranscribeTask.run} instead.
*/
constructor(model: WhisperModel) {
super();
this._model = model;
}
get model(): WhisperModel {
return this._model;
}
get model(): WhisperModel {
return this._model;
}
/**
* A promise that resolves to the result of the transcription task.
*/
get result(): Promise<TranscribeResult<Format, TokenTimestamp>[]> {
if (this._result === null) {
throw new Error("Task has not been started");
}
return this._result;
}
/**
* A promise that resolves to the result of the transcription task.
*/
get result(): Promise<TranscribeResult<Format, TokenTimestamp>[]> {
if (this._result === null) {
throw new Error("Task has not been started");
}
return this._result;
}
private async _run(
pcm: Float32Array,
params: Partial<TranscribeParams<Format, TokenTimestamp>>,
): Promise<TranscribeResult<Format, TokenTimestamp>[]> {
return new Promise((resolve) => {
const handle = this.model.handle;
if (!handle) {
throw new Error("Model has been freed");
}
private async _run(
pcm: Float32Array,
params: Partial<TranscribeParams<Format, TokenTimestamp>>,
): Promise<TranscribeResult<Format, TokenTimestamp>[]> {
return new Promise((resolve) => {
const handle = this.model.handle;
if (!handle) {
throw new Error("Model has been freed");
}
binding.transcribe(
handle,
pcm,
params,
(results) => {
this.emit("finish");
resolve(results);
},
(result) => {
this.emit("transcribed", result);
},
);
});
}
binding.transcribe(
handle,
pcm,
params,
(results) => {
this.emit("finish");
resolve(results);
},
(result) => {
this.emit("transcribed", result);
},
);
});
}
static async run<Format extends TranscribeFormat, TokenTimestamp extends boolean>(
model: WhisperModel,
pcm: Float32Array,
params: Partial<TranscribeParams<Format, TokenTimestamp>>,
): Promise<TranscribeTask<Format, TokenTimestamp>> {
if (model.freed) {
throw new Error("Model has been freed");
}
static async run<
Format extends TranscribeFormat,
TokenTimestamp extends boolean,
>(
model: WhisperModel,
pcm: Float32Array,
params: Partial<TranscribeParams<Format, TokenTimestamp>>,
): Promise<TranscribeTask<Format, TokenTimestamp>> {
if (model.freed) {
throw new Error("Model has been freed");
}
const task = new TranscribeTask(model);
task._result = task._run(pcm, params);
const task = new TranscribeTask(model);
task._result = task._run(pcm, params);
return task;
}
return task;
}
on(
event: "finish",
listener: (results: TranscribeResult<Format, TokenTimestamp>[]) => void,
): this;
on(
event: "transcribed",
listener: (result: TranscribeResult<Format, TokenTimestamp>) => void,
): this;
on(event: string, listener: (...args: any[]) => void): this {
return super.on(event, listener);
}
on(
event: "finish",
listener: (results: TranscribeResult<Format, TokenTimestamp>[]) => void,
): this;
on(
event: "transcribed",
listener: (result: TranscribeResult<Format, TokenTimestamp>) => void,
): this;
on(event: string, listener: (...args: any[]) => void): this {
return super.on(event, listener);
}
once(
event: "finish",
listener: (results: TranscribeResult<Format, TokenTimestamp>[]) => void,
): this;
once(
event: "transcribed",
listener: (result: TranscribeResult<Format, TokenTimestamp>) => void,
): this;
once(event: string, listener: (...args: any[]) => void): this {
return super.once(event, listener);
}
once(
event: "finish",
listener: (results: TranscribeResult<Format, TokenTimestamp>[]) => void,
): this;
once(
event: "transcribed",
listener: (result: TranscribeResult<Format, TokenTimestamp>) => void,
): this;
once(event: string, listener: (...args: any[]) => void): this {
return super.once(event, listener);
}
off(
event: "finish",
listener: (results: TranscribeResult<Format, TokenTimestamp>[]) => void,
): this;
off(
event: "transcribed",
listener: (result: TranscribeResult<Format, TokenTimestamp>) => void,
): this;
off(event: string, listener: (...args: any[]) => void): this {
return super.off(event, listener);
}
off(
event: "finish",
listener: (results: TranscribeResult<Format, TokenTimestamp>[]) => void,
): this;
off(
event: "transcribed",
listener: (result: TranscribeResult<Format, TokenTimestamp>) => void,
): this;
off(event: string, listener: (...args: any[]) => void): this {
return super.off(event, listener);
}
}

View file

@ -1,6 +1,6 @@
export enum WhisperSamplingStrategy {
WHISPER_SAMPLING_GREEDY,
WHISPER_SAMPLING_BEAM_SEARCH,
WHISPER_SAMPLING_GREEDY,
WHISPER_SAMPLING_BEAM_SEARCH,
}
export type TranscribeFormat = "simple" | "detail";
@ -9,94 +9,96 @@ export type TranscribeFormat = "simple" | "detail";
* See {@link https://github.com/ggerganov/whisper.cpp/blob/00b7a4be02ca82d53ac69dd2dd438c16e2af7658/whisper.h#L433C19-L433C19} for details.
*/
export interface TranscribeParams<
Format extends TranscribeFormat = TranscribeFormat,
TokenTimestamp extends boolean = false,
Format extends TranscribeFormat = TranscribeFormat,
TokenTimestamp extends boolean = false,
> {
strategy: WhisperSamplingStrategy;
n_threads: number;
n_max_text_ctx: number;
offset_ms: number;
duration_ms: number;
strategy: WhisperSamplingStrategy;
n_threads: number;
n_max_text_ctx: number;
offset_ms: number;
duration_ms: number;
translate: boolean;
no_context: boolean;
no_timestamps: boolean;
single_segment: boolean;
print_special: boolean;
print_progress: boolean;
print_realtime: boolean;
print_timestamps: boolean;
translate: boolean;
no_context: boolean;
no_timestamps: boolean;
single_segment: boolean;
print_special: boolean;
print_progress: boolean;
print_realtime: boolean;
print_timestamps: boolean;
token_timestamps: TokenTimestamp;
thold_pt: number;
thold_ptsum: number;
max_len: number;
split_on_word: boolean;
max_tokens: number;
token_timestamps: TokenTimestamp;
thold_pt: number;
thold_ptsum: number;
max_len: number;
split_on_word: boolean;
max_tokens: number;
speed_up: boolean;
debug_mode: boolean;
audio_ctx: number;
speed_up: boolean;
debug_mode: boolean;
audio_ctx: number;
tdrz_enable: boolean;
tdrz_enable: boolean;
initial_prompt: string;
initial_prompt: string;
/**
* Language code, e.g. "en", "de", "fr", "es", "it", "nl", "pt", "ru", "tr", "uk", "pl", "sv", "cs", "zh", "ja", "ko"
*/
language: string;
/**
* Language code, e.g. "en", "de", "fr", "es", "it", "nl", "pt", "ru", "tr", "uk", "pl", "sv", "cs", "zh", "ja", "ko"
*/
language: string;
suppress_blank: boolean;
suppress_non_speech_tokens: boolean;
suppress_blank: boolean;
suppress_non_speech_tokens: boolean;
temperature: number;
max_initial_ts: number;
length_penalty: number;
temperature: number;
max_initial_ts: number;
length_penalty: number;
temperature_inc: number;
entropy_thold: number;
logprob_thold: number;
no_speech_thold: number;
temperature_inc: number;
entropy_thold: number;
logprob_thold: number;
no_speech_thold: number;
best_of: number;
best_of: number;
beam_size: number;
beam_size: number;
format: Format;
format: Format;
}
export interface TranscribeSimpleResult {
from: number;
to: number;
text: string;
from: number;
to: number;
text: string;
}
/**
* Represents a detailed result of transcription.
*/
export interface TranscribeDetailedResult<TokenTimestamp extends boolean>
extends TranscribeSimpleResult {
/** The detected spoken language. */
lang: string;
/** The confidence level of the transcription, calculated by the average probability of the tokens. */
confidence: number;
/** The tokens generated during the transcription process. */
tokens: {
/** The text of the token, for CJK languages, due to the BPE encoding, the token text may not be readable. */
text: string;
/** The ID of the token. */
id: number;
/** The probability of the token. */
p: number;
/** The start timestamp of the token, in milliseconds. Only available when `token_timestamps` of {@link TranscribeParams} is `true`. */
from: TokenTimestamp extends true ? number : undefined;
/** The end timestamp of the token, in milliseconds. Only available when `token_timestamps` of {@link TranscribeParams} is `true`. */
to: TokenTimestamp extends true ? number : undefined;
}[];
extends TranscribeSimpleResult {
/** The detected spoken language. */
lang: string;
/** The confidence level of the transcription, calculated by the average probability of the tokens. */
confidence: number;
/** The tokens generated during the transcription process. */
tokens: {
/** The text of the token, for CJK languages, due to the BPE encoding, the token text may not be readable. */
text: string;
/** The ID of the token. */
id: number;
/** The probability of the token. */
p: number;
/** The start timestamp of the token, in milliseconds. Only available when `token_timestamps` of {@link TranscribeParams} is `true`. */
from: TokenTimestamp extends true ? number : undefined;
/** The end timestamp of the token, in milliseconds. Only available when `token_timestamps` of {@link TranscribeParams} is `true`. */
to: TokenTimestamp extends true ? number : undefined;
}[];
}
export type TranscribeResult<
Format extends TranscribeFormat = TranscribeFormat,
TokenTimestamp extends boolean = boolean,
> = Format extends "simple" ? TranscribeSimpleResult : TranscribeDetailedResult<TokenTimestamp>;
Format extends TranscribeFormat = TranscribeFormat,
TokenTimestamp extends boolean = boolean,
> = Format extends "simple"
? TranscribeSimpleResult
: TranscribeDetailedResult<TokenTimestamp>;

View file

@ -1,17 +1,21 @@
import type { TranscribeFormat, TranscribeParams, TranscribeResult } from "./types";
import type {
TranscribeFormat,
TranscribeParams,
TranscribeResult,
} from "./types";
import { WhisperModel } from "./model";
import { TranscribeTask } from "./transcribe";
export interface WhisperConfig {
/**
* Time in seconds to wait before offloading the model if it's not being used.
*/
offload: number;
/**
* Time in seconds to wait before offloading the model if it's not being used.
*/
offload: number;
/**
* Whether to use the GPU or not.
*/
gpu: boolean;
/**
* Whether to use the GPU or not.
*/
gpu: boolean;
}
/**
@ -19,112 +23,119 @@ export interface WhisperConfig {
* It handles the loading and offloading of the model, managing transcription tasks, and configuring model parameters.
*/
export class Whisper {
private _file: string;
private _available: WhisperModel | null = null;
private _loading: Promise<WhisperModel> | null = null;
private _tasks: Promise<TranscribeResult[]>[] = [];
private _config: WhisperConfig;
private _offload_timer: NodeJS.Timeout | null = null;
private _file: string;
private _available: WhisperModel | null = null;
private _loading: Promise<WhisperModel> | null = null;
private _tasks: Promise<TranscribeResult[]>[] = [];
private _config: WhisperConfig;
private _offload_timer: NodeJS.Timeout | null = null;
/**
* Constructs a new Whisper instance with a specified model file and configuration.
* @param file - The path to the Whisper model file.
* @param config - Optional configuration for the Whisper instance.
*/
constructor(file: string, config: Partial<WhisperConfig> = {}) {
this._file = file;
this._config = {
offload: 300,
gpu: true,
...config,
};
}
/**
* Constructs a new Whisper instance with a specified model file and configuration.
* @param file - The path to the Whisper model file.
* @param config - Optional configuration for the Whisper instance.
*/
constructor(file: string, config: Partial<WhisperConfig> = {}) {
this._file = file;
this._config = {
offload: 300,
gpu: true,
...config,
};
}
get file(): string {
return this._file;
}
get file(): string {
return this._file;
}
set file(file: string) {
this._file = file;
}
set file(file: string) {
this._file = file;
}
get config(): WhisperConfig {
return this._config;
}
get config(): WhisperConfig {
return this._config;
}
get tasks(): Promise<TranscribeResult[]>[] {
return this._tasks;
}
get tasks(): Promise<TranscribeResult[]>[] {
return this._tasks;
}
reset_offload_timer(): void {
this.clear_offload_timer();
this._offload_timer = setTimeout(() => {
this.free();
}, this.config.offload * 1000);
}
reset_offload_timer(): void {
this.clear_offload_timer();
this._offload_timer = setTimeout(() => {
this.free();
}, this.config.offload * 1000);
}
private clear_offload_timer(): void {
if (this._offload_timer !== null) {
clearTimeout(this._offload_timer);
this._offload_timer = null;
}
}
private clear_offload_timer(): void {
if (this._offload_timer !== null) {
clearTimeout(this._offload_timer);
this._offload_timer = null;
}
}
async model(): Promise<WhisperModel> {
if (this._available === null) {
return this.load();
}
this.reset_offload_timer();
return Promise.resolve(this._available);
}
async model(): Promise<WhisperModel> {
if (this._available === null) {
return this.load();
}
this.reset_offload_timer();
return Promise.resolve(this._available);
}
/**
* Loads the whisper model asynchronously.
* If the model is already being loaded, returns the existing one.
*
* You don't need to call this method directly, it's called automatically if necessary when you call {@link Whisper.transcribe}.
*
* @returns A Promise that resolves to the loaded model.
*/
async load(): Promise<WhisperModel> {
if (this._loading !== null) {
return this._loading;
}
/**
* Loads the whisper model asynchronously.
* If the model is already being loaded, returns the existing one.
*
* You don't need to call this method directly, it's called automatically if necessary when you call {@link Whisper.transcribe}.
*
* @returns A Promise that resolves to the loaded model.
*/
async load(): Promise<WhisperModel> {
if (this._loading !== null) {
return this._loading;
}
const model = WhisperModel.load(this.file, this.config.gpu);
this._loading = model;
this._available = await model;
this._loading = null;
this.reset_offload_timer();
return this._available;
}
const model = WhisperModel.load(this.file, this.config.gpu);
this._loading = model;
this._available = await model;
this._loading = null;
this.reset_offload_timer();
return this._available;
}
/**
* Transcribes the given PCM audio data using the Whisper model.
* @param pcm - The mono 16k PCM audio data to transcribe.
* @param params - Optional parameters for transcription.
* @returns A promise that resolves to the result of the transcription task.
*/
async transcribe<Format extends TranscribeFormat, TokenTimestamp extends boolean>(
pcm: Float32Array,
params: Partial<TranscribeParams<Format, TokenTimestamp>> = {},
): Promise<TranscribeTask<Format, TokenTimestamp>> {
const model = await this.model();
const task = await TranscribeTask.run<Format, TokenTimestamp>(model, pcm, params);
this._tasks.push(task.result);
return task;
}
/**
* Transcribes the given PCM audio data using the Whisper model.
* @param pcm - The mono 16k PCM audio data to transcribe.
* @param params - Optional parameters for transcription.
* @returns A promise that resolves to the result of the transcription task.
*/
async transcribe<
Format extends TranscribeFormat,
TokenTimestamp extends boolean,
>(
pcm: Float32Array,
params: Partial<TranscribeParams<Format, TokenTimestamp>> = {},
): Promise<TranscribeTask<Format, TokenTimestamp>> {
const model = await this.model();
const task = await TranscribeTask.run<Format, TokenTimestamp>(
model,
pcm,
params,
);
this._tasks.push(task.result);
return task;
}
async free(): Promise<void> {
if (this._available === null) {
return;
}
const model = this._available;
this._available = null;
this.clear_offload_timer();
await Promise.all(this.tasks);
await model.free();
}
async free(): Promise<void> {
if (this._available === null) {
return;
}
const model = this._available;
this._available = null;
this.clear_offload_timer();
await Promise.all(this.tasks);
await model.free();
}
}
/**

View file

@ -2,14 +2,17 @@ import { defineConfig } from "tsup";
import { readFileSync, writeFileSync } from "node:fs";
export default defineConfig({
entry: ["src/index.ts", "src/build.ts"],
outDir: "dist",
dts: true,
async onSuccess() {
// replace `#include "ggml-common.h" in whisper.cpp/ggml/src/ggml-metal.metal with full content
const metal = readFileSync("whisper.cpp/ggml/src/ggml-metal.metal", "utf-8");
const common = readFileSync("whisper.cpp/ggml/src/ggml-common.h", "utf-8");
const replaced = metal.replace(/#include "ggml-common.h"/, common);
writeFileSync("whisper.cpp/ggml/src/ggml-metal.metal", replaced);
},
entry: ["src/index.ts", "src/build.ts"],
outDir: "dist",
dts: true,
async onSuccess() {
// replace `#include "ggml-common.h" in whisper.cpp/ggml/src/ggml-metal.metal with full content
const metal = readFileSync(
"whisper.cpp/ggml/src/ggml-metal.metal",
"utf-8",
);
const common = readFileSync("whisper.cpp/ggml/src/ggml-common.h", "utf-8");
const replaced = metal.replace(/#include "ggml-common.h"/, common);
writeFileSync("whisper.cpp/ggml/src/ggml-metal.metal", replaced);
},
});

View file

@ -12,7 +12,14 @@
"build": {
"dependsOn": ["^build"],
"inputs": ["$TURBO_DEFAULT$", ".env*"],
"outputs": [".next/**", "!.next/cache/**", "bin/**", "out/**", "dist/**", "build/**"],
"outputs": [
".next/**",
"!.next/cache/**",
"bin/**",
"out/**",
"dist/**",
"build/**"
],
"env": [
"AWS_ACCESS_KEY_ID",
"AWS_SECRET_ACCESS_KEY",
@ -23,7 +30,15 @@
},
"build:native": {
"dependsOn": [],
"inputs": ["Sources/**", "Package.swift", "main.swift", "scripts/**", "src/binding/**", "binding.gyp", "whisper.cpp/**"],
"inputs": [
"Sources/**",
"Package.swift",
"main.swift",
"scripts/**",
"src/binding/**",
"binding.gyp",
"whisper.cpp/**"
],
"outputs": ["bin/**", "build/**"],
"cache": true
},