218 lines
7.2 KiB
TypeScript
218 lines
7.2 KiB
TypeScript
import {
|
|
normalizeEnabledChatTools,
|
|
type ToolAwareCompletionParams,
|
|
type ToolAwareCompletionResult,
|
|
type ToolAwareStreamingEvent,
|
|
} from "./chat-tools.js";
|
|
import { completeWithChatCompletionsApi, streamWithChatCompletionsApi } from "./protocols/chat-completions-api.js";
|
|
import { completeWithMessagesApi, streamWithMessagesApi } from "./protocols/messages-api.js";
|
|
import { completeWithResponsesApi, streamWithResponsesApi } from "./protocols/responses-api.js";
|
|
import { env } from "../env.js";
|
|
import { anthropicClient, hermesAgentClient, isHermesAgentConfigured, openaiClient, xaiClient } from "./providers.js";
|
|
import type { ChatMessage, Provider } from "./types.js";
|
|
|
|
type ProviderAdapterParams = {
|
|
model: string;
|
|
messages: ChatMessage[];
|
|
enabledTools?: string[];
|
|
userLocation?: string;
|
|
temperature?: number;
|
|
maxTokens?: number;
|
|
logContext?: ToolAwareCompletionParams["logContext"];
|
|
};
|
|
|
|
export type ProviderChatAdapter = {
|
|
provider: Provider;
|
|
complete(params: ProviderAdapterParams): Promise<ToolAwareCompletionResult>;
|
|
stream(params: ProviderAdapterParams): AsyncGenerator<ToolAwareStreamingEvent>;
|
|
};
|
|
|
|
type ChatProtocolId = "chat-completions" | "messages" | "responses";
|
|
|
|
type ChatProtocol = {
|
|
id: ChatProtocolId;
|
|
complete(params: ToolAwareCompletionParams): Promise<ToolAwareCompletionResult>;
|
|
stream(params: ToolAwareCompletionParams): AsyncGenerator<ToolAwareStreamingEvent>;
|
|
};
|
|
|
|
type ModelCatalogSpec = {
|
|
enabled?: () => boolean;
|
|
fetchModels(client: any): Promise<string[]>;
|
|
fallbackModels?: () => string[];
|
|
};
|
|
|
|
type ProviderBackendSpec = {
|
|
createClient: () => any;
|
|
plainProtocol: ChatProtocol;
|
|
toolProtocol?: ChatProtocol;
|
|
managedTools?: boolean;
|
|
modelCatalog?: ModelCatalogSpec;
|
|
};
|
|
|
|
const chatCompletionsProtocol: ChatProtocol = {
|
|
id: "chat-completions",
|
|
complete: completeWithChatCompletionsApi,
|
|
stream: streamWithChatCompletionsApi,
|
|
};
|
|
|
|
const messagesProtocol: ChatProtocol = {
|
|
id: "messages",
|
|
complete: completeWithMessagesApi,
|
|
stream: streamWithMessagesApi,
|
|
};
|
|
|
|
const responsesProtocol: ChatProtocol = {
|
|
id: "responses",
|
|
complete: completeWithResponsesApi,
|
|
stream: streamWithResponsesApi,
|
|
};
|
|
|
|
function uniqSorted(values: string[]) {
|
|
return [...new Set(values.map((value) => value.trim()).filter(Boolean))].sort((a, b) => a.localeCompare(b));
|
|
}
|
|
|
|
function modelIdsFromListResponse(page: any) {
|
|
return Array.isArray(page?.data)
|
|
? page.data.map((model: any) => model?.id).filter((id: unknown): id is string => typeof id === "string")
|
|
: [];
|
|
}
|
|
|
|
function isLikelyResponsesApiModel(model: string) {
|
|
const id = model.toLowerCase();
|
|
if (id.includes("embedding") || id.includes("moderation")) return false;
|
|
if (id.includes("audio") || id.includes("realtime") || id.includes("transcribe") || id.includes("tts")) return false;
|
|
if (id.includes("image") || id.includes("dall-e") || id.includes("sora")) return false;
|
|
if (id.includes("search") || id.includes("computer-use")) return false;
|
|
return /^(gpt-|o\d|chatgpt-)/.test(id);
|
|
}
|
|
|
|
function withClient(params: ProviderAdapterParams, client: any, enabledTools?: string[]): ToolAwareCompletionParams {
|
|
return {
|
|
client,
|
|
model: params.model,
|
|
messages: params.messages,
|
|
enabledTools,
|
|
userLocation: params.userLocation,
|
|
temperature: params.temperature,
|
|
maxTokens: params.maxTokens,
|
|
logContext: params.logContext,
|
|
};
|
|
}
|
|
|
|
function selectChatProtocol(spec: ProviderBackendSpec, params: Pick<ProviderAdapterParams, "enabledTools">) {
|
|
const enabledTools = normalizeEnabledChatTools(params.enabledTools);
|
|
const useManagedTools = spec.managedTools === true && spec.toolProtocol && enabledTools.length > 0;
|
|
return {
|
|
protocol: useManagedTools ? spec.toolProtocol! : spec.plainProtocol,
|
|
enabledTools: useManagedTools ? enabledTools : [],
|
|
managedTools: Boolean(useManagedTools),
|
|
};
|
|
}
|
|
|
|
function createProviderChatAdapter(provider: Provider, spec: ProviderBackendSpec): ProviderChatAdapter {
|
|
return {
|
|
provider,
|
|
complete(params) {
|
|
const selected = selectChatProtocol(spec, params);
|
|
return selected.protocol.complete(withClient(params, spec.createClient(), selected.enabledTools));
|
|
},
|
|
stream(params) {
|
|
const selected = selectChatProtocol(spec, params);
|
|
return selected.protocol.stream(withClient(params, spec.createClient(), selected.enabledTools));
|
|
},
|
|
};
|
|
}
|
|
|
|
const backendSpecs: Record<Provider, ProviderBackendSpec> = {
|
|
openai: {
|
|
createClient: openaiClient,
|
|
plainProtocol: chatCompletionsProtocol,
|
|
toolProtocol: responsesProtocol,
|
|
managedTools: true,
|
|
modelCatalog: {
|
|
async fetchModels(client) {
|
|
const page = await client.models.list();
|
|
return modelIdsFromListResponse(page).filter(isLikelyResponsesApiModel);
|
|
},
|
|
},
|
|
},
|
|
anthropic: {
|
|
createClient: anthropicClient,
|
|
plainProtocol: messagesProtocol,
|
|
toolProtocol: messagesProtocol,
|
|
managedTools: true,
|
|
modelCatalog: {
|
|
async fetchModels(client) {
|
|
const page = await client.models.list({ limit: 200 });
|
|
return modelIdsFromListResponse(page);
|
|
},
|
|
},
|
|
},
|
|
xai: {
|
|
createClient: xaiClient,
|
|
plainProtocol: chatCompletionsProtocol,
|
|
toolProtocol: chatCompletionsProtocol,
|
|
managedTools: true,
|
|
modelCatalog: {
|
|
async fetchModels(client) {
|
|
const page = await client.models.list();
|
|
return modelIdsFromListResponse(page);
|
|
},
|
|
},
|
|
},
|
|
"hermes-agent": {
|
|
createClient: hermesAgentClient,
|
|
plainProtocol: chatCompletionsProtocol,
|
|
managedTools: false,
|
|
modelCatalog: {
|
|
enabled: isHermesAgentConfigured,
|
|
async fetchModels(client) {
|
|
const page = await client.models.list();
|
|
const models = modelIdsFromListResponse(page);
|
|
if (env.HERMES_AGENT_MODEL) models.push(env.HERMES_AGENT_MODEL);
|
|
return models;
|
|
},
|
|
fallbackModels() {
|
|
return env.HERMES_AGENT_MODEL ? [env.HERMES_AGENT_MODEL] : [];
|
|
},
|
|
},
|
|
},
|
|
};
|
|
|
|
const providerChatAdapters: Record<Provider, ProviderChatAdapter> = Object.fromEntries(
|
|
Object.entries(backendSpecs).map(([provider, spec]) => [provider, createProviderChatAdapter(provider as Provider, spec)])
|
|
) as Record<Provider, ProviderChatAdapter>;
|
|
|
|
export function getProviderChatAdapter(provider: Provider) {
|
|
return providerChatAdapters[provider];
|
|
}
|
|
|
|
export function describeProviderChatBackend(provider: Provider, enabledTools?: string[]) {
|
|
const selected = selectChatProtocol(backendSpecs[provider], { enabledTools });
|
|
return {
|
|
provider,
|
|
protocol: selected.protocol.id,
|
|
managedTools: selected.managedTools,
|
|
enabledTools: selected.enabledTools,
|
|
};
|
|
}
|
|
|
|
export function listModelCatalogProviders(): Provider[] {
|
|
return (Object.entries(backendSpecs) as [Provider, ProviderBackendSpec][])
|
|
.filter(([, spec]) => {
|
|
const catalog = spec.modelCatalog;
|
|
return catalog !== undefined && catalog.enabled?.() !== false;
|
|
})
|
|
.map(([provider]) => provider);
|
|
}
|
|
|
|
export async function fetchProviderCatalogModels(provider: Provider) {
|
|
const spec = backendSpecs[provider].modelCatalog;
|
|
if (!spec) return [];
|
|
return uniqSorted(await spec.fetchModels(backendSpecs[provider].createClient()));
|
|
}
|
|
|
|
export function getProviderCatalogFallbackModels(provider: Provider) {
|
|
return uniqSorted(backendSpecs[provider].modelCatalog?.fallbackModels?.() ?? []);
|
|
}
|