Files
Sybil-2/server/src/llm/provider-adapters.ts

218 lines
7.2 KiB
TypeScript
Raw Normal View History

2026-06-13 12:02:22 -07:00
import {
normalizeEnabledChatTools,
type ToolAwareCompletionParams,
type ToolAwareCompletionResult,
type ToolAwareStreamingEvent,
} from "./chat-tools.js";
import { completeWithChatCompletionsApi, streamWithChatCompletionsApi } from "./protocols/chat-completions-api.js";
import { completeWithMessagesApi, streamWithMessagesApi } from "./protocols/messages-api.js";
import { completeWithResponsesApi, streamWithResponsesApi } from "./protocols/responses-api.js";
import { env } from "../env.js";
import { anthropicClient, hermesAgentClient, isHermesAgentConfigured, openaiClient, xaiClient } from "./providers.js";
import type { ChatMessage, Provider } from "./types.js";
type ProviderAdapterParams = {
model: string;
messages: ChatMessage[];
enabledTools?: string[];
userLocation?: string;
temperature?: number;
maxTokens?: number;
logContext?: ToolAwareCompletionParams["logContext"];
};
export type ProviderChatAdapter = {
provider: Provider;
complete(params: ProviderAdapterParams): Promise<ToolAwareCompletionResult>;
stream(params: ProviderAdapterParams): AsyncGenerator<ToolAwareStreamingEvent>;
};
type ChatProtocolId = "chat-completions" | "messages" | "responses";
type ChatProtocol = {
id: ChatProtocolId;
complete(params: ToolAwareCompletionParams): Promise<ToolAwareCompletionResult>;
stream(params: ToolAwareCompletionParams): AsyncGenerator<ToolAwareStreamingEvent>;
};
type ModelCatalogSpec = {
enabled?: () => boolean;
fetchModels(client: any): Promise<string[]>;
fallbackModels?: () => string[];
};
type ProviderBackendSpec = {
createClient: () => any;
plainProtocol: ChatProtocol;
toolProtocol?: ChatProtocol;
managedTools?: boolean;
modelCatalog?: ModelCatalogSpec;
};
const chatCompletionsProtocol: ChatProtocol = {
id: "chat-completions",
complete: completeWithChatCompletionsApi,
stream: streamWithChatCompletionsApi,
};
const messagesProtocol: ChatProtocol = {
id: "messages",
complete: completeWithMessagesApi,
stream: streamWithMessagesApi,
};
const responsesProtocol: ChatProtocol = {
id: "responses",
complete: completeWithResponsesApi,
stream: streamWithResponsesApi,
};
function uniqSorted(values: string[]) {
return [...new Set(values.map((value) => value.trim()).filter(Boolean))].sort((a, b) => a.localeCompare(b));
}
function modelIdsFromListResponse(page: any) {
return Array.isArray(page?.data)
? page.data.map((model: any) => model?.id).filter((id: unknown): id is string => typeof id === "string")
: [];
}
function isLikelyResponsesApiModel(model: string) {
const id = model.toLowerCase();
if (id.includes("embedding") || id.includes("moderation")) return false;
if (id.includes("audio") || id.includes("realtime") || id.includes("transcribe") || id.includes("tts")) return false;
if (id.includes("image") || id.includes("dall-e") || id.includes("sora")) return false;
if (id.includes("search") || id.includes("computer-use")) return false;
return /^(gpt-|o\d|chatgpt-)/.test(id);
}
function withClient(params: ProviderAdapterParams, client: any, enabledTools?: string[]): ToolAwareCompletionParams {
return {
client,
model: params.model,
messages: params.messages,
enabledTools,
userLocation: params.userLocation,
temperature: params.temperature,
maxTokens: params.maxTokens,
logContext: params.logContext,
};
}
function selectChatProtocol(spec: ProviderBackendSpec, params: Pick<ProviderAdapterParams, "enabledTools">) {
const enabledTools = normalizeEnabledChatTools(params.enabledTools);
const useManagedTools = spec.managedTools === true && spec.toolProtocol && enabledTools.length > 0;
return {
protocol: useManagedTools ? spec.toolProtocol! : spec.plainProtocol,
enabledTools: useManagedTools ? enabledTools : [],
managedTools: Boolean(useManagedTools),
};
}
function createProviderChatAdapter(provider: Provider, spec: ProviderBackendSpec): ProviderChatAdapter {
return {
provider,
complete(params) {
const selected = selectChatProtocol(spec, params);
return selected.protocol.complete(withClient(params, spec.createClient(), selected.enabledTools));
},
stream(params) {
const selected = selectChatProtocol(spec, params);
return selected.protocol.stream(withClient(params, spec.createClient(), selected.enabledTools));
},
};
}
const backendSpecs: Record<Provider, ProviderBackendSpec> = {
openai: {
createClient: openaiClient,
plainProtocol: chatCompletionsProtocol,
toolProtocol: responsesProtocol,
managedTools: true,
modelCatalog: {
async fetchModels(client) {
const page = await client.models.list();
return modelIdsFromListResponse(page).filter(isLikelyResponsesApiModel);
},
},
},
anthropic: {
createClient: anthropicClient,
plainProtocol: messagesProtocol,
toolProtocol: messagesProtocol,
managedTools: true,
modelCatalog: {
async fetchModels(client) {
const page = await client.models.list({ limit: 200 });
return modelIdsFromListResponse(page);
},
},
},
xai: {
createClient: xaiClient,
plainProtocol: chatCompletionsProtocol,
toolProtocol: chatCompletionsProtocol,
managedTools: true,
modelCatalog: {
async fetchModels(client) {
const page = await client.models.list();
return modelIdsFromListResponse(page);
},
},
},
"hermes-agent": {
createClient: hermesAgentClient,
plainProtocol: chatCompletionsProtocol,
managedTools: false,
modelCatalog: {
enabled: isHermesAgentConfigured,
async fetchModels(client) {
const page = await client.models.list();
const models = modelIdsFromListResponse(page);
if (env.HERMES_AGENT_MODEL) models.push(env.HERMES_AGENT_MODEL);
return models;
},
fallbackModels() {
return env.HERMES_AGENT_MODEL ? [env.HERMES_AGENT_MODEL] : [];
},
},
},
};
const providerChatAdapters: Record<Provider, ProviderChatAdapter> = Object.fromEntries(
Object.entries(backendSpecs).map(([provider, spec]) => [provider, createProviderChatAdapter(provider as Provider, spec)])
) as Record<Provider, ProviderChatAdapter>;
export function getProviderChatAdapter(provider: Provider) {
return providerChatAdapters[provider];
}
export function describeProviderChatBackend(provider: Provider, enabledTools?: string[]) {
const selected = selectChatProtocol(backendSpecs[provider], { enabledTools });
return {
provider,
protocol: selected.protocol.id,
managedTools: selected.managedTools,
enabledTools: selected.enabledTools,
};
}
export function listModelCatalogProviders(): Provider[] {
return (Object.entries(backendSpecs) as [Provider, ProviderBackendSpec][])
.filter(([, spec]) => {
const catalog = spec.modelCatalog;
return catalog !== undefined && catalog.enabled?.() !== false;
})
.map(([provider]) => provider);
}
export async function fetchProviderCatalogModels(provider: Provider) {
const spec = backendSpecs[provider].modelCatalog;
if (!spec) return [];
return uniqSorted(await spec.fetchModels(backendSpecs[provider].createClient()));
}
export function getProviderCatalogFallbackModels(provider: Provider) {
return uniqSorted(backendSpecs[provider].modelCatalog?.fallbackModels?.() ?? []);
}