Enable streaming for tool call logs
This commit is contained in:
@@ -108,8 +108,8 @@ Event order:
|
||||
|
||||
Tool-enabled streaming notes (`openai`/`xai`):
|
||||
- Stream still emits standard `meta`, `delta`, `done|error` events.
|
||||
- Stream may emit `tool_call` events before final assistant text.
|
||||
- `delta` may arrive as one consolidated chunk after tool execution, rather than many token-level chunks.
|
||||
- Stream may emit `tool_call` events while tool calls are executed.
|
||||
- `delta` events stream incrementally as text is generated.
|
||||
|
||||
## Persistence + Consistency Model
|
||||
|
||||
|
||||
@@ -114,6 +114,11 @@ type ToolAwareCompletionResult = {
|
||||
toolEvents: ToolExecutionEvent[];
|
||||
};
|
||||
|
||||
export type ToolAwareStreamingEvent =
|
||||
| { type: "delta"; text: string }
|
||||
| { type: "tool_call"; event: ToolExecutionEvent }
|
||||
| { type: "done"; result: ToolAwareCompletionResult };
|
||||
|
||||
type ToolAwareCompletionParams = {
|
||||
client: OpenAI;
|
||||
model: string;
|
||||
@@ -399,6 +404,67 @@ function mergeUsage(acc: Required<ToolAwareUsage>, usage: any) {
|
||||
return true;
|
||||
}
|
||||
|
||||
type NormalizedToolCall = {
|
||||
id: string;
|
||||
name: string;
|
||||
arguments: string;
|
||||
};
|
||||
|
||||
function normalizeModelToolCalls(toolCalls: any[], round: number): NormalizedToolCall[] {
|
||||
return toolCalls.map((call: any, index: number) => ({
|
||||
id: call?.id ?? `tool_call_${round}_${index}`,
|
||||
name: call?.function?.name ?? "unknown_tool",
|
||||
arguments: call?.function?.arguments ?? "{}",
|
||||
}));
|
||||
}
|
||||
|
||||
async function executeToolCallAndBuildEvent(
|
||||
call: NormalizedToolCall,
|
||||
params: ToolAwareCompletionParams
|
||||
): Promise<{ event: ToolExecutionEvent; toolResult: ToolRunOutcome }> {
|
||||
const startedAtMs = Date.now();
|
||||
const startedAt = new Date(startedAtMs).toISOString();
|
||||
let toolResult: ToolRunOutcome;
|
||||
let parsedArgs: Record<string, unknown> = {};
|
||||
try {
|
||||
parsedArgs = toRecord(parseToolArgs(call.arguments));
|
||||
toolResult = await executeTool(call.name, parsedArgs);
|
||||
} catch (err: any) {
|
||||
toolResult = {
|
||||
ok: false,
|
||||
error: err?.message ?? String(err),
|
||||
};
|
||||
}
|
||||
|
||||
const status: "completed" | "failed" = toolResult.ok ? "completed" : "failed";
|
||||
const error =
|
||||
status === "failed"
|
||||
? typeof toolResult.error === "string"
|
||||
? toolResult.error
|
||||
: "Tool execution failed."
|
||||
: undefined;
|
||||
|
||||
const completedAtMs = Date.now();
|
||||
const event: ToolExecutionEvent = {
|
||||
toolCallId: call.id,
|
||||
name: call.name,
|
||||
status,
|
||||
summary: buildToolSummary(call.name, parsedArgs, status, error),
|
||||
args: parsedArgs,
|
||||
startedAt,
|
||||
completedAt: new Date(completedAtMs).toISOString(),
|
||||
durationMs: completedAtMs - startedAtMs,
|
||||
error,
|
||||
resultPreview: buildResultPreview(toolResult),
|
||||
};
|
||||
logToolEvent(event, params.logContext);
|
||||
if (params.onToolEvent) {
|
||||
await params.onToolEvent(event);
|
||||
}
|
||||
|
||||
return { event, toolResult };
|
||||
}
|
||||
|
||||
export async function runToolAwareOpenAIChat(params: ToolAwareCompletionParams): Promise<ToolAwareCompletionResult> {
|
||||
const conversation: any[] = normalizeIncomingMessages(params.messages);
|
||||
const rawResponses: unknown[] = [];
|
||||
@@ -439,16 +505,17 @@ export async function runToolAwareOpenAIChat(params: ToolAwareCompletionParams):
|
||||
};
|
||||
}
|
||||
|
||||
totalToolCalls += toolCalls.length;
|
||||
const normalizedToolCalls = normalizeModelToolCalls(toolCalls, round);
|
||||
totalToolCalls += normalizedToolCalls.length;
|
||||
|
||||
const assistantToolCallMessage: any = {
|
||||
role: "assistant",
|
||||
tool_calls: toolCalls.map((call: any, index: number) => ({
|
||||
id: call?.id ?? `tool_call_${round}_${index}`,
|
||||
tool_calls: normalizedToolCalls.map((call) => ({
|
||||
id: call.id,
|
||||
type: "function",
|
||||
function: {
|
||||
name: call?.function?.name ?? "unknown_tool",
|
||||
arguments: call?.function?.arguments ?? "{}",
|
||||
name: call.name,
|
||||
arguments: call.arguments,
|
||||
},
|
||||
})),
|
||||
};
|
||||
@@ -457,52 +524,13 @@ export async function runToolAwareOpenAIChat(params: ToolAwareCompletionParams):
|
||||
}
|
||||
conversation.push(assistantToolCallMessage);
|
||||
|
||||
for (let index = 0; index < toolCalls.length; index += 1) {
|
||||
const call: any = toolCalls[index];
|
||||
const toolName = call?.function?.name ?? "unknown_tool";
|
||||
const toolCallId = call?.id ?? `tool_call_${round}_${index}`;
|
||||
const startedAtMs = Date.now();
|
||||
const startedAt = new Date(startedAtMs).toISOString();
|
||||
let toolResult: ToolRunOutcome;
|
||||
let parsedArgs: Record<string, unknown> = {};
|
||||
try {
|
||||
parsedArgs = toRecord(parseToolArgs(call?.function?.arguments));
|
||||
toolResult = await executeTool(toolName, parsedArgs);
|
||||
} catch (err: any) {
|
||||
toolResult = {
|
||||
ok: false,
|
||||
error: err?.message ?? String(err),
|
||||
};
|
||||
}
|
||||
const status: "completed" | "failed" = toolResult.ok ? "completed" : "failed";
|
||||
const error =
|
||||
status === "failed"
|
||||
? typeof toolResult.error === "string"
|
||||
? toolResult.error
|
||||
: "Tool execution failed."
|
||||
: undefined;
|
||||
const completedAtMs = Date.now();
|
||||
const event: ToolExecutionEvent = {
|
||||
toolCallId,
|
||||
name: toolName,
|
||||
status,
|
||||
summary: buildToolSummary(toolName, parsedArgs, status, error),
|
||||
args: parsedArgs,
|
||||
startedAt,
|
||||
completedAt: new Date(completedAtMs).toISOString(),
|
||||
durationMs: completedAtMs - startedAtMs,
|
||||
error,
|
||||
resultPreview: buildResultPreview(toolResult),
|
||||
};
|
||||
for (const call of normalizedToolCalls) {
|
||||
const { event, toolResult } = await executeToolCallAndBuildEvent(call, params);
|
||||
toolEvents.push(event);
|
||||
logToolEvent(event, params.logContext);
|
||||
if (params.onToolEvent) {
|
||||
await params.onToolEvent(event);
|
||||
}
|
||||
|
||||
conversation.push({
|
||||
role: "tool",
|
||||
tool_call_id: toolCallId,
|
||||
tool_call_id: call.id,
|
||||
content: JSON.stringify(toolResult),
|
||||
});
|
||||
}
|
||||
@@ -515,3 +543,115 @@ export async function runToolAwareOpenAIChat(params: ToolAwareCompletionParams):
|
||||
toolEvents,
|
||||
};
|
||||
}
|
||||
|
||||
export async function* runToolAwareOpenAIChatStream(
|
||||
params: ToolAwareCompletionParams
|
||||
): AsyncGenerator<ToolAwareStreamingEvent> {
|
||||
const conversation: any[] = normalizeIncomingMessages(params.messages);
|
||||
const rawResponses: unknown[] = [];
|
||||
const toolEvents: ToolExecutionEvent[] = [];
|
||||
const usageAcc: Required<ToolAwareUsage> = { inputTokens: 0, outputTokens: 0, totalTokens: 0 };
|
||||
let sawUsage = false;
|
||||
let totalToolCalls = 0;
|
||||
|
||||
for (let round = 0; round < MAX_TOOL_ROUNDS; round += 1) {
|
||||
const stream = await params.client.chat.completions.create({
|
||||
model: params.model,
|
||||
messages: conversation,
|
||||
temperature: params.temperature,
|
||||
max_tokens: params.maxTokens,
|
||||
tools: CHAT_TOOLS,
|
||||
tool_choice: "auto",
|
||||
stream: true,
|
||||
stream_options: { include_usage: true },
|
||||
} as any);
|
||||
|
||||
let roundText = "";
|
||||
const roundToolCalls = new Map<number, { id?: string; name?: string; arguments: string }>();
|
||||
|
||||
for await (const chunk of stream as any as AsyncIterable<any>) {
|
||||
rawResponses.push(chunk);
|
||||
sawUsage = mergeUsage(usageAcc, chunk?.usage) || sawUsage;
|
||||
|
||||
const choice = chunk?.choices?.[0];
|
||||
const deltaText = choice?.delta?.content ?? "";
|
||||
if (typeof deltaText === "string" && deltaText.length) {
|
||||
roundText += deltaText;
|
||||
if (roundToolCalls.size === 0) {
|
||||
yield { type: "delta", text: deltaText };
|
||||
}
|
||||
}
|
||||
|
||||
const deltaToolCalls = Array.isArray(choice?.delta?.tool_calls) ? choice.delta.tool_calls : [];
|
||||
for (const toolCall of deltaToolCalls) {
|
||||
const idx = typeof toolCall?.index === "number" ? toolCall.index : 0;
|
||||
const entry = roundToolCalls.get(idx) ?? { arguments: "" };
|
||||
if (typeof toolCall?.id === "string" && toolCall.id.length) {
|
||||
entry.id = toolCall.id;
|
||||
}
|
||||
if (typeof toolCall?.function?.name === "string" && toolCall.function.name.length) {
|
||||
entry.name = toolCall.function.name;
|
||||
}
|
||||
if (typeof toolCall?.function?.arguments === "string" && toolCall.function.arguments.length) {
|
||||
entry.arguments += toolCall.function.arguments;
|
||||
}
|
||||
roundToolCalls.set(idx, entry);
|
||||
}
|
||||
}
|
||||
|
||||
const normalizedToolCalls: NormalizedToolCall[] = [...roundToolCalls.entries()]
|
||||
.sort((a, b) => a[0] - b[0])
|
||||
.map(([_, call], index) => ({
|
||||
id: call.id ?? `tool_call_${round}_${index}`,
|
||||
name: call.name ?? "unknown_tool",
|
||||
arguments: call.arguments || "{}",
|
||||
}));
|
||||
|
||||
if (!normalizedToolCalls.length) {
|
||||
yield {
|
||||
type: "done",
|
||||
result: {
|
||||
text: roundText,
|
||||
usage: sawUsage ? usageAcc : undefined,
|
||||
raw: { streamed: true, responses: rawResponses, toolCallsUsed: totalToolCalls },
|
||||
toolEvents,
|
||||
},
|
||||
};
|
||||
return;
|
||||
}
|
||||
|
||||
totalToolCalls += normalizedToolCalls.length;
|
||||
conversation.push({
|
||||
role: "assistant",
|
||||
tool_calls: normalizedToolCalls.map((call) => ({
|
||||
id: call.id,
|
||||
type: "function",
|
||||
function: {
|
||||
name: call.name,
|
||||
arguments: call.arguments,
|
||||
},
|
||||
})),
|
||||
});
|
||||
|
||||
for (const call of normalizedToolCalls) {
|
||||
const { event, toolResult } = await executeToolCallAndBuildEvent(call, params);
|
||||
toolEvents.push(event);
|
||||
yield { type: "tool_call", event };
|
||||
conversation.push({
|
||||
role: "tool",
|
||||
tool_call_id: call.id,
|
||||
content: JSON.stringify(toolResult),
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
yield {
|
||||
type: "done",
|
||||
result: {
|
||||
text: "I reached the tool-call limit while gathering information. Please narrow the request and try again.",
|
||||
usage: sawUsage ? usageAcc : undefined,
|
||||
raw: { streamed: true, responses: rawResponses, toolCallsUsed: totalToolCalls, toolCallLimitReached: true },
|
||||
toolEvents,
|
||||
},
|
||||
};
|
||||
}
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import { performance } from "node:perf_hooks";
|
||||
import { prisma } from "../db.js";
|
||||
import { anthropicClient, openaiClient, xaiClient } from "./providers.js";
|
||||
import { buildToolLogMessageData, runToolAwareOpenAIChat, type ToolExecutionEvent } from "./chat-tools.js";
|
||||
import { buildToolLogMessageData, runToolAwareOpenAIChatStream, type ToolExecutionEvent } from "./chat-tools.js";
|
||||
import type { MultiplexRequest, Provider } from "./types.js";
|
||||
|
||||
export type StreamEvent =
|
||||
@@ -57,31 +57,33 @@ export async function* runMultiplexStream(req: MultiplexRequest): AsyncGenerator
|
||||
try {
|
||||
if (req.provider === "openai" || req.provider === "xai") {
|
||||
const client = req.provider === "openai" ? openaiClient() : xaiClient();
|
||||
const toolEvents: ToolExecutionEvent[] = [];
|
||||
const r = await runToolAwareOpenAIChat({
|
||||
for await (const ev of runToolAwareOpenAIChatStream({
|
||||
client,
|
||||
model: req.model,
|
||||
messages: req.messages,
|
||||
temperature: req.temperature,
|
||||
maxTokens: req.maxTokens,
|
||||
onToolEvent: (event) => {
|
||||
toolEvents.push(event);
|
||||
},
|
||||
logContext: {
|
||||
provider: req.provider,
|
||||
model: req.model,
|
||||
chatId,
|
||||
},
|
||||
});
|
||||
raw = r.raw;
|
||||
text = r.text;
|
||||
usage = r.usage;
|
||||
toolMessages = toolEvents.map((event) => buildToolLogMessageData(chatId, event));
|
||||
for (const event of toolEvents) {
|
||||
yield { type: "tool_call", event };
|
||||
}
|
||||
if (text) {
|
||||
yield { type: "delta", text };
|
||||
})) {
|
||||
if (ev.type === "delta") {
|
||||
text += ev.text;
|
||||
yield { type: "delta", text: ev.text };
|
||||
continue;
|
||||
}
|
||||
|
||||
if (ev.type === "tool_call") {
|
||||
toolMessages.push(buildToolLogMessageData(chatId, ev.event));
|
||||
yield { type: "tool_call", event: ev.event };
|
||||
continue;
|
||||
}
|
||||
|
||||
raw = ev.result.raw;
|
||||
usage = ev.result.usage;
|
||||
text = ev.result.text;
|
||||
}
|
||||
} else if (req.provider === "anthropic") {
|
||||
const client = anthropicClient();
|
||||
|
||||
Reference in New Issue
Block a user