diff --git a/docs/api/streaming-chat.md b/docs/api/streaming-chat.md index c6335d7..55cf17c 100644 --- a/docs/api/streaming-chat.md +++ b/docs/api/streaming-chat.md @@ -108,8 +108,8 @@ Event order: Tool-enabled streaming notes (`openai`/`xai`): - Stream still emits standard `meta`, `delta`, `done|error` events. -- Stream may emit `tool_call` events before final assistant text. -- `delta` may arrive as one consolidated chunk after tool execution, rather than many token-level chunks. +- Stream may emit `tool_call` events while tool calls are executed. +- `delta` events stream incrementally as text is generated. ## Persistence + Consistency Model diff --git a/server/src/llm/chat-tools.ts b/server/src/llm/chat-tools.ts index e898ab5..91ae8c9 100644 --- a/server/src/llm/chat-tools.ts +++ b/server/src/llm/chat-tools.ts @@ -114,6 +114,11 @@ type ToolAwareCompletionResult = { toolEvents: ToolExecutionEvent[]; }; +export type ToolAwareStreamingEvent = + | { type: "delta"; text: string } + | { type: "tool_call"; event: ToolExecutionEvent } + | { type: "done"; result: ToolAwareCompletionResult }; + type ToolAwareCompletionParams = { client: OpenAI; model: string; @@ -399,6 +404,67 @@ function mergeUsage(acc: Required, usage: any) { return true; } +type NormalizedToolCall = { + id: string; + name: string; + arguments: string; +}; + +function normalizeModelToolCalls(toolCalls: any[], round: number): NormalizedToolCall[] { + return toolCalls.map((call: any, index: number) => ({ + id: call?.id ?? `tool_call_${round}_${index}`, + name: call?.function?.name ?? "unknown_tool", + arguments: call?.function?.arguments ?? "{}", + })); +} + +async function executeToolCallAndBuildEvent( + call: NormalizedToolCall, + params: ToolAwareCompletionParams +): Promise<{ event: ToolExecutionEvent; toolResult: ToolRunOutcome }> { + const startedAtMs = Date.now(); + const startedAt = new Date(startedAtMs).toISOString(); + let toolResult: ToolRunOutcome; + let parsedArgs: Record = {}; + try { + parsedArgs = toRecord(parseToolArgs(call.arguments)); + toolResult = await executeTool(call.name, parsedArgs); + } catch (err: any) { + toolResult = { + ok: false, + error: err?.message ?? String(err), + }; + } + + const status: "completed" | "failed" = toolResult.ok ? "completed" : "failed"; + const error = + status === "failed" + ? typeof toolResult.error === "string" + ? toolResult.error + : "Tool execution failed." + : undefined; + + const completedAtMs = Date.now(); + const event: ToolExecutionEvent = { + toolCallId: call.id, + name: call.name, + status, + summary: buildToolSummary(call.name, parsedArgs, status, error), + args: parsedArgs, + startedAt, + completedAt: new Date(completedAtMs).toISOString(), + durationMs: completedAtMs - startedAtMs, + error, + resultPreview: buildResultPreview(toolResult), + }; + logToolEvent(event, params.logContext); + if (params.onToolEvent) { + await params.onToolEvent(event); + } + + return { event, toolResult }; +} + export async function runToolAwareOpenAIChat(params: ToolAwareCompletionParams): Promise { const conversation: any[] = normalizeIncomingMessages(params.messages); const rawResponses: unknown[] = []; @@ -439,16 +505,17 @@ export async function runToolAwareOpenAIChat(params: ToolAwareCompletionParams): }; } - totalToolCalls += toolCalls.length; + const normalizedToolCalls = normalizeModelToolCalls(toolCalls, round); + totalToolCalls += normalizedToolCalls.length; const assistantToolCallMessage: any = { role: "assistant", - tool_calls: toolCalls.map((call: any, index: number) => ({ - id: call?.id ?? `tool_call_${round}_${index}`, + tool_calls: normalizedToolCalls.map((call) => ({ + id: call.id, type: "function", function: { - name: call?.function?.name ?? "unknown_tool", - arguments: call?.function?.arguments ?? "{}", + name: call.name, + arguments: call.arguments, }, })), }; @@ -457,52 +524,13 @@ export async function runToolAwareOpenAIChat(params: ToolAwareCompletionParams): } conversation.push(assistantToolCallMessage); - for (let index = 0; index < toolCalls.length; index += 1) { - const call: any = toolCalls[index]; - const toolName = call?.function?.name ?? "unknown_tool"; - const toolCallId = call?.id ?? `tool_call_${round}_${index}`; - const startedAtMs = Date.now(); - const startedAt = new Date(startedAtMs).toISOString(); - let toolResult: ToolRunOutcome; - let parsedArgs: Record = {}; - try { - parsedArgs = toRecord(parseToolArgs(call?.function?.arguments)); - toolResult = await executeTool(toolName, parsedArgs); - } catch (err: any) { - toolResult = { - ok: false, - error: err?.message ?? String(err), - }; - } - const status: "completed" | "failed" = toolResult.ok ? "completed" : "failed"; - const error = - status === "failed" - ? typeof toolResult.error === "string" - ? toolResult.error - : "Tool execution failed." - : undefined; - const completedAtMs = Date.now(); - const event: ToolExecutionEvent = { - toolCallId, - name: toolName, - status, - summary: buildToolSummary(toolName, parsedArgs, status, error), - args: parsedArgs, - startedAt, - completedAt: new Date(completedAtMs).toISOString(), - durationMs: completedAtMs - startedAtMs, - error, - resultPreview: buildResultPreview(toolResult), - }; + for (const call of normalizedToolCalls) { + const { event, toolResult } = await executeToolCallAndBuildEvent(call, params); toolEvents.push(event); - logToolEvent(event, params.logContext); - if (params.onToolEvent) { - await params.onToolEvent(event); - } conversation.push({ role: "tool", - tool_call_id: toolCallId, + tool_call_id: call.id, content: JSON.stringify(toolResult), }); } @@ -515,3 +543,115 @@ export async function runToolAwareOpenAIChat(params: ToolAwareCompletionParams): toolEvents, }; } + +export async function* runToolAwareOpenAIChatStream( + params: ToolAwareCompletionParams +): AsyncGenerator { + const conversation: any[] = normalizeIncomingMessages(params.messages); + const rawResponses: unknown[] = []; + const toolEvents: ToolExecutionEvent[] = []; + const usageAcc: Required = { inputTokens: 0, outputTokens: 0, totalTokens: 0 }; + let sawUsage = false; + let totalToolCalls = 0; + + for (let round = 0; round < MAX_TOOL_ROUNDS; round += 1) { + const stream = await params.client.chat.completions.create({ + model: params.model, + messages: conversation, + temperature: params.temperature, + max_tokens: params.maxTokens, + tools: CHAT_TOOLS, + tool_choice: "auto", + stream: true, + stream_options: { include_usage: true }, + } as any); + + let roundText = ""; + const roundToolCalls = new Map(); + + for await (const chunk of stream as any as AsyncIterable) { + rawResponses.push(chunk); + sawUsage = mergeUsage(usageAcc, chunk?.usage) || sawUsage; + + const choice = chunk?.choices?.[0]; + const deltaText = choice?.delta?.content ?? ""; + if (typeof deltaText === "string" && deltaText.length) { + roundText += deltaText; + if (roundToolCalls.size === 0) { + yield { type: "delta", text: deltaText }; + } + } + + const deltaToolCalls = Array.isArray(choice?.delta?.tool_calls) ? choice.delta.tool_calls : []; + for (const toolCall of deltaToolCalls) { + const idx = typeof toolCall?.index === "number" ? toolCall.index : 0; + const entry = roundToolCalls.get(idx) ?? { arguments: "" }; + if (typeof toolCall?.id === "string" && toolCall.id.length) { + entry.id = toolCall.id; + } + if (typeof toolCall?.function?.name === "string" && toolCall.function.name.length) { + entry.name = toolCall.function.name; + } + if (typeof toolCall?.function?.arguments === "string" && toolCall.function.arguments.length) { + entry.arguments += toolCall.function.arguments; + } + roundToolCalls.set(idx, entry); + } + } + + const normalizedToolCalls: NormalizedToolCall[] = [...roundToolCalls.entries()] + .sort((a, b) => a[0] - b[0]) + .map(([_, call], index) => ({ + id: call.id ?? `tool_call_${round}_${index}`, + name: call.name ?? "unknown_tool", + arguments: call.arguments || "{}", + })); + + if (!normalizedToolCalls.length) { + yield { + type: "done", + result: { + text: roundText, + usage: sawUsage ? usageAcc : undefined, + raw: { streamed: true, responses: rawResponses, toolCallsUsed: totalToolCalls }, + toolEvents, + }, + }; + return; + } + + totalToolCalls += normalizedToolCalls.length; + conversation.push({ + role: "assistant", + tool_calls: normalizedToolCalls.map((call) => ({ + id: call.id, + type: "function", + function: { + name: call.name, + arguments: call.arguments, + }, + })), + }); + + for (const call of normalizedToolCalls) { + const { event, toolResult } = await executeToolCallAndBuildEvent(call, params); + toolEvents.push(event); + yield { type: "tool_call", event }; + conversation.push({ + role: "tool", + tool_call_id: call.id, + content: JSON.stringify(toolResult), + }); + } + } + + yield { + type: "done", + result: { + text: "I reached the tool-call limit while gathering information. Please narrow the request and try again.", + usage: sawUsage ? usageAcc : undefined, + raw: { streamed: true, responses: rawResponses, toolCallsUsed: totalToolCalls, toolCallLimitReached: true }, + toolEvents, + }, + }; +} diff --git a/server/src/llm/streaming.ts b/server/src/llm/streaming.ts index 1422048..fb54c68 100644 --- a/server/src/llm/streaming.ts +++ b/server/src/llm/streaming.ts @@ -1,7 +1,7 @@ import { performance } from "node:perf_hooks"; import { prisma } from "../db.js"; import { anthropicClient, openaiClient, xaiClient } from "./providers.js"; -import { buildToolLogMessageData, runToolAwareOpenAIChat, type ToolExecutionEvent } from "./chat-tools.js"; +import { buildToolLogMessageData, runToolAwareOpenAIChatStream, type ToolExecutionEvent } from "./chat-tools.js"; import type { MultiplexRequest, Provider } from "./types.js"; export type StreamEvent = @@ -57,31 +57,33 @@ export async function* runMultiplexStream(req: MultiplexRequest): AsyncGenerator try { if (req.provider === "openai" || req.provider === "xai") { const client = req.provider === "openai" ? openaiClient() : xaiClient(); - const toolEvents: ToolExecutionEvent[] = []; - const r = await runToolAwareOpenAIChat({ + for await (const ev of runToolAwareOpenAIChatStream({ client, model: req.model, messages: req.messages, temperature: req.temperature, maxTokens: req.maxTokens, - onToolEvent: (event) => { - toolEvents.push(event); - }, logContext: { provider: req.provider, model: req.model, chatId, }, - }); - raw = r.raw; - text = r.text; - usage = r.usage; - toolMessages = toolEvents.map((event) => buildToolLogMessageData(chatId, event)); - for (const event of toolEvents) { - yield { type: "tool_call", event }; - } - if (text) { - yield { type: "delta", text }; + })) { + if (ev.type === "delta") { + text += ev.text; + yield { type: "delta", text: ev.text }; + continue; + } + + if (ev.type === "tool_call") { + toolMessages.push(buildToolLogMessageData(chatId, ev.event)); + yield { type: "tool_call", event: ev.event }; + continue; + } + + raw = ev.result.raw; + usage = ev.result.usage; + text = ev.result.text; } } else if (req.provider === "anthropic") { const client = anthropicClient();