535 lines
21 KiB
JavaScript
535 lines
21 KiB
JavaScript
"use client";
|
|
|
|
import { useState, useMemo, useEffect } from "react";
|
|
import PropTypes from "prop-types";
|
|
import Modal from "./Modal";
|
|
import Button from "./Button";
|
|
import { getModelsByProviderId } from "@/shared/constants/models";
|
|
import { OAUTH_PROVIDERS, APIKEY_PROVIDERS, FREE_PROVIDERS, FREE_TIER_PROVIDERS, AI_PROVIDERS, isOpenAICompatibleProvider, isAnthropicCompatibleProvider, getProviderAlias } from "@/shared/constants/providers";
|
|
|
|
// Provider order: OAuth first, then Free Tier, then API Key (matches dashboard/providers)
|
|
const PROVIDER_ORDER = [
|
|
...Object.keys(OAUTH_PROVIDERS),
|
|
...Object.keys(FREE_PROVIDERS),
|
|
...Object.keys(FREE_TIER_PROVIDERS),
|
|
...Object.keys(APIKEY_PROVIDERS),
|
|
];
|
|
|
|
// Providers that need no auth — always show in model selector
|
|
const NO_AUTH_PROVIDER_IDS = Object.keys(FREE_PROVIDERS).filter(id => FREE_PROVIDERS[id].noAuth);
|
|
|
|
export default function ModelSelectModal({
|
|
isOpen,
|
|
onClose,
|
|
onSelect,
|
|
onDeselect,
|
|
selectedModel,
|
|
activeProviders = [],
|
|
title = "Select Model",
|
|
modelAliases = {},
|
|
kindFilter = null,
|
|
addedModelValues = [],
|
|
closeOnSelect = true,
|
|
}) {
|
|
// Filter activeProviders by serviceKinds when kindFilter set (e.g. "webSearch", "webFetch")
|
|
const filteredActiveProviders = useMemo(() => {
|
|
if (!kindFilter) return activeProviders;
|
|
return activeProviders.filter((p) => {
|
|
const info = AI_PROVIDERS[p.provider];
|
|
const kinds = info?.serviceKinds || ["llm"];
|
|
return kinds.includes(kindFilter);
|
|
});
|
|
}, [activeProviders, kindFilter]);
|
|
const [searchQuery, setSearchQuery] = useState("");
|
|
const [combos, setCombos] = useState([]);
|
|
const [providerNodes, setProviderNodes] = useState([]);
|
|
const [customModels, setCustomModels] = useState([]);
|
|
const [disabledModels, setDisabledModels] = useState({});
|
|
|
|
const fetchCombos = async () => {
|
|
try {
|
|
const res = await fetch("/api/combos");
|
|
if (!res.ok) throw new Error(`Failed to fetch combos: ${res.status}`);
|
|
const data = await res.json();
|
|
setCombos(data.combos || []);
|
|
} catch (error) {
|
|
console.error("Error fetching combos:", error);
|
|
setCombos([]);
|
|
}
|
|
};
|
|
|
|
useEffect(() => {
|
|
if (isOpen) fetchCombos();
|
|
}, [isOpen]);
|
|
|
|
const fetchProviderNodes = async () => {
|
|
try {
|
|
const res = await fetch("/api/provider-nodes");
|
|
if (!res.ok) throw new Error(`Failed to fetch provider nodes: ${res.status}`);
|
|
const data = await res.json();
|
|
setProviderNodes(data.nodes || []);
|
|
} catch (error) {
|
|
console.error("Error fetching provider nodes:", error);
|
|
setProviderNodes([]);
|
|
}
|
|
};
|
|
|
|
useEffect(() => {
|
|
if (isOpen) fetchProviderNodes();
|
|
}, [isOpen]);
|
|
|
|
const fetchCustomModels = async () => {
|
|
try {
|
|
const res = await fetch("/api/models/custom");
|
|
if (!res.ok) throw new Error(`Failed to fetch custom models: ${res.status}`);
|
|
const data = await res.json();
|
|
setCustomModels(data.models || []);
|
|
} catch (error) {
|
|
console.error("Error fetching custom models:", error);
|
|
setCustomModels([]);
|
|
}
|
|
};
|
|
|
|
useEffect(() => {
|
|
if (isOpen) fetchCustomModels();
|
|
}, [isOpen]);
|
|
|
|
const fetchDisabledModels = async () => {
|
|
try {
|
|
const res = await fetch("/api/models/disabled");
|
|
if (!res.ok) throw new Error(`Failed to fetch disabled models: ${res.status}`);
|
|
const data = await res.json();
|
|
setDisabledModels(data.disabled || {});
|
|
} catch (error) {
|
|
console.error("Error fetching disabled models:", error);
|
|
setDisabledModels({});
|
|
}
|
|
};
|
|
|
|
useEffect(() => {
|
|
if (isOpen) fetchDisabledModels();
|
|
}, [isOpen]);
|
|
|
|
const allProviders = useMemo(() => ({ ...OAUTH_PROVIDERS, ...FREE_PROVIDERS, ...FREE_TIER_PROVIDERS, ...APIKEY_PROVIDERS }), []);
|
|
|
|
// Group models by provider with priority order
|
|
const groupedModels = useMemo(() => {
|
|
const groups = {};
|
|
|
|
// Kinds where the provider IS the model (no per-model selection needed)
|
|
const PROVIDER_AS_MODEL_KINDS = new Set(["webSearch", "webFetch"]);
|
|
// Kinds that map directly to model.type field
|
|
const TYPED_KINDS = new Set(["image", "tts", "stt", "embedding", "imageToText"]);
|
|
// For these kinds, providers without hardcoded models can still be picked (provider-as-model fallback)
|
|
const ALLOW_PROVIDER_FALLBACK_KINDS = new Set(["tts", "image", "webFetch"]);
|
|
|
|
// Filter a models[] array by kindFilter (keep only matching m.type)
|
|
const filterByKind = (models) => {
|
|
// No kindFilter → LLM context: keep only LLM models (no type or type === "llm")
|
|
if (!kindFilter) return models.filter((m) => m.isPlaceholder || !m.type || m.type === "llm");
|
|
if (!TYPED_KINDS.has(kindFilter)) return models;
|
|
return models.filter((m) => m.isPlaceholder || m.type === kindFilter);
|
|
};
|
|
|
|
// Get all active provider IDs from connections (filtered by kindFilter if set)
|
|
const activeConnectionIds = filteredActiveProviders.map(p => p.provider);
|
|
|
|
// No-auth providers: filter by kindFilter as well
|
|
const noAuthIds = kindFilter
|
|
? NO_AUTH_PROVIDER_IDS.filter((id) => (AI_PROVIDERS[id]?.serviceKinds || ["llm"]).includes(kindFilter))
|
|
: NO_AUTH_PROVIDER_IDS;
|
|
|
|
// Only show connected providers (including both standard and custom)
|
|
const providerIdsToShow = new Set([
|
|
...activeConnectionIds, // Only connected providers
|
|
...noAuthIds, // No-auth providers (kind-filtered)
|
|
]);
|
|
|
|
// Sort by PROVIDER_ORDER
|
|
const sortedProviderIds = [...providerIdsToShow].sort((a, b) => {
|
|
const indexA = PROVIDER_ORDER.indexOf(a);
|
|
const indexB = PROVIDER_ORDER.indexOf(b);
|
|
return (indexA === -1 ? 999 : indexA) - (indexB === -1 ? 999 : indexB);
|
|
});
|
|
|
|
sortedProviderIds.forEach((providerId) => {
|
|
const alias = getProviderAlias(providerId);
|
|
const providerInfo = allProviders[providerId] || { name: providerId, color: "#666" };
|
|
const isCustomProvider = isOpenAICompatibleProvider(providerId) || isAnthropicCompatibleProvider(providerId);
|
|
|
|
// For provider-as-model kinds (webSearch/webFetch): emit a single entry where value === providerId
|
|
if (kindFilter && PROVIDER_AS_MODEL_KINDS.has(kindFilter)) {
|
|
groups[providerId] = {
|
|
name: providerInfo.name,
|
|
alias,
|
|
color: providerInfo.color,
|
|
models: [{ id: providerId, name: providerInfo.name, value: providerId }],
|
|
};
|
|
return;
|
|
}
|
|
|
|
if (providerInfo.passthroughModels) {
|
|
const aliasModels = Object.entries(modelAliases)
|
|
.filter(([, fullModel]) => fullModel.startsWith(`${alias}/`))
|
|
.map(([aliasName, fullModel]) => ({
|
|
id: fullModel.replace(`${alias}/`, ""),
|
|
name: aliasName,
|
|
value: fullModel,
|
|
}));
|
|
|
|
// For typed kinds, only include hardcoded typed models (aliases are typically LLM-only and lack type info)
|
|
let combined = aliasModels;
|
|
if (kindFilter && TYPED_KINDS.has(kindFilter)) {
|
|
combined = getModelsByProviderId(providerId)
|
|
.filter((m) => m.type === kindFilter)
|
|
.map((m) => ({ id: m.id, name: m.name, value: `${alias}/${m.id}`, type: m.type }));
|
|
// Fallback: provider-as-model when no hardcoded models match (tts/image/webFetch only)
|
|
if (combined.length === 0 && ALLOW_PROVIDER_FALLBACK_KINDS.has(kindFilter)) {
|
|
const supports = (providerInfo.serviceKinds || ["llm"]).includes(kindFilter);
|
|
if (supports) combined = [{ id: providerId, name: providerInfo.name, value: alias }];
|
|
}
|
|
}
|
|
|
|
if (combined.length > 0) {
|
|
// Check for custom name from providerNodes (for compatible providers)
|
|
const matchedNode = providerNodes.find(node => node.id === providerId);
|
|
const displayName = matchedNode?.name || providerInfo.name;
|
|
|
|
groups[providerId] = {
|
|
name: displayName,
|
|
alias: alias,
|
|
color: providerInfo.color,
|
|
models: combined,
|
|
};
|
|
}
|
|
} else if (isCustomProvider) {
|
|
// Custom (openai/anthropic-compatible) providers are LLM-only — skip for typed media kinds
|
|
if (kindFilter && TYPED_KINDS.has(kindFilter)) return;
|
|
// Find connection object to get prefix synchronously without waiting for providerNodes fetch
|
|
const connection = activeProviders.find(p => p.provider === providerId);
|
|
const matchedNode = providerNodes.find(node => node.id === providerId);
|
|
const displayName = connection?.name || matchedNode?.name || providerInfo.name;
|
|
const nodePrefix = connection?.providerSpecificData?.prefix || matchedNode?.prefix || providerId;
|
|
|
|
// Aliases are stored using the raw providerId as key (e.g. "openai-compatible-chat-<uuid>/glm-4.7"),
|
|
// so we must filter by providerId, not by the display prefix.
|
|
const nodeModels = Object.entries(modelAliases)
|
|
.filter(([, fullModel]) => fullModel.startsWith(`${providerId}/`))
|
|
.map(([aliasName, fullModel]) => ({
|
|
id: fullModel.replace(`${providerId}/`, ""),
|
|
name: aliasName,
|
|
value: `${nodePrefix}/${fullModel.replace(`${providerId}/`, "")}`,
|
|
}));
|
|
|
|
// Always show compatible providers that are connected, even with no aliases.
|
|
// When no aliases exist, show a placeholder so users know it's available.
|
|
const modelsToShow = nodeModels.length > 0 ? nodeModels : [{
|
|
id: `__placeholder__${providerId}`,
|
|
name: `${nodePrefix}/model-id`,
|
|
value: `${nodePrefix}/model-id`,
|
|
isPlaceholder: true,
|
|
}];
|
|
|
|
groups[providerId] = {
|
|
name: displayName,
|
|
alias: nodePrefix,
|
|
color: providerInfo.color,
|
|
models: modelsToShow,
|
|
isCustom: true,
|
|
hasModels: nodeModels.length > 0,
|
|
};
|
|
} else {
|
|
const hardcodedModels = getModelsByProviderId(providerId);
|
|
const hardcodedIds = new Set(hardcodedModels.map((m) => m.id));
|
|
|
|
// Custom models: if no hardcoded models (e.g. openrouter), show all aliases for this provider
|
|
// Otherwise only show aliases where aliasName === modelId ("Add Model" button pattern)
|
|
const hasHardcoded = hardcodedModels.length > 0;
|
|
const customAliasModels = Object.entries(modelAliases)
|
|
.filter(([aliasName, fullModel]) =>
|
|
fullModel.startsWith(`${alias}/`) &&
|
|
(hasHardcoded ? aliasName === fullModel.replace(`${alias}/`, "") : true) &&
|
|
!hardcodedIds.has(fullModel.replace(`${alias}/`, ""))
|
|
)
|
|
.map(([aliasName, fullModel]) => {
|
|
const modelId = fullModel.replace(`${alias}/`, "");
|
|
return { id: modelId, name: aliasName, value: fullModel, isCustom: true };
|
|
});
|
|
|
|
// Custom models registered via /api/models/custom (provider "Add Model" button)
|
|
const customAliasIds = new Set(customAliasModels.map((m) => m.id));
|
|
const customRegisteredModels = customModels
|
|
.filter((m) => m.providerAlias === alias && !hardcodedIds.has(m.id) && !customAliasIds.has(m.id))
|
|
.map((m) => ({ id: m.id, name: m.name || m.id, value: `${alias}/${m.id}`, isCustom: true }));
|
|
|
|
const merged = [
|
|
...hardcodedModels.map((m) => ({ id: m.id, name: m.name, value: `${alias}/${m.id}`, type: m.type })),
|
|
...customAliasModels,
|
|
...customRegisteredModels,
|
|
];
|
|
// Dedupe by value (alias may equal hardcoded id, causing React key collision)
|
|
const seen = new Set();
|
|
let allModels = filterByKind(merged.filter((m) => {
|
|
if (seen.has(m.value)) return false;
|
|
seen.add(m.value);
|
|
return true;
|
|
}));
|
|
|
|
// Provider-as-model fallback: providers that support the kind but have no hardcoded models
|
|
// can still be picked (value = providerAlias). Skips embedding (always needs model).
|
|
if (allModels.length === 0 && kindFilter && ALLOW_PROVIDER_FALLBACK_KINDS.has(kindFilter)) {
|
|
const supports = (providerInfo.serviceKinds || ["llm"]).includes(kindFilter);
|
|
if (supports) {
|
|
allModels = [{ id: providerId, name: providerInfo.name, value: alias }];
|
|
}
|
|
}
|
|
|
|
if (allModels.length > 0) {
|
|
groups[providerId] = {
|
|
name: providerInfo.name,
|
|
alias: alias,
|
|
color: providerInfo.color,
|
|
models: allModels,
|
|
};
|
|
}
|
|
}
|
|
});
|
|
|
|
// Filter out disabled models per provider (disabled keyed by storage alias OR providerId)
|
|
Object.entries(groups).forEach(([providerId, group]) => {
|
|
const aliasKey = getProviderAlias(providerId);
|
|
const disabledIds = new Set([
|
|
...(disabledModels[aliasKey] || []),
|
|
...(disabledModels[providerId] || []),
|
|
]);
|
|
if (disabledIds.size === 0) return;
|
|
group.models = group.models.filter((m) => !disabledIds.has(m.id));
|
|
if (group.models.length === 0) delete groups[providerId];
|
|
});
|
|
|
|
return groups;
|
|
}, [filteredActiveProviders, modelAliases, allProviders, providerNodes, customModels, disabledModels, kindFilter, activeProviders]);
|
|
|
|
// Filter combos by search query (and hide combos when kindFilter is set — combos are LLM-only by design)
|
|
const filteredCombos = useMemo(() => {
|
|
if (kindFilter) return [];
|
|
if (!searchQuery.trim()) return combos;
|
|
const query = searchQuery.toLowerCase();
|
|
return combos.filter(c => c.name.toLowerCase().includes(query));
|
|
}, [combos, searchQuery, kindFilter]);
|
|
|
|
// Filter models by search query
|
|
const filteredGroups = useMemo(() => {
|
|
if (!searchQuery.trim()) return groupedModels;
|
|
|
|
const query = searchQuery.toLowerCase();
|
|
const filtered = {};
|
|
|
|
Object.entries(groupedModels).forEach(([providerId, group]) => {
|
|
const matchedModels = group.models.filter(
|
|
(m) =>
|
|
m.name.toLowerCase().includes(query) ||
|
|
m.id.toLowerCase().includes(query)
|
|
);
|
|
|
|
const providerNameMatches = group.name.toLowerCase().includes(query);
|
|
|
|
if (matchedModels.length > 0 || providerNameMatches) {
|
|
filtered[providerId] = {
|
|
...group,
|
|
models: matchedModels,
|
|
};
|
|
}
|
|
});
|
|
|
|
return filtered;
|
|
}, [groupedModels, searchQuery]);
|
|
|
|
const handleSelect = (model) => {
|
|
const value = model?.value || model?.name || model;
|
|
const isAdded = addedModelValues.includes(value);
|
|
|
|
if (isAdded && onDeselect) {
|
|
onDeselect(model);
|
|
} else {
|
|
onSelect(model);
|
|
}
|
|
|
|
if (closeOnSelect) {
|
|
onClose();
|
|
setSearchQuery("");
|
|
}
|
|
};
|
|
|
|
return (
|
|
<Modal
|
|
isOpen={isOpen}
|
|
onClose={() => {
|
|
onClose();
|
|
setSearchQuery("");
|
|
}}
|
|
title={title}
|
|
size="md"
|
|
className="p-4!"
|
|
footer={
|
|
!closeOnSelect ? (
|
|
<Button
|
|
onClick={() => {
|
|
onClose();
|
|
setSearchQuery("");
|
|
}}
|
|
fullWidth
|
|
>
|
|
Done
|
|
</Button>
|
|
) : null
|
|
}
|
|
>
|
|
{/* Search - compact */}
|
|
<div className="mb-3">
|
|
<div className="relative">
|
|
<span className="material-symbols-outlined absolute left-2.5 top-1/2 -translate-y-1/2 text-text-muted text-[16px]">
|
|
search
|
|
</span>
|
|
<input
|
|
type="text"
|
|
placeholder="Search..."
|
|
value={searchQuery}
|
|
onChange={(e) => setSearchQuery(e.target.value)}
|
|
className="w-full pl-8 pr-3 py-1.5 bg-surface border border-border rounded text-xs focus:outline-none focus:ring-1 focus:ring-primary/50"
|
|
/>
|
|
</div>
|
|
</div>
|
|
|
|
{/* Models grouped by provider - compact */}
|
|
<div className="max-h-[400px] overflow-y-auto space-y-3">
|
|
{/* Combos section - always first */}
|
|
{filteredCombos.length > 0 && (
|
|
<div>
|
|
<div className="flex items-center gap-1.5 mb-1.5 sticky top-0 bg-surface py-0.5">
|
|
<span className="material-symbols-outlined text-primary text-[14px]">layers</span>
|
|
<span className="text-xs font-medium text-primary">Combos</span>
|
|
<span className="text-[10px] text-text-muted">({filteredCombos.length})</span>
|
|
</div>
|
|
<div className="flex flex-wrap gap-1.5">
|
|
{filteredCombos.map((combo) => {
|
|
const isSelected = selectedModel === combo.name;
|
|
return (
|
|
<button
|
|
key={combo.id}
|
|
onClick={() => handleSelect({ id: combo.name, name: combo.name, value: combo.name })}
|
|
className={`
|
|
px-2 py-1 rounded-xl text-xs font-medium transition-all border hover:cursor-pointer flex items-center gap-1
|
|
${isSelected
|
|
? "bg-primary text-white border-primary"
|
|
: addedModelValues.includes(combo.name)
|
|
? "bg-green-500/10 border-green-500/30 text-green-700 dark:text-green-400 hover:border-green-500/50"
|
|
: "bg-surface border-border text-text-main hover:border-primary/50 hover:bg-primary/5"
|
|
}
|
|
`}
|
|
>
|
|
{addedModelValues.includes(combo.name) && (
|
|
<span className="material-symbols-outlined text-[12px]">check_circle</span>
|
|
)}
|
|
{combo.name}
|
|
</button>
|
|
);
|
|
})}
|
|
</div>
|
|
</div>
|
|
)}
|
|
|
|
{/* Provider models */}
|
|
{Object.entries(filteredGroups).map(([providerId, group]) => (
|
|
<div key={providerId}>
|
|
{/* Provider header */}
|
|
<div className="flex items-center gap-1.5 mb-1.5 sticky top-0 bg-surface py-0.5">
|
|
<div
|
|
className="w-2 h-2 rounded-full"
|
|
style={{ backgroundColor: group.color }}
|
|
/>
|
|
<span className="text-xs font-medium text-primary">
|
|
{group.name}
|
|
</span>
|
|
<span className="text-[10px] text-text-muted">
|
|
({group.models.length})
|
|
</span>
|
|
</div>
|
|
|
|
<div className="flex flex-wrap gap-1.5">
|
|
{group.models.map((model) => {
|
|
const isSelected = selectedModel === model.value;
|
|
const isPlaceholder = model.isPlaceholder;
|
|
return (
|
|
<button
|
|
key={model.value}
|
|
onClick={() => handleSelect(model)}
|
|
title={isPlaceholder ? "Select to pre-fill, then edit model ID in the input" : undefined}
|
|
className={`
|
|
px-2 py-1 rounded-xl text-xs font-medium transition-all border hover:cursor-pointer
|
|
${isPlaceholder
|
|
? "border-dashed border-border text-text-muted hover:border-primary/50 hover:text-primary bg-surface italic"
|
|
: isSelected
|
|
? "bg-primary text-white border-primary"
|
|
: addedModelValues.includes(model.value)
|
|
? "bg-green-500/10 border-green-500/30 text-green-700 dark:text-green-400 hover:border-green-500/50"
|
|
: "bg-surface border-border text-text-main hover:border-primary/50 hover:bg-primary/5"
|
|
}
|
|
`}
|
|
>
|
|
<span className="flex items-center gap-1">
|
|
{addedModelValues.includes(model.value) && !isPlaceholder && (
|
|
<span className="material-symbols-outlined text-[12px]">check_circle</span>
|
|
)}
|
|
{isPlaceholder ? (
|
|
<>
|
|
<span className="material-symbols-outlined text-[11px]">edit</span>
|
|
{model.name}
|
|
</>
|
|
) : model.isCustom ? (
|
|
<>
|
|
{model.name}
|
|
<span className="text-[9px] opacity-60 font-normal">custom</span>
|
|
</>
|
|
) : (
|
|
model.name
|
|
)}
|
|
</span>
|
|
</button>
|
|
);
|
|
})}
|
|
</div>
|
|
</div>
|
|
))}
|
|
|
|
{Object.keys(filteredGroups).length === 0 && filteredCombos.length === 0 && (
|
|
<div className="text-center py-4 text-text-muted">
|
|
<span className="material-symbols-outlined text-2xl mb-1 block">
|
|
search_off
|
|
</span>
|
|
<p className="text-xs">No models found</p>
|
|
</div>
|
|
)}
|
|
</div>
|
|
</Modal>
|
|
);
|
|
}
|
|
|
|
ModelSelectModal.propTypes = {
|
|
isOpen: PropTypes.bool.isRequired,
|
|
onClose: PropTypes.func.isRequired,
|
|
onSelect: PropTypes.func.isRequired,
|
|
onDeselect: PropTypes.func,
|
|
selectedModel: PropTypes.string,
|
|
activeProviders: PropTypes.arrayOf(
|
|
PropTypes.shape({
|
|
provider: PropTypes.string.isRequired,
|
|
})
|
|
),
|
|
title: PropTypes.string,
|
|
modelAliases: PropTypes.object,
|
|
kindFilter: PropTypes.string,
|
|
addedModelValues: PropTypes.arrayOf(PropTypes.string),
|
|
closeOnSelect: PropTypes.bool,
|
|
};
|
|
|