big backend refactor
This commit is contained in:
217
server/src/llm/provider-adapters.ts
Normal file
217
server/src/llm/provider-adapters.ts
Normal file
@@ -0,0 +1,217 @@
|
||||
import {
|
||||
normalizeEnabledChatTools,
|
||||
type ToolAwareCompletionParams,
|
||||
type ToolAwareCompletionResult,
|
||||
type ToolAwareStreamingEvent,
|
||||
} from "./chat-tools.js";
|
||||
import { completeWithChatCompletionsApi, streamWithChatCompletionsApi } from "./protocols/chat-completions-api.js";
|
||||
import { completeWithMessagesApi, streamWithMessagesApi } from "./protocols/messages-api.js";
|
||||
import { completeWithResponsesApi, streamWithResponsesApi } from "./protocols/responses-api.js";
|
||||
import { env } from "../env.js";
|
||||
import { anthropicClient, hermesAgentClient, isHermesAgentConfigured, openaiClient, xaiClient } from "./providers.js";
|
||||
import type { ChatMessage, Provider } from "./types.js";
|
||||
|
||||
type ProviderAdapterParams = {
|
||||
model: string;
|
||||
messages: ChatMessage[];
|
||||
enabledTools?: string[];
|
||||
userLocation?: string;
|
||||
temperature?: number;
|
||||
maxTokens?: number;
|
||||
logContext?: ToolAwareCompletionParams["logContext"];
|
||||
};
|
||||
|
||||
export type ProviderChatAdapter = {
|
||||
provider: Provider;
|
||||
complete(params: ProviderAdapterParams): Promise<ToolAwareCompletionResult>;
|
||||
stream(params: ProviderAdapterParams): AsyncGenerator<ToolAwareStreamingEvent>;
|
||||
};
|
||||
|
||||
type ChatProtocolId = "chat-completions" | "messages" | "responses";
|
||||
|
||||
type ChatProtocol = {
|
||||
id: ChatProtocolId;
|
||||
complete(params: ToolAwareCompletionParams): Promise<ToolAwareCompletionResult>;
|
||||
stream(params: ToolAwareCompletionParams): AsyncGenerator<ToolAwareStreamingEvent>;
|
||||
};
|
||||
|
||||
type ModelCatalogSpec = {
|
||||
enabled?: () => boolean;
|
||||
fetchModels(client: any): Promise<string[]>;
|
||||
fallbackModels?: () => string[];
|
||||
};
|
||||
|
||||
type ProviderBackendSpec = {
|
||||
createClient: () => any;
|
||||
plainProtocol: ChatProtocol;
|
||||
toolProtocol?: ChatProtocol;
|
||||
managedTools?: boolean;
|
||||
modelCatalog?: ModelCatalogSpec;
|
||||
};
|
||||
|
||||
const chatCompletionsProtocol: ChatProtocol = {
|
||||
id: "chat-completions",
|
||||
complete: completeWithChatCompletionsApi,
|
||||
stream: streamWithChatCompletionsApi,
|
||||
};
|
||||
|
||||
const messagesProtocol: ChatProtocol = {
|
||||
id: "messages",
|
||||
complete: completeWithMessagesApi,
|
||||
stream: streamWithMessagesApi,
|
||||
};
|
||||
|
||||
const responsesProtocol: ChatProtocol = {
|
||||
id: "responses",
|
||||
complete: completeWithResponsesApi,
|
||||
stream: streamWithResponsesApi,
|
||||
};
|
||||
|
||||
function uniqSorted(values: string[]) {
|
||||
return [...new Set(values.map((value) => value.trim()).filter(Boolean))].sort((a, b) => a.localeCompare(b));
|
||||
}
|
||||
|
||||
function modelIdsFromListResponse(page: any) {
|
||||
return Array.isArray(page?.data)
|
||||
? page.data.map((model: any) => model?.id).filter((id: unknown): id is string => typeof id === "string")
|
||||
: [];
|
||||
}
|
||||
|
||||
function isLikelyResponsesApiModel(model: string) {
|
||||
const id = model.toLowerCase();
|
||||
if (id.includes("embedding") || id.includes("moderation")) return false;
|
||||
if (id.includes("audio") || id.includes("realtime") || id.includes("transcribe") || id.includes("tts")) return false;
|
||||
if (id.includes("image") || id.includes("dall-e") || id.includes("sora")) return false;
|
||||
if (id.includes("search") || id.includes("computer-use")) return false;
|
||||
return /^(gpt-|o\d|chatgpt-)/.test(id);
|
||||
}
|
||||
|
||||
function withClient(params: ProviderAdapterParams, client: any, enabledTools?: string[]): ToolAwareCompletionParams {
|
||||
return {
|
||||
client,
|
||||
model: params.model,
|
||||
messages: params.messages,
|
||||
enabledTools,
|
||||
userLocation: params.userLocation,
|
||||
temperature: params.temperature,
|
||||
maxTokens: params.maxTokens,
|
||||
logContext: params.logContext,
|
||||
};
|
||||
}
|
||||
|
||||
function selectChatProtocol(spec: ProviderBackendSpec, params: Pick<ProviderAdapterParams, "enabledTools">) {
|
||||
const enabledTools = normalizeEnabledChatTools(params.enabledTools);
|
||||
const useManagedTools = spec.managedTools === true && spec.toolProtocol && enabledTools.length > 0;
|
||||
return {
|
||||
protocol: useManagedTools ? spec.toolProtocol! : spec.plainProtocol,
|
||||
enabledTools: useManagedTools ? enabledTools : [],
|
||||
managedTools: Boolean(useManagedTools),
|
||||
};
|
||||
}
|
||||
|
||||
function createProviderChatAdapter(provider: Provider, spec: ProviderBackendSpec): ProviderChatAdapter {
|
||||
return {
|
||||
provider,
|
||||
complete(params) {
|
||||
const selected = selectChatProtocol(spec, params);
|
||||
return selected.protocol.complete(withClient(params, spec.createClient(), selected.enabledTools));
|
||||
},
|
||||
stream(params) {
|
||||
const selected = selectChatProtocol(spec, params);
|
||||
return selected.protocol.stream(withClient(params, spec.createClient(), selected.enabledTools));
|
||||
},
|
||||
};
|
||||
}
|
||||
|
||||
const backendSpecs: Record<Provider, ProviderBackendSpec> = {
|
||||
openai: {
|
||||
createClient: openaiClient,
|
||||
plainProtocol: chatCompletionsProtocol,
|
||||
toolProtocol: responsesProtocol,
|
||||
managedTools: true,
|
||||
modelCatalog: {
|
||||
async fetchModels(client) {
|
||||
const page = await client.models.list();
|
||||
return modelIdsFromListResponse(page).filter(isLikelyResponsesApiModel);
|
||||
},
|
||||
},
|
||||
},
|
||||
anthropic: {
|
||||
createClient: anthropicClient,
|
||||
plainProtocol: messagesProtocol,
|
||||
toolProtocol: messagesProtocol,
|
||||
managedTools: true,
|
||||
modelCatalog: {
|
||||
async fetchModels(client) {
|
||||
const page = await client.models.list({ limit: 200 });
|
||||
return modelIdsFromListResponse(page);
|
||||
},
|
||||
},
|
||||
},
|
||||
xai: {
|
||||
createClient: xaiClient,
|
||||
plainProtocol: chatCompletionsProtocol,
|
||||
toolProtocol: chatCompletionsProtocol,
|
||||
managedTools: true,
|
||||
modelCatalog: {
|
||||
async fetchModels(client) {
|
||||
const page = await client.models.list();
|
||||
return modelIdsFromListResponse(page);
|
||||
},
|
||||
},
|
||||
},
|
||||
"hermes-agent": {
|
||||
createClient: hermesAgentClient,
|
||||
plainProtocol: chatCompletionsProtocol,
|
||||
managedTools: false,
|
||||
modelCatalog: {
|
||||
enabled: isHermesAgentConfigured,
|
||||
async fetchModels(client) {
|
||||
const page = await client.models.list();
|
||||
const models = modelIdsFromListResponse(page);
|
||||
if (env.HERMES_AGENT_MODEL) models.push(env.HERMES_AGENT_MODEL);
|
||||
return models;
|
||||
},
|
||||
fallbackModels() {
|
||||
return env.HERMES_AGENT_MODEL ? [env.HERMES_AGENT_MODEL] : [];
|
||||
},
|
||||
},
|
||||
},
|
||||
};
|
||||
|
||||
const providerChatAdapters: Record<Provider, ProviderChatAdapter> = Object.fromEntries(
|
||||
Object.entries(backendSpecs).map(([provider, spec]) => [provider, createProviderChatAdapter(provider as Provider, spec)])
|
||||
) as Record<Provider, ProviderChatAdapter>;
|
||||
|
||||
export function getProviderChatAdapter(provider: Provider) {
|
||||
return providerChatAdapters[provider];
|
||||
}
|
||||
|
||||
export function describeProviderChatBackend(provider: Provider, enabledTools?: string[]) {
|
||||
const selected = selectChatProtocol(backendSpecs[provider], { enabledTools });
|
||||
return {
|
||||
provider,
|
||||
protocol: selected.protocol.id,
|
||||
managedTools: selected.managedTools,
|
||||
enabledTools: selected.enabledTools,
|
||||
};
|
||||
}
|
||||
|
||||
export function listModelCatalogProviders(): Provider[] {
|
||||
return (Object.entries(backendSpecs) as [Provider, ProviderBackendSpec][])
|
||||
.filter(([, spec]) => {
|
||||
const catalog = spec.modelCatalog;
|
||||
return catalog !== undefined && catalog.enabled?.() !== false;
|
||||
})
|
||||
.map(([provider]) => provider);
|
||||
}
|
||||
|
||||
export async function fetchProviderCatalogModels(provider: Provider) {
|
||||
const spec = backendSpecs[provider].modelCatalog;
|
||||
if (!spec) return [];
|
||||
return uniqSorted(await spec.fetchModels(backendSpecs[provider].createClient()));
|
||||
}
|
||||
|
||||
export function getProviderCatalogFallbackModels(provider: Provider) {
|
||||
return uniqSorted(backendSpecs[provider].modelCatalog?.fallbackModels?.() ?? []);
|
||||
}
|
||||
Reference in New Issue
Block a user