471 lines
16 KiB
TypeScript
471 lines
16 KiB
TypeScript
import {
|
|
buildChatToolSystemPrompt,
|
|
executeToolCallAndBuildEvent,
|
|
getEnabledChatTools,
|
|
looksLikeDanglingToolIntent,
|
|
MAX_DANGLING_TOOL_INTENT_RETRIES,
|
|
MAX_TOOL_ROUNDS,
|
|
parseToolArgs,
|
|
prepareToolCallExecution,
|
|
type NormalizedToolCall,
|
|
type ToolAwareCompletionParams,
|
|
type ToolAwareCompletionResult,
|
|
type ToolAwareStreamingEvent,
|
|
type ToolAwareUsage,
|
|
type ToolExecutionEvent,
|
|
type ToolRunOutcome,
|
|
} from "../chat-tools.js";
|
|
import {
|
|
buildImageSummaryText,
|
|
buildTextAttachmentPrompt,
|
|
buildTopLevelSystemPrompt,
|
|
getImageAttachments,
|
|
getTextAttachments,
|
|
parseImageDataUrl,
|
|
} from "../message-content.js";
|
|
import type { ChatMessage } from "../types.js";
|
|
|
|
const INTERNAL_CORRECTION =
|
|
"Internal correction: the previous assistant message claimed it would run a tool, but no tool call was made. If the task needs an available tool, call it now. Otherwise provide the final answer directly without saying you will run a tool.";
|
|
|
|
function toTools(tools: any[]) {
|
|
return tools
|
|
.map((tool) => {
|
|
if (tool?.type !== "function") return null;
|
|
return {
|
|
name: tool.function.name,
|
|
description: tool.function.description,
|
|
input_schema: tool.function.parameters,
|
|
};
|
|
})
|
|
.filter(Boolean);
|
|
}
|
|
|
|
function toContentBlocks(message: ChatMessage) {
|
|
const imageAttachments = getImageAttachments(message);
|
|
const textAttachments = getTextAttachments(message);
|
|
if (!imageAttachments.length && !textAttachments.length) {
|
|
return message.content;
|
|
}
|
|
|
|
const blocks: Array<Record<string, unknown>> = [];
|
|
for (const attachment of imageAttachments) {
|
|
const source = parseImageDataUrl(attachment);
|
|
blocks.push({
|
|
type: "image",
|
|
source: {
|
|
type: "base64",
|
|
media_type: source.mediaType,
|
|
data: source.data,
|
|
},
|
|
});
|
|
}
|
|
|
|
const imageSummary = buildImageSummaryText(imageAttachments);
|
|
if (imageSummary) {
|
|
blocks.push({ type: "text", text: imageSummary });
|
|
}
|
|
|
|
for (const attachment of textAttachments) {
|
|
blocks.push({ type: "text", text: buildTextAttachmentPrompt(attachment) });
|
|
}
|
|
|
|
if (message.content.trim()) {
|
|
blocks.push({ type: "text", text: message.content });
|
|
}
|
|
|
|
if (blocks.length === 1 && blocks[0]?.type === "text" && typeof blocks[0].text === "string") {
|
|
return blocks[0].text;
|
|
}
|
|
|
|
return blocks;
|
|
}
|
|
|
|
function buildConversationMessage(message: ChatMessage) {
|
|
if (message.role === "system") {
|
|
throw new Error("System messages must be handled separately for top-level-system protocols.");
|
|
}
|
|
|
|
if (message.role === "tool") {
|
|
const name = message.name?.trim() || "tool";
|
|
return {
|
|
role: "user",
|
|
content: `Tool output (${name}):\n${message.content}`,
|
|
};
|
|
}
|
|
|
|
return {
|
|
role: message.role === "assistant" ? "assistant" : "user",
|
|
content: toContentBlocks(message),
|
|
};
|
|
}
|
|
|
|
function buildBaseMessages(params: ToolAwareCompletionParams) {
|
|
return params.messages.filter((message) => message.role !== "system").map((message) => buildConversationMessage(message));
|
|
}
|
|
|
|
function stringifyToolInput(input: unknown) {
|
|
if (typeof input === "string") return input;
|
|
try {
|
|
return JSON.stringify(input ?? {});
|
|
} catch {
|
|
return "{}";
|
|
}
|
|
}
|
|
|
|
function normalizeToolCalls(content: any[], round: number): NormalizedToolCall[] {
|
|
return content
|
|
.filter((item) => item?.type === "tool_use")
|
|
.map((call: any, index: number) => ({
|
|
id: call?.id ?? `tool_call_${round}_${index}`,
|
|
name: call?.name ?? "unknown_tool",
|
|
arguments: stringifyToolInput(call?.input),
|
|
}));
|
|
}
|
|
|
|
function extractText(response: any) {
|
|
if (!Array.isArray(response?.content)) return "";
|
|
return response.content
|
|
.map((content: any) => (content?.type === "text" && typeof content.text === "string" ? content.text : ""))
|
|
.join("")
|
|
.trim();
|
|
}
|
|
|
|
function buildToolResultBlock(call: NormalizedToolCall, toolResult: ToolRunOutcome) {
|
|
return {
|
|
type: "tool_result",
|
|
tool_use_id: call.id,
|
|
content: JSON.stringify(toolResult),
|
|
is_error: !toolResult.ok,
|
|
};
|
|
}
|
|
|
|
function appendCorrection(conversation: any[], text: string) {
|
|
conversation.push({ role: "assistant", content: text });
|
|
conversation.push({
|
|
role: "user",
|
|
content: INTERNAL_CORRECTION,
|
|
});
|
|
}
|
|
|
|
function mergeUsage(acc: Required<ToolAwareUsage>, usage: any) {
|
|
if (!usage) return false;
|
|
const inputTokens = usage.input_tokens ?? 0;
|
|
const outputTokens = usage.output_tokens ?? 0;
|
|
acc.inputTokens += inputTokens;
|
|
acc.outputTokens += outputTokens;
|
|
acc.totalTokens += inputTokens + outputTokens;
|
|
return true;
|
|
}
|
|
|
|
export async function completeWithMessagesApi(params: ToolAwareCompletionParams): Promise<ToolAwareCompletionResult> {
|
|
const enabledTools = getEnabledChatTools(params);
|
|
if (!enabledTools.length) {
|
|
const response = await params.client.messages.create({
|
|
model: params.model,
|
|
system: buildTopLevelSystemPrompt(params.messages, params.userLocation),
|
|
max_tokens: params.maxTokens ?? 1024,
|
|
temperature: params.temperature,
|
|
messages: buildBaseMessages(params),
|
|
} as any);
|
|
|
|
const usageAcc: Required<ToolAwareUsage> = { inputTokens: 0, outputTokens: 0, totalTokens: 0 };
|
|
const sawUsage = mergeUsage(usageAcc, response?.usage);
|
|
|
|
return {
|
|
text: extractText(response),
|
|
usage: sawUsage ? usageAcc : undefined,
|
|
raw: { response, api: "messages" },
|
|
toolEvents: [],
|
|
};
|
|
}
|
|
|
|
const conversation: any[] = buildBaseMessages(params);
|
|
const rawResponses: unknown[] = [];
|
|
const toolEvents: ToolExecutionEvent[] = [];
|
|
const usageAcc: Required<ToolAwareUsage> = { inputTokens: 0, outputTokens: 0, totalTokens: 0 };
|
|
let sawUsage = false;
|
|
let totalToolCalls = 0;
|
|
let danglingToolIntentRetries = 0;
|
|
|
|
for (let round = 0; round < MAX_TOOL_ROUNDS; round += 1) {
|
|
const response = await params.client.messages.create({
|
|
model: params.model,
|
|
system: buildTopLevelSystemPrompt(params.messages, params.userLocation, buildChatToolSystemPrompt(params)),
|
|
max_tokens: params.maxTokens ?? 1024,
|
|
temperature: params.temperature,
|
|
messages: conversation,
|
|
tools: toTools(enabledTools),
|
|
tool_choice: { type: "auto" },
|
|
} as any);
|
|
rawResponses.push(response);
|
|
sawUsage = mergeUsage(usageAcc, response?.usage) || sawUsage;
|
|
|
|
const content = Array.isArray(response?.content) ? response.content : [];
|
|
const normalizedToolCalls = normalizeToolCalls(content, round);
|
|
if (!normalizedToolCalls.length) {
|
|
const text = extractText(response);
|
|
if (danglingToolIntentRetries < MAX_DANGLING_TOOL_INTENT_RETRIES && looksLikeDanglingToolIntent(text)) {
|
|
danglingToolIntentRetries += 1;
|
|
appendCorrection(conversation, text);
|
|
continue;
|
|
}
|
|
return {
|
|
text,
|
|
usage: sawUsage ? usageAcc : undefined,
|
|
raw: { responses: rawResponses, toolCallsUsed: totalToolCalls, api: "messages" },
|
|
toolEvents,
|
|
};
|
|
}
|
|
|
|
totalToolCalls += normalizedToolCalls.length;
|
|
conversation.push({
|
|
role: "assistant",
|
|
content,
|
|
});
|
|
|
|
const toolResultBlocks: any[] = [];
|
|
for (const call of normalizedToolCalls) {
|
|
const { execution } = prepareToolCallExecution(call);
|
|
const { event, toolResult } = await executeToolCallAndBuildEvent(call, execution, params);
|
|
toolEvents.push(event);
|
|
toolResultBlocks.push(buildToolResultBlock(call, toolResult));
|
|
}
|
|
|
|
conversation.push({
|
|
role: "user",
|
|
content: toolResultBlocks,
|
|
});
|
|
}
|
|
|
|
return {
|
|
text: "I reached the tool-call limit while gathering information. Please narrow the request and try again.",
|
|
usage: sawUsage ? usageAcc : undefined,
|
|
raw: { responses: rawResponses, toolCallsUsed: totalToolCalls, toolCallLimitReached: true, api: "messages" },
|
|
toolEvents,
|
|
};
|
|
}
|
|
|
|
export async function* streamWithMessagesApi(params: ToolAwareCompletionParams): AsyncGenerator<ToolAwareStreamingEvent> {
|
|
const enabledTools = getEnabledChatTools(params);
|
|
if (!enabledTools.length) {
|
|
const rawResponses: unknown[] = [];
|
|
const usageAcc: Required<ToolAwareUsage> = { inputTokens: 0, outputTokens: 0, totalTokens: 0 };
|
|
let sawUsage = false;
|
|
let roundInputTokens = 0;
|
|
let roundOutputTokens = 0;
|
|
let text = "";
|
|
|
|
const stream = await params.client.messages.create({
|
|
model: params.model,
|
|
system: buildTopLevelSystemPrompt(params.messages, params.userLocation),
|
|
max_tokens: params.maxTokens ?? 1024,
|
|
temperature: params.temperature,
|
|
messages: buildBaseMessages(params),
|
|
stream: true,
|
|
} as any);
|
|
|
|
for await (const ev of stream as any as AsyncIterable<any>) {
|
|
rawResponses.push(ev);
|
|
if (ev?.type === "message_start" && ev?.message?.usage) {
|
|
roundInputTokens = ev.message.usage.input_tokens ?? roundInputTokens;
|
|
sawUsage = true;
|
|
}
|
|
if (ev?.type === "content_block_delta" && ev?.delta?.type === "text_delta") {
|
|
const delta = ev.delta.text ?? "";
|
|
if (delta) {
|
|
text += delta;
|
|
yield { type: "delta", text: delta };
|
|
}
|
|
}
|
|
if (ev?.type === "message_delta" && ev.usage) {
|
|
roundInputTokens = ev.usage.input_tokens ?? roundInputTokens;
|
|
roundOutputTokens = ev.usage.output_tokens ?? roundOutputTokens;
|
|
sawUsage = true;
|
|
}
|
|
}
|
|
|
|
if (sawUsage) {
|
|
usageAcc.inputTokens += roundInputTokens;
|
|
usageAcc.outputTokens += roundOutputTokens;
|
|
usageAcc.totalTokens += roundInputTokens + roundOutputTokens;
|
|
}
|
|
|
|
yield {
|
|
type: "done",
|
|
result: {
|
|
text,
|
|
usage: sawUsage ? usageAcc : undefined,
|
|
raw: { streamed: true, responses: rawResponses, toolCallsUsed: 0, api: "messages" },
|
|
toolEvents: [],
|
|
},
|
|
};
|
|
return;
|
|
}
|
|
|
|
const conversation: any[] = buildBaseMessages(params);
|
|
const rawResponses: unknown[] = [];
|
|
const toolEvents: ToolExecutionEvent[] = [];
|
|
const usageAcc: Required<ToolAwareUsage> = { inputTokens: 0, outputTokens: 0, totalTokens: 0 };
|
|
let sawUsage = false;
|
|
let totalToolCalls = 0;
|
|
let danglingToolIntentRetries = 0;
|
|
|
|
for (let round = 0; round < MAX_TOOL_ROUNDS; round += 1) {
|
|
const stream = await params.client.messages.create({
|
|
model: params.model,
|
|
system: buildTopLevelSystemPrompt(params.messages, params.userLocation, buildChatToolSystemPrompt(params)),
|
|
max_tokens: params.maxTokens ?? 1024,
|
|
temperature: params.temperature,
|
|
messages: conversation,
|
|
tools: toTools(enabledTools),
|
|
tool_choice: { type: "auto" },
|
|
stream: true,
|
|
} as any);
|
|
|
|
const contentByIndex = new Map<number, any>();
|
|
const toolArgumentByIndex = new Map<number, string>();
|
|
let roundText = "";
|
|
let roundHasToolCalls = false;
|
|
let roundInputTokens = 0;
|
|
let roundOutputTokens = 0;
|
|
let sawRoundUsage = false;
|
|
|
|
for await (const ev of stream as any as AsyncIterable<any>) {
|
|
rawResponses.push(ev);
|
|
|
|
if (ev?.type === "message_start" && ev?.message?.usage) {
|
|
roundInputTokens = ev.message.usage.input_tokens ?? roundInputTokens;
|
|
sawRoundUsage = true;
|
|
}
|
|
|
|
if (ev?.type === "content_block_start" && typeof ev.index === "number") {
|
|
const block = ev.content_block ?? {};
|
|
if (block.type === "tool_use") {
|
|
roundHasToolCalls = true;
|
|
contentByIndex.set(ev.index, {
|
|
type: "tool_use",
|
|
id: block.id,
|
|
name: block.name,
|
|
input: block.input ?? {},
|
|
});
|
|
toolArgumentByIndex.set(ev.index, "");
|
|
} else if (block.type === "text") {
|
|
contentByIndex.set(ev.index, {
|
|
type: "text",
|
|
text: typeof block.text === "string" ? block.text : "",
|
|
});
|
|
} else if (block.type) {
|
|
contentByIndex.set(ev.index, block);
|
|
}
|
|
}
|
|
|
|
if (ev?.type === "content_block_delta" && typeof ev.index === "number") {
|
|
if (ev.delta?.type === "text_delta") {
|
|
const delta = typeof ev.delta.text === "string" ? ev.delta.text : "";
|
|
if (delta) {
|
|
const block = contentByIndex.get(ev.index) ?? { type: "text", text: "" };
|
|
if (block.type === "text") {
|
|
block.text = `${typeof block.text === "string" ? block.text : ""}${delta}`;
|
|
contentByIndex.set(ev.index, block);
|
|
}
|
|
roundText += delta;
|
|
}
|
|
} else if (ev.delta?.type === "input_json_delta") {
|
|
roundHasToolCalls = true;
|
|
const partialJson = typeof ev.delta.partial_json === "string" ? ev.delta.partial_json : "";
|
|
toolArgumentByIndex.set(ev.index, `${toolArgumentByIndex.get(ev.index) ?? ""}${partialJson}`);
|
|
}
|
|
}
|
|
|
|
if (ev?.type === "content_block_stop" && typeof ev.index === "number") {
|
|
const block = contentByIndex.get(ev.index);
|
|
if (block?.type === "tool_use") {
|
|
const rawArguments = toolArgumentByIndex.get(ev.index) || stringifyToolInput(block.input);
|
|
try {
|
|
block.input = parseToolArgs(rawArguments);
|
|
} catch {
|
|
block.input = {};
|
|
}
|
|
contentByIndex.set(ev.index, block);
|
|
}
|
|
}
|
|
|
|
if (ev?.type === "message_delta" && ev.usage) {
|
|
roundInputTokens = ev.usage.input_tokens ?? roundInputTokens;
|
|
roundOutputTokens = ev.usage.output_tokens ?? roundOutputTokens;
|
|
sawRoundUsage = true;
|
|
}
|
|
}
|
|
|
|
if (sawRoundUsage) {
|
|
usageAcc.inputTokens += roundInputTokens;
|
|
usageAcc.outputTokens += roundOutputTokens;
|
|
usageAcc.totalTokens += roundInputTokens + roundOutputTokens;
|
|
sawUsage = true;
|
|
}
|
|
|
|
const indexedContent = [...contentByIndex.entries()].sort((a, b) => a[0] - b[0]);
|
|
const assistantContent = indexedContent.map(([, block]) => block);
|
|
const normalizedToolCalls: NormalizedToolCall[] = indexedContent
|
|
.filter(([, block]) => block?.type === "tool_use")
|
|
.map(([index, block], callIndex) => ({
|
|
id: block.id ?? `tool_call_${round}_${callIndex}`,
|
|
name: block.name ?? "unknown_tool",
|
|
arguments: toolArgumentByIndex.get(index) || stringifyToolInput(block.input),
|
|
}));
|
|
|
|
if (!normalizedToolCalls.length) {
|
|
if (danglingToolIntentRetries < MAX_DANGLING_TOOL_INTENT_RETRIES && looksLikeDanglingToolIntent(roundText)) {
|
|
danglingToolIntentRetries += 1;
|
|
appendCorrection(conversation, roundText);
|
|
continue;
|
|
}
|
|
if (roundText) {
|
|
yield { type: "delta", text: roundText };
|
|
}
|
|
yield {
|
|
type: "done",
|
|
result: {
|
|
text: roundText,
|
|
usage: sawUsage ? usageAcc : undefined,
|
|
raw: { streamed: true, responses: rawResponses, toolCallsUsed: totalToolCalls, api: "messages" },
|
|
toolEvents,
|
|
},
|
|
};
|
|
return;
|
|
}
|
|
|
|
totalToolCalls += normalizedToolCalls.length;
|
|
conversation.push({
|
|
role: "assistant",
|
|
content: assistantContent,
|
|
});
|
|
|
|
const toolResultBlocks: any[] = [];
|
|
for (const call of normalizedToolCalls) {
|
|
const { event: initiatedEvent, execution } = prepareToolCallExecution(call);
|
|
yield { type: "tool_call", event: initiatedEvent };
|
|
const { event, toolResult } = await executeToolCallAndBuildEvent(call, execution, params);
|
|
toolEvents.push(event);
|
|
yield { type: "tool_call", event };
|
|
toolResultBlocks.push(buildToolResultBlock(call, toolResult));
|
|
}
|
|
|
|
conversation.push({
|
|
role: "user",
|
|
content: toolResultBlocks,
|
|
});
|
|
}
|
|
|
|
yield {
|
|
type: "done",
|
|
result: {
|
|
text: "I reached the tool-call limit while gathering information. Please narrow the request and try again.",
|
|
usage: sawUsage ? usageAcc : undefined,
|
|
raw: { streamed: true, responses: rawResponses, toolCallsUsed: totalToolCalls, toolCallLimitReached: true, api: "messages" },
|
|
toolEvents,
|
|
},
|
|
};
|
|
}
|