big backend refactor
This commit is contained in:
470
server/src/llm/protocols/messages-api.ts
Normal file
470
server/src/llm/protocols/messages-api.ts
Normal file
@@ -0,0 +1,470 @@
|
||||
import {
|
||||
buildChatToolSystemPrompt,
|
||||
executeToolCallAndBuildEvent,
|
||||
getEnabledChatTools,
|
||||
looksLikeDanglingToolIntent,
|
||||
MAX_DANGLING_TOOL_INTENT_RETRIES,
|
||||
MAX_TOOL_ROUNDS,
|
||||
parseToolArgs,
|
||||
prepareToolCallExecution,
|
||||
type NormalizedToolCall,
|
||||
type ToolAwareCompletionParams,
|
||||
type ToolAwareCompletionResult,
|
||||
type ToolAwareStreamingEvent,
|
||||
type ToolAwareUsage,
|
||||
type ToolExecutionEvent,
|
||||
type ToolRunOutcome,
|
||||
} from "../chat-tools.js";
|
||||
import {
|
||||
buildImageSummaryText,
|
||||
buildTextAttachmentPrompt,
|
||||
buildTopLevelSystemPrompt,
|
||||
getImageAttachments,
|
||||
getTextAttachments,
|
||||
parseImageDataUrl,
|
||||
} from "../message-content.js";
|
||||
import type { ChatMessage } from "../types.js";
|
||||
|
||||
const INTERNAL_CORRECTION =
|
||||
"Internal correction: the previous assistant message claimed it would run a tool, but no tool call was made. If the task needs an available tool, call it now. Otherwise provide the final answer directly without saying you will run a tool.";
|
||||
|
||||
function toTools(tools: any[]) {
|
||||
return tools
|
||||
.map((tool) => {
|
||||
if (tool?.type !== "function") return null;
|
||||
return {
|
||||
name: tool.function.name,
|
||||
description: tool.function.description,
|
||||
input_schema: tool.function.parameters,
|
||||
};
|
||||
})
|
||||
.filter(Boolean);
|
||||
}
|
||||
|
||||
function toContentBlocks(message: ChatMessage) {
|
||||
const imageAttachments = getImageAttachments(message);
|
||||
const textAttachments = getTextAttachments(message);
|
||||
if (!imageAttachments.length && !textAttachments.length) {
|
||||
return message.content;
|
||||
}
|
||||
|
||||
const blocks: Array<Record<string, unknown>> = [];
|
||||
for (const attachment of imageAttachments) {
|
||||
const source = parseImageDataUrl(attachment);
|
||||
blocks.push({
|
||||
type: "image",
|
||||
source: {
|
||||
type: "base64",
|
||||
media_type: source.mediaType,
|
||||
data: source.data,
|
||||
},
|
||||
});
|
||||
}
|
||||
|
||||
const imageSummary = buildImageSummaryText(imageAttachments);
|
||||
if (imageSummary) {
|
||||
blocks.push({ type: "text", text: imageSummary });
|
||||
}
|
||||
|
||||
for (const attachment of textAttachments) {
|
||||
blocks.push({ type: "text", text: buildTextAttachmentPrompt(attachment) });
|
||||
}
|
||||
|
||||
if (message.content.trim()) {
|
||||
blocks.push({ type: "text", text: message.content });
|
||||
}
|
||||
|
||||
if (blocks.length === 1 && blocks[0]?.type === "text" && typeof blocks[0].text === "string") {
|
||||
return blocks[0].text;
|
||||
}
|
||||
|
||||
return blocks;
|
||||
}
|
||||
|
||||
function buildConversationMessage(message: ChatMessage) {
|
||||
if (message.role === "system") {
|
||||
throw new Error("System messages must be handled separately for top-level-system protocols.");
|
||||
}
|
||||
|
||||
if (message.role === "tool") {
|
||||
const name = message.name?.trim() || "tool";
|
||||
return {
|
||||
role: "user",
|
||||
content: `Tool output (${name}):\n${message.content}`,
|
||||
};
|
||||
}
|
||||
|
||||
return {
|
||||
role: message.role === "assistant" ? "assistant" : "user",
|
||||
content: toContentBlocks(message),
|
||||
};
|
||||
}
|
||||
|
||||
function buildBaseMessages(params: ToolAwareCompletionParams) {
|
||||
return params.messages.filter((message) => message.role !== "system").map((message) => buildConversationMessage(message));
|
||||
}
|
||||
|
||||
function stringifyToolInput(input: unknown) {
|
||||
if (typeof input === "string") return input;
|
||||
try {
|
||||
return JSON.stringify(input ?? {});
|
||||
} catch {
|
||||
return "{}";
|
||||
}
|
||||
}
|
||||
|
||||
function normalizeToolCalls(content: any[], round: number): NormalizedToolCall[] {
|
||||
return content
|
||||
.filter((item) => item?.type === "tool_use")
|
||||
.map((call: any, index: number) => ({
|
||||
id: call?.id ?? `tool_call_${round}_${index}`,
|
||||
name: call?.name ?? "unknown_tool",
|
||||
arguments: stringifyToolInput(call?.input),
|
||||
}));
|
||||
}
|
||||
|
||||
function extractText(response: any) {
|
||||
if (!Array.isArray(response?.content)) return "";
|
||||
return response.content
|
||||
.map((content: any) => (content?.type === "text" && typeof content.text === "string" ? content.text : ""))
|
||||
.join("")
|
||||
.trim();
|
||||
}
|
||||
|
||||
function buildToolResultBlock(call: NormalizedToolCall, toolResult: ToolRunOutcome) {
|
||||
return {
|
||||
type: "tool_result",
|
||||
tool_use_id: call.id,
|
||||
content: JSON.stringify(toolResult),
|
||||
is_error: !toolResult.ok,
|
||||
};
|
||||
}
|
||||
|
||||
function appendCorrection(conversation: any[], text: string) {
|
||||
conversation.push({ role: "assistant", content: text });
|
||||
conversation.push({
|
||||
role: "user",
|
||||
content: INTERNAL_CORRECTION,
|
||||
});
|
||||
}
|
||||
|
||||
function mergeUsage(acc: Required<ToolAwareUsage>, usage: any) {
|
||||
if (!usage) return false;
|
||||
const inputTokens = usage.input_tokens ?? 0;
|
||||
const outputTokens = usage.output_tokens ?? 0;
|
||||
acc.inputTokens += inputTokens;
|
||||
acc.outputTokens += outputTokens;
|
||||
acc.totalTokens += inputTokens + outputTokens;
|
||||
return true;
|
||||
}
|
||||
|
||||
export async function completeWithMessagesApi(params: ToolAwareCompletionParams): Promise<ToolAwareCompletionResult> {
|
||||
const enabledTools = getEnabledChatTools(params);
|
||||
if (!enabledTools.length) {
|
||||
const response = await params.client.messages.create({
|
||||
model: params.model,
|
||||
system: buildTopLevelSystemPrompt(params.messages, params.userLocation),
|
||||
max_tokens: params.maxTokens ?? 1024,
|
||||
temperature: params.temperature,
|
||||
messages: buildBaseMessages(params),
|
||||
} as any);
|
||||
|
||||
const usageAcc: Required<ToolAwareUsage> = { inputTokens: 0, outputTokens: 0, totalTokens: 0 };
|
||||
const sawUsage = mergeUsage(usageAcc, response?.usage);
|
||||
|
||||
return {
|
||||
text: extractText(response),
|
||||
usage: sawUsage ? usageAcc : undefined,
|
||||
raw: { response, api: "messages" },
|
||||
toolEvents: [],
|
||||
};
|
||||
}
|
||||
|
||||
const conversation: any[] = buildBaseMessages(params);
|
||||
const rawResponses: unknown[] = [];
|
||||
const toolEvents: ToolExecutionEvent[] = [];
|
||||
const usageAcc: Required<ToolAwareUsage> = { inputTokens: 0, outputTokens: 0, totalTokens: 0 };
|
||||
let sawUsage = false;
|
||||
let totalToolCalls = 0;
|
||||
let danglingToolIntentRetries = 0;
|
||||
|
||||
for (let round = 0; round < MAX_TOOL_ROUNDS; round += 1) {
|
||||
const response = await params.client.messages.create({
|
||||
model: params.model,
|
||||
system: buildTopLevelSystemPrompt(params.messages, params.userLocation, buildChatToolSystemPrompt(params)),
|
||||
max_tokens: params.maxTokens ?? 1024,
|
||||
temperature: params.temperature,
|
||||
messages: conversation,
|
||||
tools: toTools(enabledTools),
|
||||
tool_choice: { type: "auto" },
|
||||
} as any);
|
||||
rawResponses.push(response);
|
||||
sawUsage = mergeUsage(usageAcc, response?.usage) || sawUsage;
|
||||
|
||||
const content = Array.isArray(response?.content) ? response.content : [];
|
||||
const normalizedToolCalls = normalizeToolCalls(content, round);
|
||||
if (!normalizedToolCalls.length) {
|
||||
const text = extractText(response);
|
||||
if (danglingToolIntentRetries < MAX_DANGLING_TOOL_INTENT_RETRIES && looksLikeDanglingToolIntent(text)) {
|
||||
danglingToolIntentRetries += 1;
|
||||
appendCorrection(conversation, text);
|
||||
continue;
|
||||
}
|
||||
return {
|
||||
text,
|
||||
usage: sawUsage ? usageAcc : undefined,
|
||||
raw: { responses: rawResponses, toolCallsUsed: totalToolCalls, api: "messages" },
|
||||
toolEvents,
|
||||
};
|
||||
}
|
||||
|
||||
totalToolCalls += normalizedToolCalls.length;
|
||||
conversation.push({
|
||||
role: "assistant",
|
||||
content,
|
||||
});
|
||||
|
||||
const toolResultBlocks: any[] = [];
|
||||
for (const call of normalizedToolCalls) {
|
||||
const { execution } = prepareToolCallExecution(call);
|
||||
const { event, toolResult } = await executeToolCallAndBuildEvent(call, execution, params);
|
||||
toolEvents.push(event);
|
||||
toolResultBlocks.push(buildToolResultBlock(call, toolResult));
|
||||
}
|
||||
|
||||
conversation.push({
|
||||
role: "user",
|
||||
content: toolResultBlocks,
|
||||
});
|
||||
}
|
||||
|
||||
return {
|
||||
text: "I reached the tool-call limit while gathering information. Please narrow the request and try again.",
|
||||
usage: sawUsage ? usageAcc : undefined,
|
||||
raw: { responses: rawResponses, toolCallsUsed: totalToolCalls, toolCallLimitReached: true, api: "messages" },
|
||||
toolEvents,
|
||||
};
|
||||
}
|
||||
|
||||
export async function* streamWithMessagesApi(params: ToolAwareCompletionParams): AsyncGenerator<ToolAwareStreamingEvent> {
|
||||
const enabledTools = getEnabledChatTools(params);
|
||||
if (!enabledTools.length) {
|
||||
const rawResponses: unknown[] = [];
|
||||
const usageAcc: Required<ToolAwareUsage> = { inputTokens: 0, outputTokens: 0, totalTokens: 0 };
|
||||
let sawUsage = false;
|
||||
let roundInputTokens = 0;
|
||||
let roundOutputTokens = 0;
|
||||
let text = "";
|
||||
|
||||
const stream = await params.client.messages.create({
|
||||
model: params.model,
|
||||
system: buildTopLevelSystemPrompt(params.messages, params.userLocation),
|
||||
max_tokens: params.maxTokens ?? 1024,
|
||||
temperature: params.temperature,
|
||||
messages: buildBaseMessages(params),
|
||||
stream: true,
|
||||
} as any);
|
||||
|
||||
for await (const ev of stream as any as AsyncIterable<any>) {
|
||||
rawResponses.push(ev);
|
||||
if (ev?.type === "message_start" && ev?.message?.usage) {
|
||||
roundInputTokens = ev.message.usage.input_tokens ?? roundInputTokens;
|
||||
sawUsage = true;
|
||||
}
|
||||
if (ev?.type === "content_block_delta" && ev?.delta?.type === "text_delta") {
|
||||
const delta = ev.delta.text ?? "";
|
||||
if (delta) {
|
||||
text += delta;
|
||||
yield { type: "delta", text: delta };
|
||||
}
|
||||
}
|
||||
if (ev?.type === "message_delta" && ev.usage) {
|
||||
roundInputTokens = ev.usage.input_tokens ?? roundInputTokens;
|
||||
roundOutputTokens = ev.usage.output_tokens ?? roundOutputTokens;
|
||||
sawUsage = true;
|
||||
}
|
||||
}
|
||||
|
||||
if (sawUsage) {
|
||||
usageAcc.inputTokens += roundInputTokens;
|
||||
usageAcc.outputTokens += roundOutputTokens;
|
||||
usageAcc.totalTokens += roundInputTokens + roundOutputTokens;
|
||||
}
|
||||
|
||||
yield {
|
||||
type: "done",
|
||||
result: {
|
||||
text,
|
||||
usage: sawUsage ? usageAcc : undefined,
|
||||
raw: { streamed: true, responses: rawResponses, toolCallsUsed: 0, api: "messages" },
|
||||
toolEvents: [],
|
||||
},
|
||||
};
|
||||
return;
|
||||
}
|
||||
|
||||
const conversation: any[] = buildBaseMessages(params);
|
||||
const rawResponses: unknown[] = [];
|
||||
const toolEvents: ToolExecutionEvent[] = [];
|
||||
const usageAcc: Required<ToolAwareUsage> = { inputTokens: 0, outputTokens: 0, totalTokens: 0 };
|
||||
let sawUsage = false;
|
||||
let totalToolCalls = 0;
|
||||
let danglingToolIntentRetries = 0;
|
||||
|
||||
for (let round = 0; round < MAX_TOOL_ROUNDS; round += 1) {
|
||||
const stream = await params.client.messages.create({
|
||||
model: params.model,
|
||||
system: buildTopLevelSystemPrompt(params.messages, params.userLocation, buildChatToolSystemPrompt(params)),
|
||||
max_tokens: params.maxTokens ?? 1024,
|
||||
temperature: params.temperature,
|
||||
messages: conversation,
|
||||
tools: toTools(enabledTools),
|
||||
tool_choice: { type: "auto" },
|
||||
stream: true,
|
||||
} as any);
|
||||
|
||||
const contentByIndex = new Map<number, any>();
|
||||
const toolArgumentByIndex = new Map<number, string>();
|
||||
let roundText = "";
|
||||
let roundHasToolCalls = false;
|
||||
let roundInputTokens = 0;
|
||||
let roundOutputTokens = 0;
|
||||
let sawRoundUsage = false;
|
||||
|
||||
for await (const ev of stream as any as AsyncIterable<any>) {
|
||||
rawResponses.push(ev);
|
||||
|
||||
if (ev?.type === "message_start" && ev?.message?.usage) {
|
||||
roundInputTokens = ev.message.usage.input_tokens ?? roundInputTokens;
|
||||
sawRoundUsage = true;
|
||||
}
|
||||
|
||||
if (ev?.type === "content_block_start" && typeof ev.index === "number") {
|
||||
const block = ev.content_block ?? {};
|
||||
if (block.type === "tool_use") {
|
||||
roundHasToolCalls = true;
|
||||
contentByIndex.set(ev.index, {
|
||||
type: "tool_use",
|
||||
id: block.id,
|
||||
name: block.name,
|
||||
input: block.input ?? {},
|
||||
});
|
||||
toolArgumentByIndex.set(ev.index, "");
|
||||
} else if (block.type === "text") {
|
||||
contentByIndex.set(ev.index, {
|
||||
type: "text",
|
||||
text: typeof block.text === "string" ? block.text : "",
|
||||
});
|
||||
} else if (block.type) {
|
||||
contentByIndex.set(ev.index, block);
|
||||
}
|
||||
}
|
||||
|
||||
if (ev?.type === "content_block_delta" && typeof ev.index === "number") {
|
||||
if (ev.delta?.type === "text_delta") {
|
||||
const delta = typeof ev.delta.text === "string" ? ev.delta.text : "";
|
||||
if (delta) {
|
||||
const block = contentByIndex.get(ev.index) ?? { type: "text", text: "" };
|
||||
if (block.type === "text") {
|
||||
block.text = `${typeof block.text === "string" ? block.text : ""}${delta}`;
|
||||
contentByIndex.set(ev.index, block);
|
||||
}
|
||||
roundText += delta;
|
||||
}
|
||||
} else if (ev.delta?.type === "input_json_delta") {
|
||||
roundHasToolCalls = true;
|
||||
const partialJson = typeof ev.delta.partial_json === "string" ? ev.delta.partial_json : "";
|
||||
toolArgumentByIndex.set(ev.index, `${toolArgumentByIndex.get(ev.index) ?? ""}${partialJson}`);
|
||||
}
|
||||
}
|
||||
|
||||
if (ev?.type === "content_block_stop" && typeof ev.index === "number") {
|
||||
const block = contentByIndex.get(ev.index);
|
||||
if (block?.type === "tool_use") {
|
||||
const rawArguments = toolArgumentByIndex.get(ev.index) || stringifyToolInput(block.input);
|
||||
try {
|
||||
block.input = parseToolArgs(rawArguments);
|
||||
} catch {
|
||||
block.input = {};
|
||||
}
|
||||
contentByIndex.set(ev.index, block);
|
||||
}
|
||||
}
|
||||
|
||||
if (ev?.type === "message_delta" && ev.usage) {
|
||||
roundInputTokens = ev.usage.input_tokens ?? roundInputTokens;
|
||||
roundOutputTokens = ev.usage.output_tokens ?? roundOutputTokens;
|
||||
sawRoundUsage = true;
|
||||
}
|
||||
}
|
||||
|
||||
if (sawRoundUsage) {
|
||||
usageAcc.inputTokens += roundInputTokens;
|
||||
usageAcc.outputTokens += roundOutputTokens;
|
||||
usageAcc.totalTokens += roundInputTokens + roundOutputTokens;
|
||||
sawUsage = true;
|
||||
}
|
||||
|
||||
const indexedContent = [...contentByIndex.entries()].sort((a, b) => a[0] - b[0]);
|
||||
const assistantContent = indexedContent.map(([, block]) => block);
|
||||
const normalizedToolCalls: NormalizedToolCall[] = indexedContent
|
||||
.filter(([, block]) => block?.type === "tool_use")
|
||||
.map(([index, block], callIndex) => ({
|
||||
id: block.id ?? `tool_call_${round}_${callIndex}`,
|
||||
name: block.name ?? "unknown_tool",
|
||||
arguments: toolArgumentByIndex.get(index) || stringifyToolInput(block.input),
|
||||
}));
|
||||
|
||||
if (!normalizedToolCalls.length) {
|
||||
if (danglingToolIntentRetries < MAX_DANGLING_TOOL_INTENT_RETRIES && looksLikeDanglingToolIntent(roundText)) {
|
||||
danglingToolIntentRetries += 1;
|
||||
appendCorrection(conversation, roundText);
|
||||
continue;
|
||||
}
|
||||
if (roundText) {
|
||||
yield { type: "delta", text: roundText };
|
||||
}
|
||||
yield {
|
||||
type: "done",
|
||||
result: {
|
||||
text: roundText,
|
||||
usage: sawUsage ? usageAcc : undefined,
|
||||
raw: { streamed: true, responses: rawResponses, toolCallsUsed: totalToolCalls, api: "messages" },
|
||||
toolEvents,
|
||||
},
|
||||
};
|
||||
return;
|
||||
}
|
||||
|
||||
totalToolCalls += normalizedToolCalls.length;
|
||||
conversation.push({
|
||||
role: "assistant",
|
||||
content: assistantContent,
|
||||
});
|
||||
|
||||
const toolResultBlocks: any[] = [];
|
||||
for (const call of normalizedToolCalls) {
|
||||
const { event: initiatedEvent, execution } = prepareToolCallExecution(call);
|
||||
yield { type: "tool_call", event: initiatedEvent };
|
||||
const { event, toolResult } = await executeToolCallAndBuildEvent(call, execution, params);
|
||||
toolEvents.push(event);
|
||||
yield { type: "tool_call", event };
|
||||
toolResultBlocks.push(buildToolResultBlock(call, toolResult));
|
||||
}
|
||||
|
||||
conversation.push({
|
||||
role: "user",
|
||||
content: toolResultBlocks,
|
||||
});
|
||||
}
|
||||
|
||||
yield {
|
||||
type: "done",
|
||||
result: {
|
||||
text: "I reached the tool-call limit while gathering information. Please narrow the request and try again.",
|
||||
usage: sawUsage ? usageAcc : undefined,
|
||||
raw: { streamed: true, responses: rawResponses, toolCallsUsed: totalToolCalls, toolCallLimitReached: true, api: "messages" },
|
||||
toolEvents,
|
||||
},
|
||||
};
|
||||
}
|
||||
Reference in New Issue
Block a user