Enable streaming for tool call logs

This commit is contained in:
2026-03-02 16:39:05 -08:00
parent 991316e692
commit e3253d1741
3 changed files with 207 additions and 65 deletions

View File

@@ -108,8 +108,8 @@ Event order:
Tool-enabled streaming notes (`openai`/`xai`):
- Stream still emits standard `meta`, `delta`, `done|error` events.
- Stream may emit `tool_call` events before final assistant text.
- `delta` may arrive as one consolidated chunk after tool execution, rather than many token-level chunks.
- Stream may emit `tool_call` events while tool calls are executed.
- `delta` events stream incrementally as text is generated.
## Persistence + Consistency Model

View File

@@ -114,6 +114,11 @@ type ToolAwareCompletionResult = {
toolEvents: ToolExecutionEvent[];
};
export type ToolAwareStreamingEvent =
| { type: "delta"; text: string }
| { type: "tool_call"; event: ToolExecutionEvent }
| { type: "done"; result: ToolAwareCompletionResult };
type ToolAwareCompletionParams = {
client: OpenAI;
model: string;
@@ -399,6 +404,67 @@ function mergeUsage(acc: Required<ToolAwareUsage>, usage: any) {
return true;
}
type NormalizedToolCall = {
id: string;
name: string;
arguments: string;
};
function normalizeModelToolCalls(toolCalls: any[], round: number): NormalizedToolCall[] {
return toolCalls.map((call: any, index: number) => ({
id: call?.id ?? `tool_call_${round}_${index}`,
name: call?.function?.name ?? "unknown_tool",
arguments: call?.function?.arguments ?? "{}",
}));
}
async function executeToolCallAndBuildEvent(
call: NormalizedToolCall,
params: ToolAwareCompletionParams
): Promise<{ event: ToolExecutionEvent; toolResult: ToolRunOutcome }> {
const startedAtMs = Date.now();
const startedAt = new Date(startedAtMs).toISOString();
let toolResult: ToolRunOutcome;
let parsedArgs: Record<string, unknown> = {};
try {
parsedArgs = toRecord(parseToolArgs(call.arguments));
toolResult = await executeTool(call.name, parsedArgs);
} catch (err: any) {
toolResult = {
ok: false,
error: err?.message ?? String(err),
};
}
const status: "completed" | "failed" = toolResult.ok ? "completed" : "failed";
const error =
status === "failed"
? typeof toolResult.error === "string"
? toolResult.error
: "Tool execution failed."
: undefined;
const completedAtMs = Date.now();
const event: ToolExecutionEvent = {
toolCallId: call.id,
name: call.name,
status,
summary: buildToolSummary(call.name, parsedArgs, status, error),
args: parsedArgs,
startedAt,
completedAt: new Date(completedAtMs).toISOString(),
durationMs: completedAtMs - startedAtMs,
error,
resultPreview: buildResultPreview(toolResult),
};
logToolEvent(event, params.logContext);
if (params.onToolEvent) {
await params.onToolEvent(event);
}
return { event, toolResult };
}
export async function runToolAwareOpenAIChat(params: ToolAwareCompletionParams): Promise<ToolAwareCompletionResult> {
const conversation: any[] = normalizeIncomingMessages(params.messages);
const rawResponses: unknown[] = [];
@@ -439,16 +505,17 @@ export async function runToolAwareOpenAIChat(params: ToolAwareCompletionParams):
};
}
totalToolCalls += toolCalls.length;
const normalizedToolCalls = normalizeModelToolCalls(toolCalls, round);
totalToolCalls += normalizedToolCalls.length;
const assistantToolCallMessage: any = {
role: "assistant",
tool_calls: toolCalls.map((call: any, index: number) => ({
id: call?.id ?? `tool_call_${round}_${index}`,
tool_calls: normalizedToolCalls.map((call) => ({
id: call.id,
type: "function",
function: {
name: call?.function?.name ?? "unknown_tool",
arguments: call?.function?.arguments ?? "{}",
name: call.name,
arguments: call.arguments,
},
})),
};
@@ -457,52 +524,13 @@ export async function runToolAwareOpenAIChat(params: ToolAwareCompletionParams):
}
conversation.push(assistantToolCallMessage);
for (let index = 0; index < toolCalls.length; index += 1) {
const call: any = toolCalls[index];
const toolName = call?.function?.name ?? "unknown_tool";
const toolCallId = call?.id ?? `tool_call_${round}_${index}`;
const startedAtMs = Date.now();
const startedAt = new Date(startedAtMs).toISOString();
let toolResult: ToolRunOutcome;
let parsedArgs: Record<string, unknown> = {};
try {
parsedArgs = toRecord(parseToolArgs(call?.function?.arguments));
toolResult = await executeTool(toolName, parsedArgs);
} catch (err: any) {
toolResult = {
ok: false,
error: err?.message ?? String(err),
};
}
const status: "completed" | "failed" = toolResult.ok ? "completed" : "failed";
const error =
status === "failed"
? typeof toolResult.error === "string"
? toolResult.error
: "Tool execution failed."
: undefined;
const completedAtMs = Date.now();
const event: ToolExecutionEvent = {
toolCallId,
name: toolName,
status,
summary: buildToolSummary(toolName, parsedArgs, status, error),
args: parsedArgs,
startedAt,
completedAt: new Date(completedAtMs).toISOString(),
durationMs: completedAtMs - startedAtMs,
error,
resultPreview: buildResultPreview(toolResult),
};
for (const call of normalizedToolCalls) {
const { event, toolResult } = await executeToolCallAndBuildEvent(call, params);
toolEvents.push(event);
logToolEvent(event, params.logContext);
if (params.onToolEvent) {
await params.onToolEvent(event);
}
conversation.push({
role: "tool",
tool_call_id: toolCallId,
tool_call_id: call.id,
content: JSON.stringify(toolResult),
});
}
@@ -515,3 +543,115 @@ export async function runToolAwareOpenAIChat(params: ToolAwareCompletionParams):
toolEvents,
};
}
export async function* runToolAwareOpenAIChatStream(
params: ToolAwareCompletionParams
): AsyncGenerator<ToolAwareStreamingEvent> {
const conversation: any[] = normalizeIncomingMessages(params.messages);
const rawResponses: unknown[] = [];
const toolEvents: ToolExecutionEvent[] = [];
const usageAcc: Required<ToolAwareUsage> = { inputTokens: 0, outputTokens: 0, totalTokens: 0 };
let sawUsage = false;
let totalToolCalls = 0;
for (let round = 0; round < MAX_TOOL_ROUNDS; round += 1) {
const stream = await params.client.chat.completions.create({
model: params.model,
messages: conversation,
temperature: params.temperature,
max_tokens: params.maxTokens,
tools: CHAT_TOOLS,
tool_choice: "auto",
stream: true,
stream_options: { include_usage: true },
} as any);
let roundText = "";
const roundToolCalls = new Map<number, { id?: string; name?: string; arguments: string }>();
for await (const chunk of stream as any as AsyncIterable<any>) {
rawResponses.push(chunk);
sawUsage = mergeUsage(usageAcc, chunk?.usage) || sawUsage;
const choice = chunk?.choices?.[0];
const deltaText = choice?.delta?.content ?? "";
if (typeof deltaText === "string" && deltaText.length) {
roundText += deltaText;
if (roundToolCalls.size === 0) {
yield { type: "delta", text: deltaText };
}
}
const deltaToolCalls = Array.isArray(choice?.delta?.tool_calls) ? choice.delta.tool_calls : [];
for (const toolCall of deltaToolCalls) {
const idx = typeof toolCall?.index === "number" ? toolCall.index : 0;
const entry = roundToolCalls.get(idx) ?? { arguments: "" };
if (typeof toolCall?.id === "string" && toolCall.id.length) {
entry.id = toolCall.id;
}
if (typeof toolCall?.function?.name === "string" && toolCall.function.name.length) {
entry.name = toolCall.function.name;
}
if (typeof toolCall?.function?.arguments === "string" && toolCall.function.arguments.length) {
entry.arguments += toolCall.function.arguments;
}
roundToolCalls.set(idx, entry);
}
}
const normalizedToolCalls: NormalizedToolCall[] = [...roundToolCalls.entries()]
.sort((a, b) => a[0] - b[0])
.map(([_, call], index) => ({
id: call.id ?? `tool_call_${round}_${index}`,
name: call.name ?? "unknown_tool",
arguments: call.arguments || "{}",
}));
if (!normalizedToolCalls.length) {
yield {
type: "done",
result: {
text: roundText,
usage: sawUsage ? usageAcc : undefined,
raw: { streamed: true, responses: rawResponses, toolCallsUsed: totalToolCalls },
toolEvents,
},
};
return;
}
totalToolCalls += normalizedToolCalls.length;
conversation.push({
role: "assistant",
tool_calls: normalizedToolCalls.map((call) => ({
id: call.id,
type: "function",
function: {
name: call.name,
arguments: call.arguments,
},
})),
});
for (const call of normalizedToolCalls) {
const { event, toolResult } = await executeToolCallAndBuildEvent(call, params);
toolEvents.push(event);
yield { type: "tool_call", event };
conversation.push({
role: "tool",
tool_call_id: call.id,
content: JSON.stringify(toolResult),
});
}
}
yield {
type: "done",
result: {
text: "I reached the tool-call limit while gathering information. Please narrow the request and try again.",
usage: sawUsage ? usageAcc : undefined,
raw: { streamed: true, responses: rawResponses, toolCallsUsed: totalToolCalls, toolCallLimitReached: true },
toolEvents,
},
};
}

View File

@@ -1,7 +1,7 @@
import { performance } from "node:perf_hooks";
import { prisma } from "../db.js";
import { anthropicClient, openaiClient, xaiClient } from "./providers.js";
import { buildToolLogMessageData, runToolAwareOpenAIChat, type ToolExecutionEvent } from "./chat-tools.js";
import { buildToolLogMessageData, runToolAwareOpenAIChatStream, type ToolExecutionEvent } from "./chat-tools.js";
import type { MultiplexRequest, Provider } from "./types.js";
export type StreamEvent =
@@ -57,31 +57,33 @@ export async function* runMultiplexStream(req: MultiplexRequest): AsyncGenerator
try {
if (req.provider === "openai" || req.provider === "xai") {
const client = req.provider === "openai" ? openaiClient() : xaiClient();
const toolEvents: ToolExecutionEvent[] = [];
const r = await runToolAwareOpenAIChat({
for await (const ev of runToolAwareOpenAIChatStream({
client,
model: req.model,
messages: req.messages,
temperature: req.temperature,
maxTokens: req.maxTokens,
onToolEvent: (event) => {
toolEvents.push(event);
},
logContext: {
provider: req.provider,
model: req.model,
chatId,
},
});
raw = r.raw;
text = r.text;
usage = r.usage;
toolMessages = toolEvents.map((event) => buildToolLogMessageData(chatId, event));
for (const event of toolEvents) {
yield { type: "tool_call", event };
}
if (text) {
yield { type: "delta", text };
})) {
if (ev.type === "delta") {
text += ev.text;
yield { type: "delta", text: ev.text };
continue;
}
if (ev.type === "tool_call") {
toolMessages.push(buildToolLogMessageData(chatId, ev.event));
yield { type: "tool_call", event: ev.event };
continue;
}
raw = ev.result.raw;
usage = ev.result.usage;
text = ev.result.text;
}
} else if (req.provider === "anthropic") {
const client = anthropicClient();