Files
Sybil-2/server/src/llm/streaming.ts

172 lines
5.4 KiB
TypeScript
Raw Normal View History

2026-02-13 22:43:55 -08:00
import { performance } from "node:perf_hooks";
import { prisma } from "../db.js";
import { anthropicClient, openaiClient, xaiClient } from "./providers.js";
2026-03-02 16:39:05 -08:00
import { buildToolLogMessageData, runToolAwareOpenAIChatStream, type ToolExecutionEvent } from "./chat-tools.js";
2026-02-13 22:43:55 -08:00
import type { MultiplexRequest, Provider } from "./types.js";
export type StreamEvent =
| { type: "meta"; chatId: string; callId: string; provider: Provider; model: string }
| { type: "tool_call"; event: ToolExecutionEvent }
2026-02-13 22:43:55 -08:00
| { type: "delta"; text: string }
| { type: "done"; text: string; usage?: { inputTokens?: number; outputTokens?: number; totalTokens?: number } }
| { type: "error"; message: string };
function getChatIdOrCreate(chatId?: string) {
if (chatId) return Promise.resolve(chatId);
return prisma.chat.create({ data: {}, select: { id: true } }).then((c) => c.id);
}
export async function* runMultiplexStream(req: MultiplexRequest): AsyncGenerator<StreamEvent> {
const t0 = performance.now();
const chatId = await getChatIdOrCreate(req.chatId);
const call = await prisma.llmCall.create({
data: {
chatId,
provider: req.provider as any,
model: req.model,
request: req as any,
},
select: { id: true },
});
2026-02-14 22:06:30 -08:00
await prisma.$transaction([
prisma.chat.update({
where: { id: chatId },
data: {
lastUsedProvider: req.provider as any,
lastUsedModel: req.model,
},
}),
prisma.chat.updateMany({
where: { id: chatId, initiatedProvider: null },
data: {
initiatedProvider: req.provider as any,
initiatedModel: req.model,
},
}),
]);
2026-02-13 22:43:55 -08:00
yield { type: "meta", chatId, callId: call.id, provider: req.provider, model: req.model };
let text = "";
let usage: StreamEvent extends any ? any : never;
let raw: unknown = { streamed: true };
let toolMessages: ReturnType<typeof buildToolLogMessageData>[] = [];
2026-02-13 22:43:55 -08:00
try {
if (req.provider === "openai" || req.provider === "xai") {
const client = req.provider === "openai" ? openaiClient() : xaiClient();
2026-03-02 16:39:05 -08:00
for await (const ev of runToolAwareOpenAIChatStream({
client,
2026-02-13 22:43:55 -08:00
model: req.model,
messages: req.messages,
2026-02-13 22:43:55 -08:00
temperature: req.temperature,
maxTokens: req.maxTokens,
logContext: {
provider: req.provider,
model: req.model,
chatId,
},
2026-03-02 16:39:05 -08:00
})) {
if (ev.type === "delta") {
text += ev.text;
yield { type: "delta", text: ev.text };
continue;
}
if (ev.type === "tool_call") {
toolMessages.push(buildToolLogMessageData(chatId, ev.event));
yield { type: "tool_call", event: ev.event };
continue;
}
raw = ev.result.raw;
usage = ev.result.usage;
text = ev.result.text;
2026-02-13 22:43:55 -08:00
}
} else if (req.provider === "anthropic") {
const client = anthropicClient();
const system = req.messages.find((m) => m.role === "system")?.content;
const msgs = req.messages
.filter((m) => m.role !== "system")
.map((m) => ({ role: m.role === "assistant" ? "assistant" : "user", content: m.content }));
const stream = await client.messages.create({
model: req.model,
system,
max_tokens: req.maxTokens ?? 1024,
temperature: req.temperature,
messages: msgs as any,
stream: true,
});
for await (const ev of stream as any as AsyncIterable<any>) {
// Anthropic streaming events include content_block_delta with text_delta
if (ev?.type === "content_block_delta" && ev?.delta?.type === "text_delta") {
const delta = ev.delta.text ?? "";
if (delta) {
text += delta;
yield { type: "delta", text: delta };
}
}
// capture usage if present on message_delta
if (ev?.type === "message_delta" && ev?.usage) {
usage = {
inputTokens: ev.usage.input_tokens,
outputTokens: ev.usage.output_tokens,
totalTokens:
(ev.usage.input_tokens ?? 0) + (ev.usage.output_tokens ?? 0),
};
}
// some streams end with message_stop
}
raw = { streamed: true, provider: "anthropic" };
2026-02-13 22:43:55 -08:00
} else {
throw new Error(`unknown provider: ${req.provider}`);
}
const latencyMs = Math.round(performance.now() - t0);
await prisma.$transaction(async (tx) => {
if (toolMessages.length) {
await tx.message.createMany({
data: toolMessages.map((message) => ({
chatId: message.chatId,
role: message.role as any,
content: message.content,
name: message.name,
metadata: message.metadata as any,
})),
});
}
await tx.message.create({
2026-02-13 22:43:55 -08:00
data: { chatId, role: "assistant" as any, content: text },
});
await tx.llmCall.update({
2026-02-13 22:43:55 -08:00
where: { id: call.id },
data: {
response: raw as any,
latencyMs,
inputTokens: usage?.inputTokens,
outputTokens: usage?.outputTokens,
totalTokens: usage?.totalTokens,
},
});
});
2026-02-13 22:43:55 -08:00
yield { type: "done", text, usage };
} catch (e: any) {
const latencyMs = Math.round(performance.now() - t0);
await prisma.llmCall.update({
where: { id: call.id },
data: {
error: e?.message ?? String(e),
latencyMs,
},
});
yield { type: "error", message: e?.message ?? String(e) };
}
}