117 lines
3.4 KiB
JavaScript
117 lines
3.4 KiB
JavaScript
import { HTTP_STATUS } from "../config/constants.js";
|
|
|
|
/**
|
|
* BaseExecutor - Base class for provider executors
|
|
*/
|
|
export class BaseExecutor {
|
|
constructor(provider, config) {
|
|
this.provider = provider;
|
|
this.config = config;
|
|
}
|
|
|
|
getProvider() {
|
|
return this.provider;
|
|
}
|
|
|
|
getBaseUrls() {
|
|
return this.config.baseUrls || (this.config.baseUrl ? [this.config.baseUrl] : []);
|
|
}
|
|
|
|
getFallbackCount() {
|
|
return this.getBaseUrls().length || 1;
|
|
}
|
|
|
|
buildUrl(model, stream, urlIndex = 0, credentials = null) {
|
|
if (this.provider?.startsWith?.("openai-compatible-")) {
|
|
const baseUrl = credentials?.providerSpecificData?.baseUrl || "https://api.openai.com/v1";
|
|
const normalized = baseUrl.replace(/\/$/, "");
|
|
const path = this.provider.includes("responses") ? "/responses" : "/chat/completions";
|
|
return `${normalized}${path}`;
|
|
}
|
|
const baseUrls = this.getBaseUrls();
|
|
return baseUrls[urlIndex] || baseUrls[0] || this.config.baseUrl;
|
|
}
|
|
|
|
buildHeaders(credentials, stream = true) {
|
|
const headers = {
|
|
"Content-Type": "application/json",
|
|
...this.config.headers
|
|
};
|
|
|
|
if (credentials.accessToken) {
|
|
headers["Authorization"] = `Bearer ${credentials.accessToken}`;
|
|
} else if (credentials.apiKey) {
|
|
headers["Authorization"] = `Bearer ${credentials.apiKey}`;
|
|
}
|
|
|
|
if (stream) {
|
|
headers["Accept"] = "text/event-stream";
|
|
}
|
|
|
|
return headers;
|
|
}
|
|
|
|
// Override in subclass for provider-specific transformations
|
|
transformRequest(model, body, stream, credentials) {
|
|
return body;
|
|
}
|
|
|
|
shouldRetry(status, urlIndex) {
|
|
return status === HTTP_STATUS.RATE_LIMITED && urlIndex + 1 < this.getFallbackCount();
|
|
}
|
|
|
|
// Override in subclass for provider-specific refresh
|
|
async refreshCredentials(credentials, log) {
|
|
return null;
|
|
}
|
|
|
|
needsRefresh(credentials) {
|
|
if (!credentials.expiresAt) return false;
|
|
const expiresAtMs = new Date(credentials.expiresAt).getTime();
|
|
return expiresAtMs - Date.now() < 5 * 60 * 1000;
|
|
}
|
|
|
|
parseError(response, bodyText) {
|
|
return { status: response.status, message: bodyText || `HTTP ${response.status}` };
|
|
}
|
|
|
|
async execute({ model, body, stream, credentials, signal, log }) {
|
|
const fallbackCount = this.getFallbackCount();
|
|
let lastError = null;
|
|
let lastStatus = 0;
|
|
|
|
for (let urlIndex = 0; urlIndex < fallbackCount; urlIndex++) {
|
|
const url = this.buildUrl(model, stream, urlIndex, credentials);
|
|
const headers = this.buildHeaders(credentials, stream);
|
|
const transformedBody = this.transformRequest(model, body, stream, credentials);
|
|
|
|
try {
|
|
const response = await fetch(url, {
|
|
method: "POST",
|
|
headers,
|
|
body: JSON.stringify(transformedBody),
|
|
signal
|
|
});
|
|
|
|
if (this.shouldRetry(response.status, urlIndex)) {
|
|
log?.debug?.("RETRY", `${response.status} on ${url}, trying fallback ${urlIndex + 1}`);
|
|
lastStatus = response.status;
|
|
continue;
|
|
}
|
|
|
|
return { response, url, headers, transformedBody };
|
|
} catch (error) {
|
|
lastError = error;
|
|
if (urlIndex + 1 < fallbackCount) {
|
|
log?.debug?.("RETRY", `Error on ${url}, trying fallback ${urlIndex + 1}`);
|
|
continue;
|
|
}
|
|
throw error;
|
|
}
|
|
}
|
|
|
|
throw lastError || new Error(`All ${fallbackCount} URLs failed with status ${lastStatus}`);
|
|
}
|
|
}
|
|
|
|
export default BaseExecutor;
|