Files
Sybil-2/server/src/llm/protocols/messages-api.ts
2026-06-13 12:02:22 -07:00

471 lines
16 KiB
TypeScript

import {
buildChatToolSystemPrompt,
executeToolCallAndBuildEvent,
getEnabledChatTools,
looksLikeDanglingToolIntent,
MAX_DANGLING_TOOL_INTENT_RETRIES,
MAX_TOOL_ROUNDS,
parseToolArgs,
prepareToolCallExecution,
type NormalizedToolCall,
type ToolAwareCompletionParams,
type ToolAwareCompletionResult,
type ToolAwareStreamingEvent,
type ToolAwareUsage,
type ToolExecutionEvent,
type ToolRunOutcome,
} from "../chat-tools.js";
import {
buildImageSummaryText,
buildTextAttachmentPrompt,
buildTopLevelSystemPrompt,
getImageAttachments,
getTextAttachments,
parseImageDataUrl,
} from "../message-content.js";
import type { ChatMessage } from "../types.js";
const INTERNAL_CORRECTION =
"Internal correction: the previous assistant message claimed it would run a tool, but no tool call was made. If the task needs an available tool, call it now. Otherwise provide the final answer directly without saying you will run a tool.";
function toTools(tools: any[]) {
return tools
.map((tool) => {
if (tool?.type !== "function") return null;
return {
name: tool.function.name,
description: tool.function.description,
input_schema: tool.function.parameters,
};
})
.filter(Boolean);
}
function toContentBlocks(message: ChatMessage) {
const imageAttachments = getImageAttachments(message);
const textAttachments = getTextAttachments(message);
if (!imageAttachments.length && !textAttachments.length) {
return message.content;
}
const blocks: Array<Record<string, unknown>> = [];
for (const attachment of imageAttachments) {
const source = parseImageDataUrl(attachment);
blocks.push({
type: "image",
source: {
type: "base64",
media_type: source.mediaType,
data: source.data,
},
});
}
const imageSummary = buildImageSummaryText(imageAttachments);
if (imageSummary) {
blocks.push({ type: "text", text: imageSummary });
}
for (const attachment of textAttachments) {
blocks.push({ type: "text", text: buildTextAttachmentPrompt(attachment) });
}
if (message.content.trim()) {
blocks.push({ type: "text", text: message.content });
}
if (blocks.length === 1 && blocks[0]?.type === "text" && typeof blocks[0].text === "string") {
return blocks[0].text;
}
return blocks;
}
function buildConversationMessage(message: ChatMessage) {
if (message.role === "system") {
throw new Error("System messages must be handled separately for top-level-system protocols.");
}
if (message.role === "tool") {
const name = message.name?.trim() || "tool";
return {
role: "user",
content: `Tool output (${name}):\n${message.content}`,
};
}
return {
role: message.role === "assistant" ? "assistant" : "user",
content: toContentBlocks(message),
};
}
function buildBaseMessages(params: ToolAwareCompletionParams) {
return params.messages.filter((message) => message.role !== "system").map((message) => buildConversationMessage(message));
}
function stringifyToolInput(input: unknown) {
if (typeof input === "string") return input;
try {
return JSON.stringify(input ?? {});
} catch {
return "{}";
}
}
function normalizeToolCalls(content: any[], round: number): NormalizedToolCall[] {
return content
.filter((item) => item?.type === "tool_use")
.map((call: any, index: number) => ({
id: call?.id ?? `tool_call_${round}_${index}`,
name: call?.name ?? "unknown_tool",
arguments: stringifyToolInput(call?.input),
}));
}
function extractText(response: any) {
if (!Array.isArray(response?.content)) return "";
return response.content
.map((content: any) => (content?.type === "text" && typeof content.text === "string" ? content.text : ""))
.join("")
.trim();
}
function buildToolResultBlock(call: NormalizedToolCall, toolResult: ToolRunOutcome) {
return {
type: "tool_result",
tool_use_id: call.id,
content: JSON.stringify(toolResult),
is_error: !toolResult.ok,
};
}
function appendCorrection(conversation: any[], text: string) {
conversation.push({ role: "assistant", content: text });
conversation.push({
role: "user",
content: INTERNAL_CORRECTION,
});
}
function mergeUsage(acc: Required<ToolAwareUsage>, usage: any) {
if (!usage) return false;
const inputTokens = usage.input_tokens ?? 0;
const outputTokens = usage.output_tokens ?? 0;
acc.inputTokens += inputTokens;
acc.outputTokens += outputTokens;
acc.totalTokens += inputTokens + outputTokens;
return true;
}
export async function completeWithMessagesApi(params: ToolAwareCompletionParams): Promise<ToolAwareCompletionResult> {
const enabledTools = getEnabledChatTools(params);
if (!enabledTools.length) {
const response = await params.client.messages.create({
model: params.model,
system: buildTopLevelSystemPrompt(params.messages, params.userLocation),
max_tokens: params.maxTokens ?? 1024,
temperature: params.temperature,
messages: buildBaseMessages(params),
} as any);
const usageAcc: Required<ToolAwareUsage> = { inputTokens: 0, outputTokens: 0, totalTokens: 0 };
const sawUsage = mergeUsage(usageAcc, response?.usage);
return {
text: extractText(response),
usage: sawUsage ? usageAcc : undefined,
raw: { response, api: "messages" },
toolEvents: [],
};
}
const conversation: any[] = buildBaseMessages(params);
const rawResponses: unknown[] = [];
const toolEvents: ToolExecutionEvent[] = [];
const usageAcc: Required<ToolAwareUsage> = { inputTokens: 0, outputTokens: 0, totalTokens: 0 };
let sawUsage = false;
let totalToolCalls = 0;
let danglingToolIntentRetries = 0;
for (let round = 0; round < MAX_TOOL_ROUNDS; round += 1) {
const response = await params.client.messages.create({
model: params.model,
system: buildTopLevelSystemPrompt(params.messages, params.userLocation, buildChatToolSystemPrompt(params)),
max_tokens: params.maxTokens ?? 1024,
temperature: params.temperature,
messages: conversation,
tools: toTools(enabledTools),
tool_choice: { type: "auto" },
} as any);
rawResponses.push(response);
sawUsage = mergeUsage(usageAcc, response?.usage) || sawUsage;
const content = Array.isArray(response?.content) ? response.content : [];
const normalizedToolCalls = normalizeToolCalls(content, round);
if (!normalizedToolCalls.length) {
const text = extractText(response);
if (danglingToolIntentRetries < MAX_DANGLING_TOOL_INTENT_RETRIES && looksLikeDanglingToolIntent(text)) {
danglingToolIntentRetries += 1;
appendCorrection(conversation, text);
continue;
}
return {
text,
usage: sawUsage ? usageAcc : undefined,
raw: { responses: rawResponses, toolCallsUsed: totalToolCalls, api: "messages" },
toolEvents,
};
}
totalToolCalls += normalizedToolCalls.length;
conversation.push({
role: "assistant",
content,
});
const toolResultBlocks: any[] = [];
for (const call of normalizedToolCalls) {
const { execution } = prepareToolCallExecution(call);
const { event, toolResult } = await executeToolCallAndBuildEvent(call, execution, params);
toolEvents.push(event);
toolResultBlocks.push(buildToolResultBlock(call, toolResult));
}
conversation.push({
role: "user",
content: toolResultBlocks,
});
}
return {
text: "I reached the tool-call limit while gathering information. Please narrow the request and try again.",
usage: sawUsage ? usageAcc : undefined,
raw: { responses: rawResponses, toolCallsUsed: totalToolCalls, toolCallLimitReached: true, api: "messages" },
toolEvents,
};
}
export async function* streamWithMessagesApi(params: ToolAwareCompletionParams): AsyncGenerator<ToolAwareStreamingEvent> {
const enabledTools = getEnabledChatTools(params);
if (!enabledTools.length) {
const rawResponses: unknown[] = [];
const usageAcc: Required<ToolAwareUsage> = { inputTokens: 0, outputTokens: 0, totalTokens: 0 };
let sawUsage = false;
let roundInputTokens = 0;
let roundOutputTokens = 0;
let text = "";
const stream = await params.client.messages.create({
model: params.model,
system: buildTopLevelSystemPrompt(params.messages, params.userLocation),
max_tokens: params.maxTokens ?? 1024,
temperature: params.temperature,
messages: buildBaseMessages(params),
stream: true,
} as any);
for await (const ev of stream as any as AsyncIterable<any>) {
rawResponses.push(ev);
if (ev?.type === "message_start" && ev?.message?.usage) {
roundInputTokens = ev.message.usage.input_tokens ?? roundInputTokens;
sawUsage = true;
}
if (ev?.type === "content_block_delta" && ev?.delta?.type === "text_delta") {
const delta = ev.delta.text ?? "";
if (delta) {
text += delta;
yield { type: "delta", text: delta };
}
}
if (ev?.type === "message_delta" && ev.usage) {
roundInputTokens = ev.usage.input_tokens ?? roundInputTokens;
roundOutputTokens = ev.usage.output_tokens ?? roundOutputTokens;
sawUsage = true;
}
}
if (sawUsage) {
usageAcc.inputTokens += roundInputTokens;
usageAcc.outputTokens += roundOutputTokens;
usageAcc.totalTokens += roundInputTokens + roundOutputTokens;
}
yield {
type: "done",
result: {
text,
usage: sawUsage ? usageAcc : undefined,
raw: { streamed: true, responses: rawResponses, toolCallsUsed: 0, api: "messages" },
toolEvents: [],
},
};
return;
}
const conversation: any[] = buildBaseMessages(params);
const rawResponses: unknown[] = [];
const toolEvents: ToolExecutionEvent[] = [];
const usageAcc: Required<ToolAwareUsage> = { inputTokens: 0, outputTokens: 0, totalTokens: 0 };
let sawUsage = false;
let totalToolCalls = 0;
let danglingToolIntentRetries = 0;
for (let round = 0; round < MAX_TOOL_ROUNDS; round += 1) {
const stream = await params.client.messages.create({
model: params.model,
system: buildTopLevelSystemPrompt(params.messages, params.userLocation, buildChatToolSystemPrompt(params)),
max_tokens: params.maxTokens ?? 1024,
temperature: params.temperature,
messages: conversation,
tools: toTools(enabledTools),
tool_choice: { type: "auto" },
stream: true,
} as any);
const contentByIndex = new Map<number, any>();
const toolArgumentByIndex = new Map<number, string>();
let roundText = "";
let roundHasToolCalls = false;
let roundInputTokens = 0;
let roundOutputTokens = 0;
let sawRoundUsage = false;
for await (const ev of stream as any as AsyncIterable<any>) {
rawResponses.push(ev);
if (ev?.type === "message_start" && ev?.message?.usage) {
roundInputTokens = ev.message.usage.input_tokens ?? roundInputTokens;
sawRoundUsage = true;
}
if (ev?.type === "content_block_start" && typeof ev.index === "number") {
const block = ev.content_block ?? {};
if (block.type === "tool_use") {
roundHasToolCalls = true;
contentByIndex.set(ev.index, {
type: "tool_use",
id: block.id,
name: block.name,
input: block.input ?? {},
});
toolArgumentByIndex.set(ev.index, "");
} else if (block.type === "text") {
contentByIndex.set(ev.index, {
type: "text",
text: typeof block.text === "string" ? block.text : "",
});
} else if (block.type) {
contentByIndex.set(ev.index, block);
}
}
if (ev?.type === "content_block_delta" && typeof ev.index === "number") {
if (ev.delta?.type === "text_delta") {
const delta = typeof ev.delta.text === "string" ? ev.delta.text : "";
if (delta) {
const block = contentByIndex.get(ev.index) ?? { type: "text", text: "" };
if (block.type === "text") {
block.text = `${typeof block.text === "string" ? block.text : ""}${delta}`;
contentByIndex.set(ev.index, block);
}
roundText += delta;
}
} else if (ev.delta?.type === "input_json_delta") {
roundHasToolCalls = true;
const partialJson = typeof ev.delta.partial_json === "string" ? ev.delta.partial_json : "";
toolArgumentByIndex.set(ev.index, `${toolArgumentByIndex.get(ev.index) ?? ""}${partialJson}`);
}
}
if (ev?.type === "content_block_stop" && typeof ev.index === "number") {
const block = contentByIndex.get(ev.index);
if (block?.type === "tool_use") {
const rawArguments = toolArgumentByIndex.get(ev.index) || stringifyToolInput(block.input);
try {
block.input = parseToolArgs(rawArguments);
} catch {
block.input = {};
}
contentByIndex.set(ev.index, block);
}
}
if (ev?.type === "message_delta" && ev.usage) {
roundInputTokens = ev.usage.input_tokens ?? roundInputTokens;
roundOutputTokens = ev.usage.output_tokens ?? roundOutputTokens;
sawRoundUsage = true;
}
}
if (sawRoundUsage) {
usageAcc.inputTokens += roundInputTokens;
usageAcc.outputTokens += roundOutputTokens;
usageAcc.totalTokens += roundInputTokens + roundOutputTokens;
sawUsage = true;
}
const indexedContent = [...contentByIndex.entries()].sort((a, b) => a[0] - b[0]);
const assistantContent = indexedContent.map(([, block]) => block);
const normalizedToolCalls: NormalizedToolCall[] = indexedContent
.filter(([, block]) => block?.type === "tool_use")
.map(([index, block], callIndex) => ({
id: block.id ?? `tool_call_${round}_${callIndex}`,
name: block.name ?? "unknown_tool",
arguments: toolArgumentByIndex.get(index) || stringifyToolInput(block.input),
}));
if (!normalizedToolCalls.length) {
if (danglingToolIntentRetries < MAX_DANGLING_TOOL_INTENT_RETRIES && looksLikeDanglingToolIntent(roundText)) {
danglingToolIntentRetries += 1;
appendCorrection(conversation, roundText);
continue;
}
if (roundText) {
yield { type: "delta", text: roundText };
}
yield {
type: "done",
result: {
text: roundText,
usage: sawUsage ? usageAcc : undefined,
raw: { streamed: true, responses: rawResponses, toolCallsUsed: totalToolCalls, api: "messages" },
toolEvents,
},
};
return;
}
totalToolCalls += normalizedToolCalls.length;
conversation.push({
role: "assistant",
content: assistantContent,
});
const toolResultBlocks: any[] = [];
for (const call of normalizedToolCalls) {
const { event: initiatedEvent, execution } = prepareToolCallExecution(call);
yield { type: "tool_call", event: initiatedEvent };
const { event, toolResult } = await executeToolCallAndBuildEvent(call, execution, params);
toolEvents.push(event);
yield { type: "tool_call", event };
toolResultBlocks.push(buildToolResultBlock(call, toolResult));
}
conversation.push({
role: "user",
content: toolResultBlocks,
});
}
yield {
type: "done",
result: {
text: "I reached the tool-call limit while gathering information. Please narrow the request and try again.",
usage: sawUsage ? usageAcc : undefined,
raw: { streamed: true, responses: rawResponses, toolCallsUsed: totalToolCalls, toolCallLimitReached: true, api: "messages" },
toolEvents,
},
};
}