333 lines
12 KiB
TypeScript
333 lines
12 KiB
TypeScript
|
|
import {
|
||
|
|
appendDanglingToolIntentCorrection,
|
||
|
|
buildChatToolSystemPrompt,
|
||
|
|
executeToolCallAndBuildEvent,
|
||
|
|
getEnabledChatTools,
|
||
|
|
getUnstreamedText,
|
||
|
|
looksLikeDanglingToolIntent,
|
||
|
|
MAX_DANGLING_TOOL_INTENT_RETRIES,
|
||
|
|
MAX_TOOL_ROUNDS,
|
||
|
|
prepareToolCallExecution,
|
||
|
|
type NormalizedToolCall,
|
||
|
|
type ToolAwareCompletionParams,
|
||
|
|
type ToolAwareCompletionResult,
|
||
|
|
type ToolAwareStreamingEvent,
|
||
|
|
type ToolAwareUsage,
|
||
|
|
type ToolExecutionEvent,
|
||
|
|
} from "../chat-tools.js";
|
||
|
|
import {
|
||
|
|
buildImageSummaryText,
|
||
|
|
buildSystemPromptAugmentationMessage,
|
||
|
|
buildTextAttachmentPrompt,
|
||
|
|
getImageAttachments,
|
||
|
|
getTextAttachments,
|
||
|
|
} from "../message-content.js";
|
||
|
|
import type { ChatMessage } from "../types.js";
|
||
|
|
|
||
|
|
function toResponsesTools(tools: any[]) {
|
||
|
|
return tools.map((tool) => {
|
||
|
|
if (tool?.type !== "function") return tool;
|
||
|
|
return {
|
||
|
|
type: "function",
|
||
|
|
name: tool.function.name,
|
||
|
|
description: tool.function.description,
|
||
|
|
parameters: tool.function.parameters,
|
||
|
|
strict: false,
|
||
|
|
};
|
||
|
|
});
|
||
|
|
}
|
||
|
|
|
||
|
|
function toContentParts(message: ChatMessage) {
|
||
|
|
const imageAttachments = getImageAttachments(message);
|
||
|
|
const textAttachments = getTextAttachments(message);
|
||
|
|
if (!imageAttachments.length && !textAttachments.length) {
|
||
|
|
return message.content;
|
||
|
|
}
|
||
|
|
|
||
|
|
const parts: Array<Record<string, unknown>> = [];
|
||
|
|
for (const attachment of imageAttachments) {
|
||
|
|
parts.push({
|
||
|
|
type: "input_image",
|
||
|
|
image_url: attachment.dataUrl,
|
||
|
|
detail: "auto",
|
||
|
|
});
|
||
|
|
}
|
||
|
|
|
||
|
|
const imageSummary = buildImageSummaryText(imageAttachments);
|
||
|
|
if (imageSummary) {
|
||
|
|
parts.push({ type: "input_text", text: imageSummary });
|
||
|
|
}
|
||
|
|
|
||
|
|
for (const attachment of textAttachments) {
|
||
|
|
parts.push({ type: "input_text", text: buildTextAttachmentPrompt(attachment) });
|
||
|
|
}
|
||
|
|
|
||
|
|
if (message.content.trim()) {
|
||
|
|
parts.push({ type: "input_text", text: message.content });
|
||
|
|
}
|
||
|
|
|
||
|
|
if (parts.length === 1 && parts[0]?.type === "input_text" && typeof parts[0].text === "string") {
|
||
|
|
return parts[0].text;
|
||
|
|
}
|
||
|
|
|
||
|
|
return parts;
|
||
|
|
}
|
||
|
|
|
||
|
|
function buildInputMessage(message: ChatMessage) {
|
||
|
|
if (message.role === "tool") {
|
||
|
|
const name = message.name?.trim() || "tool";
|
||
|
|
return {
|
||
|
|
role: "user",
|
||
|
|
content: `Tool output (${name}):\n${message.content}`,
|
||
|
|
};
|
||
|
|
}
|
||
|
|
|
||
|
|
return {
|
||
|
|
role: message.role,
|
||
|
|
content: toContentParts(message),
|
||
|
|
};
|
||
|
|
}
|
||
|
|
|
||
|
|
function normalizeInput(messages: ChatMessage[], userLocation?: string, params: Pick<ToolAwareCompletionParams, "enabledTools"> = {}) {
|
||
|
|
const normalized = messages.map((message) => buildInputMessage(message));
|
||
|
|
return [{ role: "system", content: buildChatToolSystemPrompt(params) }, buildSystemPromptAugmentationMessage(userLocation), ...normalized];
|
||
|
|
}
|
||
|
|
|
||
|
|
function mergeUsage(acc: Required<ToolAwareUsage>, usage: any) {
|
||
|
|
if (!usage) return false;
|
||
|
|
acc.inputTokens += usage.input_tokens ?? 0;
|
||
|
|
acc.outputTokens += usage.output_tokens ?? 0;
|
||
|
|
acc.totalTokens += usage.total_tokens ?? 0;
|
||
|
|
return true;
|
||
|
|
}
|
||
|
|
|
||
|
|
function getOutputItems(response: any) {
|
||
|
|
return Array.isArray(response?.output) ? response.output : [];
|
||
|
|
}
|
||
|
|
|
||
|
|
function extractText(response: any, fallback = "") {
|
||
|
|
if (typeof response?.output_text === "string") return response.output_text;
|
||
|
|
|
||
|
|
const parts: string[] = [];
|
||
|
|
for (const item of getOutputItems(response)) {
|
||
|
|
if (item?.type !== "message" || !Array.isArray(item.content)) continue;
|
||
|
|
for (const content of item.content) {
|
||
|
|
if (content?.type === "output_text" && typeof content.text === "string") {
|
||
|
|
parts.push(content.text);
|
||
|
|
} else if (content?.type === "refusal" && typeof content.refusal === "string") {
|
||
|
|
parts.push(content.refusal);
|
||
|
|
}
|
||
|
|
}
|
||
|
|
}
|
||
|
|
return parts.join("") || fallback;
|
||
|
|
}
|
||
|
|
|
||
|
|
function getFailureMessage(response: any) {
|
||
|
|
if (response?.status !== "failed" && response?.status !== "incomplete") return null;
|
||
|
|
const errorMessage = typeof response?.error?.message === "string" ? response.error.message : null;
|
||
|
|
const incompleteReason = typeof response?.incomplete_details?.reason === "string" ? response.incomplete_details.reason : null;
|
||
|
|
return errorMessage ?? (incompleteReason ? `Response incomplete: ${incompleteReason}` : `Response ${response.status}.`);
|
||
|
|
}
|
||
|
|
|
||
|
|
function normalizeToolCalls(outputItems: any[], round: number): NormalizedToolCall[] {
|
||
|
|
return outputItems
|
||
|
|
.filter((item) => item?.type === "function_call")
|
||
|
|
.map((call: any, index: number) => ({
|
||
|
|
id: call.call_id ?? call.id ?? `tool_call_${round}_${index}`,
|
||
|
|
name: call.name ?? "unknown_tool",
|
||
|
|
arguments: call.arguments ?? "{}",
|
||
|
|
}));
|
||
|
|
}
|
||
|
|
|
||
|
|
export async function completeWithResponsesApi(params: ToolAwareCompletionParams): Promise<ToolAwareCompletionResult> {
|
||
|
|
const enabledTools = getEnabledChatTools(params);
|
||
|
|
const input: any[] = normalizeInput(params.messages, params.userLocation, params);
|
||
|
|
const rawResponses: unknown[] = [];
|
||
|
|
const toolEvents: ToolExecutionEvent[] = [];
|
||
|
|
const usageAcc: Required<ToolAwareUsage> = { inputTokens: 0, outputTokens: 0, totalTokens: 0 };
|
||
|
|
let sawUsage = false;
|
||
|
|
let totalToolCalls = 0;
|
||
|
|
let danglingToolIntentRetries = 0;
|
||
|
|
|
||
|
|
for (let round = 0; round < MAX_TOOL_ROUNDS; round += 1) {
|
||
|
|
const response = await params.client.responses.create({
|
||
|
|
model: params.model,
|
||
|
|
input,
|
||
|
|
temperature: params.temperature,
|
||
|
|
max_output_tokens: params.maxTokens,
|
||
|
|
tools: toResponsesTools(enabledTools),
|
||
|
|
tool_choice: "auto",
|
||
|
|
parallel_tool_calls: true,
|
||
|
|
store: true,
|
||
|
|
} as any);
|
||
|
|
rawResponses.push(response);
|
||
|
|
sawUsage = mergeUsage(usageAcc, response?.usage) || sawUsage;
|
||
|
|
|
||
|
|
const failureMessage = getFailureMessage(response);
|
||
|
|
if (failureMessage) {
|
||
|
|
throw new Error(failureMessage);
|
||
|
|
}
|
||
|
|
|
||
|
|
const outputItems = getOutputItems(response);
|
||
|
|
const normalizedToolCalls = normalizeToolCalls(outputItems, round);
|
||
|
|
if (!normalizedToolCalls.length) {
|
||
|
|
const text = extractText(response);
|
||
|
|
if (danglingToolIntentRetries < MAX_DANGLING_TOOL_INTENT_RETRIES && looksLikeDanglingToolIntent(text)) {
|
||
|
|
danglingToolIntentRetries += 1;
|
||
|
|
appendDanglingToolIntentCorrection(input, text);
|
||
|
|
continue;
|
||
|
|
}
|
||
|
|
return {
|
||
|
|
text,
|
||
|
|
usage: sawUsage ? usageAcc : undefined,
|
||
|
|
raw: { responses: rawResponses, toolCallsUsed: totalToolCalls, api: "responses" },
|
||
|
|
toolEvents,
|
||
|
|
};
|
||
|
|
}
|
||
|
|
|
||
|
|
totalToolCalls += normalizedToolCalls.length;
|
||
|
|
input.push(...outputItems);
|
||
|
|
|
||
|
|
for (const call of normalizedToolCalls) {
|
||
|
|
const { execution } = prepareToolCallExecution(call);
|
||
|
|
const { event, toolResult } = await executeToolCallAndBuildEvent(call, execution, params);
|
||
|
|
toolEvents.push(event);
|
||
|
|
|
||
|
|
input.push({
|
||
|
|
type: "function_call_output",
|
||
|
|
call_id: call.id,
|
||
|
|
output: JSON.stringify(toolResult),
|
||
|
|
});
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
return {
|
||
|
|
text: "I reached the tool-call limit while gathering information. Please narrow the request and try again.",
|
||
|
|
usage: sawUsage ? usageAcc : undefined,
|
||
|
|
raw: { responses: rawResponses, toolCallsUsed: totalToolCalls, toolCallLimitReached: true, api: "responses" },
|
||
|
|
toolEvents,
|
||
|
|
};
|
||
|
|
}
|
||
|
|
|
||
|
|
export async function* streamWithResponsesApi(params: ToolAwareCompletionParams): AsyncGenerator<ToolAwareStreamingEvent> {
|
||
|
|
const enabledTools = getEnabledChatTools(params);
|
||
|
|
const input: any[] = normalizeInput(params.messages, params.userLocation, params);
|
||
|
|
const rawResponses: unknown[] = [];
|
||
|
|
const toolEvents: ToolExecutionEvent[] = [];
|
||
|
|
const usageAcc: Required<ToolAwareUsage> = { inputTokens: 0, outputTokens: 0, totalTokens: 0 };
|
||
|
|
let sawUsage = false;
|
||
|
|
let totalToolCalls = 0;
|
||
|
|
let danglingToolIntentRetries = 0;
|
||
|
|
|
||
|
|
for (let round = 0; round < MAX_TOOL_ROUNDS; round += 1) {
|
||
|
|
const stream = await params.client.responses.create({
|
||
|
|
model: params.model,
|
||
|
|
input,
|
||
|
|
temperature: params.temperature,
|
||
|
|
max_output_tokens: params.maxTokens,
|
||
|
|
tools: toResponsesTools(enabledTools),
|
||
|
|
tool_choice: "auto",
|
||
|
|
parallel_tool_calls: true,
|
||
|
|
store: true,
|
||
|
|
stream: true,
|
||
|
|
} as any);
|
||
|
|
|
||
|
|
let roundText = "";
|
||
|
|
let streamedRoundText = "";
|
||
|
|
let roundHasToolCalls = false;
|
||
|
|
let canStreamRoundText = false;
|
||
|
|
let completedResponse: any | null = null;
|
||
|
|
const completedOutputItems: any[] = [];
|
||
|
|
|
||
|
|
for await (const event of stream as any as AsyncIterable<any>) {
|
||
|
|
rawResponses.push(event);
|
||
|
|
|
||
|
|
if (event?.type === "response.output_text.delta" && typeof event.delta === "string") {
|
||
|
|
roundText += event.delta;
|
||
|
|
if (canStreamRoundText && !roundHasToolCalls && event.delta.length) {
|
||
|
|
streamedRoundText += event.delta;
|
||
|
|
yield { type: "delta", text: event.delta };
|
||
|
|
}
|
||
|
|
} else if (event?.type === "response.output_item.added" && event.item) {
|
||
|
|
if (event.item.type === "function_call") {
|
||
|
|
roundHasToolCalls = true;
|
||
|
|
canStreamRoundText = false;
|
||
|
|
} else if (event.item.type === "message" && !roundHasToolCalls) {
|
||
|
|
canStreamRoundText = true;
|
||
|
|
}
|
||
|
|
} else if (event?.type === "response.output_item.done" && event.item) {
|
||
|
|
completedOutputItems[event.output_index ?? completedOutputItems.length] = event.item;
|
||
|
|
if (event.item.type === "function_call") {
|
||
|
|
roundHasToolCalls = true;
|
||
|
|
canStreamRoundText = false;
|
||
|
|
}
|
||
|
|
} else if (event?.type === "response.completed") {
|
||
|
|
completedResponse = event.response;
|
||
|
|
sawUsage = mergeUsage(usageAcc, event.response?.usage) || sawUsage;
|
||
|
|
} else if (event?.type === "response.failed" || event?.type === "response.incomplete") {
|
||
|
|
completedResponse = event.response;
|
||
|
|
sawUsage = mergeUsage(usageAcc, event.response?.usage) || sawUsage;
|
||
|
|
} else if (event?.type === "error") {
|
||
|
|
throw new Error(event.message ?? "Responses stream failed.");
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
const failureMessage = getFailureMessage(completedResponse);
|
||
|
|
if (failureMessage) {
|
||
|
|
throw new Error(failureMessage);
|
||
|
|
}
|
||
|
|
|
||
|
|
const outputItems = getOutputItems(completedResponse);
|
||
|
|
const responseOutputItems = outputItems.length ? outputItems : completedOutputItems.filter(Boolean);
|
||
|
|
const normalizedToolCalls = normalizeToolCalls(responseOutputItems, round);
|
||
|
|
if (!normalizedToolCalls.length) {
|
||
|
|
const text = extractText(completedResponse, roundText);
|
||
|
|
if (!streamedRoundText && danglingToolIntentRetries < MAX_DANGLING_TOOL_INTENT_RETRIES && looksLikeDanglingToolIntent(text)) {
|
||
|
|
danglingToolIntentRetries += 1;
|
||
|
|
appendDanglingToolIntentCorrection(input, text);
|
||
|
|
continue;
|
||
|
|
}
|
||
|
|
const unstreamedText = getUnstreamedText(text, streamedRoundText);
|
||
|
|
if (unstreamedText) {
|
||
|
|
yield { type: "delta", text: unstreamedText };
|
||
|
|
}
|
||
|
|
yield {
|
||
|
|
type: "done",
|
||
|
|
result: {
|
||
|
|
text,
|
||
|
|
usage: sawUsage ? usageAcc : undefined,
|
||
|
|
raw: { streamed: true, responses: rawResponses, toolCallsUsed: totalToolCalls, api: "responses" },
|
||
|
|
toolEvents,
|
||
|
|
},
|
||
|
|
};
|
||
|
|
return;
|
||
|
|
}
|
||
|
|
|
||
|
|
totalToolCalls += normalizedToolCalls.length;
|
||
|
|
input.push(...responseOutputItems);
|
||
|
|
|
||
|
|
for (const call of normalizedToolCalls) {
|
||
|
|
const { event: initiatedEvent, execution } = prepareToolCallExecution(call);
|
||
|
|
yield { type: "tool_call", event: initiatedEvent };
|
||
|
|
const { event, toolResult } = await executeToolCallAndBuildEvent(call, execution, params);
|
||
|
|
toolEvents.push(event);
|
||
|
|
yield { type: "tool_call", event };
|
||
|
|
input.push({
|
||
|
|
type: "function_call_output",
|
||
|
|
call_id: call.id,
|
||
|
|
output: JSON.stringify(toolResult),
|
||
|
|
});
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
yield {
|
||
|
|
type: "done",
|
||
|
|
result: {
|
||
|
|
text: "I reached the tool-call limit while gathering information. Please narrow the request and try again.",
|
||
|
|
usage: sawUsage ? usageAcc : undefined,
|
||
|
|
raw: { streamed: true, responses: rawResponses, toolCallsUsed: totalToolCalls, toolCallLimitReached: true, api: "responses" },
|
||
|
|
toolEvents,
|
||
|
|
},
|
||
|
|
};
|
||
|
|
}
|