big backend refactor
This commit is contained in:
332
server/src/llm/protocols/responses-api.ts
Normal file
332
server/src/llm/protocols/responses-api.ts
Normal file
@@ -0,0 +1,332 @@
|
||||
import {
|
||||
appendDanglingToolIntentCorrection,
|
||||
buildChatToolSystemPrompt,
|
||||
executeToolCallAndBuildEvent,
|
||||
getEnabledChatTools,
|
||||
getUnstreamedText,
|
||||
looksLikeDanglingToolIntent,
|
||||
MAX_DANGLING_TOOL_INTENT_RETRIES,
|
||||
MAX_TOOL_ROUNDS,
|
||||
prepareToolCallExecution,
|
||||
type NormalizedToolCall,
|
||||
type ToolAwareCompletionParams,
|
||||
type ToolAwareCompletionResult,
|
||||
type ToolAwareStreamingEvent,
|
||||
type ToolAwareUsage,
|
||||
type ToolExecutionEvent,
|
||||
} from "../chat-tools.js";
|
||||
import {
|
||||
buildImageSummaryText,
|
||||
buildSystemPromptAugmentationMessage,
|
||||
buildTextAttachmentPrompt,
|
||||
getImageAttachments,
|
||||
getTextAttachments,
|
||||
} from "../message-content.js";
|
||||
import type { ChatMessage } from "../types.js";
|
||||
|
||||
function toResponsesTools(tools: any[]) {
|
||||
return tools.map((tool) => {
|
||||
if (tool?.type !== "function") return tool;
|
||||
return {
|
||||
type: "function",
|
||||
name: tool.function.name,
|
||||
description: tool.function.description,
|
||||
parameters: tool.function.parameters,
|
||||
strict: false,
|
||||
};
|
||||
});
|
||||
}
|
||||
|
||||
function toContentParts(message: ChatMessage) {
|
||||
const imageAttachments = getImageAttachments(message);
|
||||
const textAttachments = getTextAttachments(message);
|
||||
if (!imageAttachments.length && !textAttachments.length) {
|
||||
return message.content;
|
||||
}
|
||||
|
||||
const parts: Array<Record<string, unknown>> = [];
|
||||
for (const attachment of imageAttachments) {
|
||||
parts.push({
|
||||
type: "input_image",
|
||||
image_url: attachment.dataUrl,
|
||||
detail: "auto",
|
||||
});
|
||||
}
|
||||
|
||||
const imageSummary = buildImageSummaryText(imageAttachments);
|
||||
if (imageSummary) {
|
||||
parts.push({ type: "input_text", text: imageSummary });
|
||||
}
|
||||
|
||||
for (const attachment of textAttachments) {
|
||||
parts.push({ type: "input_text", text: buildTextAttachmentPrompt(attachment) });
|
||||
}
|
||||
|
||||
if (message.content.trim()) {
|
||||
parts.push({ type: "input_text", text: message.content });
|
||||
}
|
||||
|
||||
if (parts.length === 1 && parts[0]?.type === "input_text" && typeof parts[0].text === "string") {
|
||||
return parts[0].text;
|
||||
}
|
||||
|
||||
return parts;
|
||||
}
|
||||
|
||||
function buildInputMessage(message: ChatMessage) {
|
||||
if (message.role === "tool") {
|
||||
const name = message.name?.trim() || "tool";
|
||||
return {
|
||||
role: "user",
|
||||
content: `Tool output (${name}):\n${message.content}`,
|
||||
};
|
||||
}
|
||||
|
||||
return {
|
||||
role: message.role,
|
||||
content: toContentParts(message),
|
||||
};
|
||||
}
|
||||
|
||||
function normalizeInput(messages: ChatMessage[], userLocation?: string, params: Pick<ToolAwareCompletionParams, "enabledTools"> = {}) {
|
||||
const normalized = messages.map((message) => buildInputMessage(message));
|
||||
return [{ role: "system", content: buildChatToolSystemPrompt(params) }, buildSystemPromptAugmentationMessage(userLocation), ...normalized];
|
||||
}
|
||||
|
||||
function mergeUsage(acc: Required<ToolAwareUsage>, usage: any) {
|
||||
if (!usage) return false;
|
||||
acc.inputTokens += usage.input_tokens ?? 0;
|
||||
acc.outputTokens += usage.output_tokens ?? 0;
|
||||
acc.totalTokens += usage.total_tokens ?? 0;
|
||||
return true;
|
||||
}
|
||||
|
||||
function getOutputItems(response: any) {
|
||||
return Array.isArray(response?.output) ? response.output : [];
|
||||
}
|
||||
|
||||
function extractText(response: any, fallback = "") {
|
||||
if (typeof response?.output_text === "string") return response.output_text;
|
||||
|
||||
const parts: string[] = [];
|
||||
for (const item of getOutputItems(response)) {
|
||||
if (item?.type !== "message" || !Array.isArray(item.content)) continue;
|
||||
for (const content of item.content) {
|
||||
if (content?.type === "output_text" && typeof content.text === "string") {
|
||||
parts.push(content.text);
|
||||
} else if (content?.type === "refusal" && typeof content.refusal === "string") {
|
||||
parts.push(content.refusal);
|
||||
}
|
||||
}
|
||||
}
|
||||
return parts.join("") || fallback;
|
||||
}
|
||||
|
||||
function getFailureMessage(response: any) {
|
||||
if (response?.status !== "failed" && response?.status !== "incomplete") return null;
|
||||
const errorMessage = typeof response?.error?.message === "string" ? response.error.message : null;
|
||||
const incompleteReason = typeof response?.incomplete_details?.reason === "string" ? response.incomplete_details.reason : null;
|
||||
return errorMessage ?? (incompleteReason ? `Response incomplete: ${incompleteReason}` : `Response ${response.status}.`);
|
||||
}
|
||||
|
||||
function normalizeToolCalls(outputItems: any[], round: number): NormalizedToolCall[] {
|
||||
return outputItems
|
||||
.filter((item) => item?.type === "function_call")
|
||||
.map((call: any, index: number) => ({
|
||||
id: call.call_id ?? call.id ?? `tool_call_${round}_${index}`,
|
||||
name: call.name ?? "unknown_tool",
|
||||
arguments: call.arguments ?? "{}",
|
||||
}));
|
||||
}
|
||||
|
||||
export async function completeWithResponsesApi(params: ToolAwareCompletionParams): Promise<ToolAwareCompletionResult> {
|
||||
const enabledTools = getEnabledChatTools(params);
|
||||
const input: any[] = normalizeInput(params.messages, params.userLocation, params);
|
||||
const rawResponses: unknown[] = [];
|
||||
const toolEvents: ToolExecutionEvent[] = [];
|
||||
const usageAcc: Required<ToolAwareUsage> = { inputTokens: 0, outputTokens: 0, totalTokens: 0 };
|
||||
let sawUsage = false;
|
||||
let totalToolCalls = 0;
|
||||
let danglingToolIntentRetries = 0;
|
||||
|
||||
for (let round = 0; round < MAX_TOOL_ROUNDS; round += 1) {
|
||||
const response = await params.client.responses.create({
|
||||
model: params.model,
|
||||
input,
|
||||
temperature: params.temperature,
|
||||
max_output_tokens: params.maxTokens,
|
||||
tools: toResponsesTools(enabledTools),
|
||||
tool_choice: "auto",
|
||||
parallel_tool_calls: true,
|
||||
store: true,
|
||||
} as any);
|
||||
rawResponses.push(response);
|
||||
sawUsage = mergeUsage(usageAcc, response?.usage) || sawUsage;
|
||||
|
||||
const failureMessage = getFailureMessage(response);
|
||||
if (failureMessage) {
|
||||
throw new Error(failureMessage);
|
||||
}
|
||||
|
||||
const outputItems = getOutputItems(response);
|
||||
const normalizedToolCalls = normalizeToolCalls(outputItems, round);
|
||||
if (!normalizedToolCalls.length) {
|
||||
const text = extractText(response);
|
||||
if (danglingToolIntentRetries < MAX_DANGLING_TOOL_INTENT_RETRIES && looksLikeDanglingToolIntent(text)) {
|
||||
danglingToolIntentRetries += 1;
|
||||
appendDanglingToolIntentCorrection(input, text);
|
||||
continue;
|
||||
}
|
||||
return {
|
||||
text,
|
||||
usage: sawUsage ? usageAcc : undefined,
|
||||
raw: { responses: rawResponses, toolCallsUsed: totalToolCalls, api: "responses" },
|
||||
toolEvents,
|
||||
};
|
||||
}
|
||||
|
||||
totalToolCalls += normalizedToolCalls.length;
|
||||
input.push(...outputItems);
|
||||
|
||||
for (const call of normalizedToolCalls) {
|
||||
const { execution } = prepareToolCallExecution(call);
|
||||
const { event, toolResult } = await executeToolCallAndBuildEvent(call, execution, params);
|
||||
toolEvents.push(event);
|
||||
|
||||
input.push({
|
||||
type: "function_call_output",
|
||||
call_id: call.id,
|
||||
output: JSON.stringify(toolResult),
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
return {
|
||||
text: "I reached the tool-call limit while gathering information. Please narrow the request and try again.",
|
||||
usage: sawUsage ? usageAcc : undefined,
|
||||
raw: { responses: rawResponses, toolCallsUsed: totalToolCalls, toolCallLimitReached: true, api: "responses" },
|
||||
toolEvents,
|
||||
};
|
||||
}
|
||||
|
||||
export async function* streamWithResponsesApi(params: ToolAwareCompletionParams): AsyncGenerator<ToolAwareStreamingEvent> {
|
||||
const enabledTools = getEnabledChatTools(params);
|
||||
const input: any[] = normalizeInput(params.messages, params.userLocation, params);
|
||||
const rawResponses: unknown[] = [];
|
||||
const toolEvents: ToolExecutionEvent[] = [];
|
||||
const usageAcc: Required<ToolAwareUsage> = { inputTokens: 0, outputTokens: 0, totalTokens: 0 };
|
||||
let sawUsage = false;
|
||||
let totalToolCalls = 0;
|
||||
let danglingToolIntentRetries = 0;
|
||||
|
||||
for (let round = 0; round < MAX_TOOL_ROUNDS; round += 1) {
|
||||
const stream = await params.client.responses.create({
|
||||
model: params.model,
|
||||
input,
|
||||
temperature: params.temperature,
|
||||
max_output_tokens: params.maxTokens,
|
||||
tools: toResponsesTools(enabledTools),
|
||||
tool_choice: "auto",
|
||||
parallel_tool_calls: true,
|
||||
store: true,
|
||||
stream: true,
|
||||
} as any);
|
||||
|
||||
let roundText = "";
|
||||
let streamedRoundText = "";
|
||||
let roundHasToolCalls = false;
|
||||
let canStreamRoundText = false;
|
||||
let completedResponse: any | null = null;
|
||||
const completedOutputItems: any[] = [];
|
||||
|
||||
for await (const event of stream as any as AsyncIterable<any>) {
|
||||
rawResponses.push(event);
|
||||
|
||||
if (event?.type === "response.output_text.delta" && typeof event.delta === "string") {
|
||||
roundText += event.delta;
|
||||
if (canStreamRoundText && !roundHasToolCalls && event.delta.length) {
|
||||
streamedRoundText += event.delta;
|
||||
yield { type: "delta", text: event.delta };
|
||||
}
|
||||
} else if (event?.type === "response.output_item.added" && event.item) {
|
||||
if (event.item.type === "function_call") {
|
||||
roundHasToolCalls = true;
|
||||
canStreamRoundText = false;
|
||||
} else if (event.item.type === "message" && !roundHasToolCalls) {
|
||||
canStreamRoundText = true;
|
||||
}
|
||||
} else if (event?.type === "response.output_item.done" && event.item) {
|
||||
completedOutputItems[event.output_index ?? completedOutputItems.length] = event.item;
|
||||
if (event.item.type === "function_call") {
|
||||
roundHasToolCalls = true;
|
||||
canStreamRoundText = false;
|
||||
}
|
||||
} else if (event?.type === "response.completed") {
|
||||
completedResponse = event.response;
|
||||
sawUsage = mergeUsage(usageAcc, event.response?.usage) || sawUsage;
|
||||
} else if (event?.type === "response.failed" || event?.type === "response.incomplete") {
|
||||
completedResponse = event.response;
|
||||
sawUsage = mergeUsage(usageAcc, event.response?.usage) || sawUsage;
|
||||
} else if (event?.type === "error") {
|
||||
throw new Error(event.message ?? "Responses stream failed.");
|
||||
}
|
||||
}
|
||||
|
||||
const failureMessage = getFailureMessage(completedResponse);
|
||||
if (failureMessage) {
|
||||
throw new Error(failureMessage);
|
||||
}
|
||||
|
||||
const outputItems = getOutputItems(completedResponse);
|
||||
const responseOutputItems = outputItems.length ? outputItems : completedOutputItems.filter(Boolean);
|
||||
const normalizedToolCalls = normalizeToolCalls(responseOutputItems, round);
|
||||
if (!normalizedToolCalls.length) {
|
||||
const text = extractText(completedResponse, roundText);
|
||||
if (!streamedRoundText && danglingToolIntentRetries < MAX_DANGLING_TOOL_INTENT_RETRIES && looksLikeDanglingToolIntent(text)) {
|
||||
danglingToolIntentRetries += 1;
|
||||
appendDanglingToolIntentCorrection(input, text);
|
||||
continue;
|
||||
}
|
||||
const unstreamedText = getUnstreamedText(text, streamedRoundText);
|
||||
if (unstreamedText) {
|
||||
yield { type: "delta", text: unstreamedText };
|
||||
}
|
||||
yield {
|
||||
type: "done",
|
||||
result: {
|
||||
text,
|
||||
usage: sawUsage ? usageAcc : undefined,
|
||||
raw: { streamed: true, responses: rawResponses, toolCallsUsed: totalToolCalls, api: "responses" },
|
||||
toolEvents,
|
||||
},
|
||||
};
|
||||
return;
|
||||
}
|
||||
|
||||
totalToolCalls += normalizedToolCalls.length;
|
||||
input.push(...responseOutputItems);
|
||||
|
||||
for (const call of normalizedToolCalls) {
|
||||
const { event: initiatedEvent, execution } = prepareToolCallExecution(call);
|
||||
yield { type: "tool_call", event: initiatedEvent };
|
||||
const { event, toolResult } = await executeToolCallAndBuildEvent(call, execution, params);
|
||||
toolEvents.push(event);
|
||||
yield { type: "tool_call", event };
|
||||
input.push({
|
||||
type: "function_call_output",
|
||||
call_id: call.id,
|
||||
output: JSON.stringify(toolResult),
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
yield {
|
||||
type: "done",
|
||||
result: {
|
||||
text: "I reached the tool-call limit while gathering information. Please narrow the request and try again.",
|
||||
usage: sawUsage ? usageAcc : undefined,
|
||||
raw: { streamed: true, responses: rawResponses, toolCallsUsed: totalToolCalls, toolCallLimitReached: true, api: "responses" },
|
||||
toolEvents,
|
||||
},
|
||||
};
|
||||
}
|
||||
Reference in New Issue
Block a user