import { performance } from "node:perf_hooks"; import { prisma } from "../db.js"; import { anthropicClient, openaiClient, xaiClient } from "./providers.js"; import { buildToolLogMessageData, runToolAwareChatCompletionsStream, runToolAwareOpenAIChatStream, type ToolExecutionEvent, } from "./chat-tools.js"; import { buildAnthropicConversationMessage, getAnthropicSystemPrompt } from "./message-content.js"; import type { MultiplexRequest, Provider } from "./types.js"; type StreamUsage = { inputTokens?: number; outputTokens?: number; totalTokens?: number; }; export type StreamEvent = | { type: "meta"; chatId: string | null; callId: string | null; provider: Provider; model: string } | { type: "tool_call"; event: ToolExecutionEvent } | { type: "delta"; text: string } | { type: "done"; text: string; usage?: StreamUsage } | { type: "error"; message: string }; function getChatIdOrCreate(chatId?: string) { if (chatId) return Promise.resolve(chatId); return prisma.chat.create({ data: {}, select: { id: true } }).then((c) => c.id); } export async function* runMultiplexStream(req: MultiplexRequest): AsyncGenerator { const t0 = performance.now(); const shouldPersist = req.persist !== false; const chatId = shouldPersist ? await getChatIdOrCreate(req.chatId) : null; const call = shouldPersist && chatId ? await prisma.llmCall.create({ data: { chatId, provider: req.provider as any, model: req.model, request: req as any, }, select: { id: true }, }) : null; if (shouldPersist && chatId) { await prisma.$transaction([ prisma.chat.update({ where: { id: chatId }, data: { lastUsedProvider: req.provider as any, lastUsedModel: req.model, }, }), prisma.chat.updateMany({ where: { id: chatId, initiatedProvider: null }, data: { initiatedProvider: req.provider as any, initiatedModel: req.model, }, }), ]); } yield { type: "meta", chatId, callId: call?.id ?? null, provider: req.provider, model: req.model }; let text = ""; let usage: StreamUsage | undefined; let raw: unknown = { streamed: true }; try { if (req.provider === "openai" || req.provider === "xai") { const client = req.provider === "openai" ? openaiClient() : xaiClient(); const streamEvents = req.provider === "openai" ? runToolAwareOpenAIChatStream({ client, model: req.model, messages: req.messages, temperature: req.temperature, maxTokens: req.maxTokens, logContext: { provider: req.provider, model: req.model, chatId: chatId ?? undefined, }, }) : runToolAwareChatCompletionsStream({ client, model: req.model, messages: req.messages, temperature: req.temperature, maxTokens: req.maxTokens, logContext: { provider: req.provider, model: req.model, chatId: chatId ?? undefined, }, }); for await (const ev of streamEvents) { if (ev.type === "delta") { text += ev.text; yield { type: "delta", text: ev.text }; continue; } if (ev.type === "tool_call") { if (shouldPersist && chatId) { const toolMessage = buildToolLogMessageData(chatId, ev.event); await prisma.message.create({ data: { chatId: toolMessage.chatId, role: toolMessage.role as any, content: toolMessage.content, name: toolMessage.name, metadata: toolMessage.metadata as any, }, }); } yield { type: "tool_call", event: ev.event }; continue; } raw = ev.result.raw; usage = ev.result.usage; text = ev.result.text; } } else if (req.provider === "anthropic") { const client = anthropicClient(); const system = getAnthropicSystemPrompt(req.messages); const msgs = req.messages.filter((message) => message.role !== "system").map((message) => buildAnthropicConversationMessage(message)); const stream = await client.messages.create({ model: req.model, system, max_tokens: req.maxTokens ?? 1024, temperature: req.temperature, messages: msgs as any, stream: true, }); for await (const ev of stream as any as AsyncIterable) { // Anthropic streaming events include content_block_delta with text_delta if (ev?.type === "content_block_delta" && ev?.delta?.type === "text_delta") { const delta = ev.delta.text ?? ""; if (delta) { text += delta; yield { type: "delta", text: delta }; } } // capture usage if present on message_delta if (ev?.type === "message_delta" && ev?.usage) { usage = { inputTokens: ev.usage.input_tokens, outputTokens: ev.usage.output_tokens, totalTokens: (ev.usage.input_tokens ?? 0) + (ev.usage.output_tokens ?? 0), }; } // some streams end with message_stop } raw = { streamed: true, provider: "anthropic" }; } else { throw new Error(`unknown provider: ${req.provider}`); } const latencyMs = Math.round(performance.now() - t0); if (shouldPersist && chatId && call) { await prisma.$transaction(async (tx) => { await tx.message.create({ data: { chatId, role: "assistant" as any, content: text }, }); await tx.llmCall.update({ where: { id: call.id }, data: { response: raw as any, latencyMs, inputTokens: usage?.inputTokens, outputTokens: usage?.outputTokens, totalTokens: usage?.totalTokens, }, }); }); } yield { type: "done", text, usage }; } catch (e: any) { const latencyMs = Math.round(performance.now() - t0); if (shouldPersist && call) { await prisma.llmCall.update({ where: { id: call.id }, data: { error: e?.message ?? String(e), latencyMs, }, }); } yield { type: "error", message: e?.message ?? String(e) }; } }