387 lines
13 KiB
TypeScript
387 lines
13 KiB
TypeScript
|
|
import {
|
||
|
|
appendDanglingToolIntentCorrection,
|
||
|
|
buildChatToolSystemPrompt,
|
||
|
|
executeToolCallAndBuildEvent,
|
||
|
|
getEnabledChatTools,
|
||
|
|
getUnstreamedText,
|
||
|
|
looksLikeDanglingToolIntent,
|
||
|
|
MAX_DANGLING_TOOL_INTENT_RETRIES,
|
||
|
|
MAX_TOOL_ROUNDS,
|
||
|
|
mergeUsage,
|
||
|
|
normalizeModelToolCalls,
|
||
|
|
prepareToolCallExecution,
|
||
|
|
type NormalizedToolCall,
|
||
|
|
type ToolAwareCompletionParams,
|
||
|
|
type ToolAwareCompletionResult,
|
||
|
|
type ToolAwareStreamingEvent,
|
||
|
|
type ToolExecutionEvent,
|
||
|
|
} from "../chat-tools.js";
|
||
|
|
import {
|
||
|
|
buildImageSummaryText,
|
||
|
|
buildSystemPromptAugmentationMessage,
|
||
|
|
buildTextAttachmentPrompt,
|
||
|
|
getImageAttachments,
|
||
|
|
getTextAttachments,
|
||
|
|
} from "../message-content.js";
|
||
|
|
import type { ChatMessage } from "../types.js";
|
||
|
|
|
||
|
|
function toContentParts(message: ChatMessage) {
|
||
|
|
const imageAttachments = getImageAttachments(message);
|
||
|
|
const textAttachments = getTextAttachments(message);
|
||
|
|
if (!imageAttachments.length && !textAttachments.length) {
|
||
|
|
return message.content;
|
||
|
|
}
|
||
|
|
|
||
|
|
const parts: Array<Record<string, unknown>> = [];
|
||
|
|
for (const attachment of imageAttachments) {
|
||
|
|
parts.push({
|
||
|
|
type: "image_url",
|
||
|
|
image_url: {
|
||
|
|
url: attachment.dataUrl,
|
||
|
|
detail: "auto",
|
||
|
|
},
|
||
|
|
});
|
||
|
|
}
|
||
|
|
|
||
|
|
const imageSummary = buildImageSummaryText(imageAttachments);
|
||
|
|
if (imageSummary) {
|
||
|
|
parts.push({ type: "text", text: imageSummary });
|
||
|
|
}
|
||
|
|
|
||
|
|
for (const attachment of textAttachments) {
|
||
|
|
parts.push({ type: "text", text: buildTextAttachmentPrompt(attachment) });
|
||
|
|
}
|
||
|
|
|
||
|
|
if (message.content.trim()) {
|
||
|
|
parts.push({ type: "text", text: message.content });
|
||
|
|
}
|
||
|
|
|
||
|
|
if (parts.length === 1 && parts[0]?.type === "text" && typeof parts[0].text === "string") {
|
||
|
|
return parts[0].text;
|
||
|
|
}
|
||
|
|
|
||
|
|
return parts;
|
||
|
|
}
|
||
|
|
|
||
|
|
function buildConversationMessage(message: ChatMessage) {
|
||
|
|
if (message.role === "tool") {
|
||
|
|
const name = message.name?.trim() || "tool";
|
||
|
|
return {
|
||
|
|
role: "user",
|
||
|
|
content: `Tool output (${name}):\n${message.content}`,
|
||
|
|
};
|
||
|
|
}
|
||
|
|
|
||
|
|
const out: Record<string, unknown> = {
|
||
|
|
role: message.role,
|
||
|
|
content: toContentParts(message),
|
||
|
|
};
|
||
|
|
|
||
|
|
if (message.name && (message.role === "assistant" || message.role === "user")) {
|
||
|
|
out.name = message.name;
|
||
|
|
}
|
||
|
|
|
||
|
|
return out;
|
||
|
|
}
|
||
|
|
|
||
|
|
function normalizeMessages(messages: ChatMessage[], userLocation?: string, params: Pick<ToolAwareCompletionParams, "enabledTools"> = {}) {
|
||
|
|
const normalized = messages.map((message) => buildConversationMessage(message));
|
||
|
|
return [{ role: "system", content: buildChatToolSystemPrompt(params) }, buildSystemPromptAugmentationMessage(userLocation), ...normalized];
|
||
|
|
}
|
||
|
|
|
||
|
|
function normalizePlainMessages(messages: ChatMessage[], userLocation?: string) {
|
||
|
|
return [buildSystemPromptAugmentationMessage(userLocation), ...messages.map((message) => buildConversationMessage(message))];
|
||
|
|
}
|
||
|
|
|
||
|
|
function extractContent(message: any) {
|
||
|
|
if (typeof message?.content === "string") return message.content;
|
||
|
|
if (!Array.isArray(message?.content)) return "";
|
||
|
|
|
||
|
|
return message.content
|
||
|
|
.map((part: any) => {
|
||
|
|
if (typeof part === "string") return part;
|
||
|
|
if (typeof part?.text === "string") return part.text;
|
||
|
|
if (typeof part?.content === "string") return part.content;
|
||
|
|
return "";
|
||
|
|
})
|
||
|
|
.join("");
|
||
|
|
}
|
||
|
|
|
||
|
|
export async function completeWithChatCompletionsApi(params: ToolAwareCompletionParams): Promise<ToolAwareCompletionResult> {
|
||
|
|
const enabledTools = getEnabledChatTools(params);
|
||
|
|
if (!enabledTools.length) {
|
||
|
|
const completion = await params.client.chat.completions.create({
|
||
|
|
model: params.model,
|
||
|
|
messages: normalizePlainMessages(params.messages, params.userLocation),
|
||
|
|
temperature: params.temperature,
|
||
|
|
max_tokens: params.maxTokens,
|
||
|
|
} as any);
|
||
|
|
|
||
|
|
const usageAcc: Required<NonNullable<ToolAwareCompletionResult["usage"]>> = { inputTokens: 0, outputTokens: 0, totalTokens: 0 };
|
||
|
|
const sawUsage = mergeUsage(usageAcc, completion?.usage);
|
||
|
|
const message = completion?.choices?.[0]?.message;
|
||
|
|
|
||
|
|
return {
|
||
|
|
text: extractContent(message),
|
||
|
|
usage: sawUsage ? usageAcc : undefined,
|
||
|
|
raw: { response: completion, api: "chat.completions" },
|
||
|
|
toolEvents: [],
|
||
|
|
};
|
||
|
|
}
|
||
|
|
|
||
|
|
const conversation: any[] = normalizeMessages(params.messages, params.userLocation, params);
|
||
|
|
const rawResponses: unknown[] = [];
|
||
|
|
const toolEvents: ToolExecutionEvent[] = [];
|
||
|
|
const usageAcc: Required<NonNullable<ToolAwareCompletionResult["usage"]>> = { inputTokens: 0, outputTokens: 0, totalTokens: 0 };
|
||
|
|
let sawUsage = false;
|
||
|
|
let totalToolCalls = 0;
|
||
|
|
let danglingToolIntentRetries = 0;
|
||
|
|
|
||
|
|
for (let round = 0; round < MAX_TOOL_ROUNDS; round += 1) {
|
||
|
|
const completion = await params.client.chat.completions.create({
|
||
|
|
model: params.model,
|
||
|
|
messages: conversation,
|
||
|
|
temperature: params.temperature,
|
||
|
|
max_tokens: params.maxTokens,
|
||
|
|
tools: enabledTools,
|
||
|
|
tool_choice: "auto",
|
||
|
|
} as any);
|
||
|
|
rawResponses.push(completion);
|
||
|
|
sawUsage = mergeUsage(usageAcc, completion?.usage) || sawUsage;
|
||
|
|
|
||
|
|
const message = completion?.choices?.[0]?.message;
|
||
|
|
if (!message) {
|
||
|
|
return {
|
||
|
|
text: "",
|
||
|
|
usage: sawUsage ? usageAcc : undefined,
|
||
|
|
raw: { responses: rawResponses, toolCallsUsed: totalToolCalls, missingMessage: true },
|
||
|
|
toolEvents,
|
||
|
|
};
|
||
|
|
}
|
||
|
|
|
||
|
|
const toolCalls = Array.isArray(message.tool_calls) ? message.tool_calls : [];
|
||
|
|
if (!toolCalls.length) {
|
||
|
|
const text = typeof message.content === "string" ? message.content : "";
|
||
|
|
if (danglingToolIntentRetries < MAX_DANGLING_TOOL_INTENT_RETRIES && looksLikeDanglingToolIntent(text)) {
|
||
|
|
danglingToolIntentRetries += 1;
|
||
|
|
appendDanglingToolIntentCorrection(conversation, text);
|
||
|
|
continue;
|
||
|
|
}
|
||
|
|
return {
|
||
|
|
text,
|
||
|
|
usage: sawUsage ? usageAcc : undefined,
|
||
|
|
raw: { responses: rawResponses, toolCallsUsed: totalToolCalls },
|
||
|
|
toolEvents,
|
||
|
|
};
|
||
|
|
}
|
||
|
|
|
||
|
|
const normalizedToolCalls = normalizeModelToolCalls(toolCalls, round);
|
||
|
|
totalToolCalls += normalizedToolCalls.length;
|
||
|
|
|
||
|
|
const assistantToolCallMessage: any = {
|
||
|
|
role: "assistant",
|
||
|
|
tool_calls: normalizedToolCalls.map((call) => ({
|
||
|
|
id: call.id,
|
||
|
|
type: "function",
|
||
|
|
function: {
|
||
|
|
name: call.name,
|
||
|
|
arguments: call.arguments,
|
||
|
|
},
|
||
|
|
})),
|
||
|
|
};
|
||
|
|
if (typeof message.content === "string" && message.content.length) {
|
||
|
|
assistantToolCallMessage.content = message.content;
|
||
|
|
}
|
||
|
|
conversation.push(assistantToolCallMessage);
|
||
|
|
|
||
|
|
for (const call of normalizedToolCalls) {
|
||
|
|
const { execution } = prepareToolCallExecution(call);
|
||
|
|
const { event, toolResult } = await executeToolCallAndBuildEvent(call, execution, params);
|
||
|
|
toolEvents.push(event);
|
||
|
|
|
||
|
|
conversation.push({
|
||
|
|
role: "tool",
|
||
|
|
tool_call_id: call.id,
|
||
|
|
content: JSON.stringify(toolResult),
|
||
|
|
});
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
return {
|
||
|
|
text: "I reached the tool-call limit while gathering information. Please narrow the request and try again.",
|
||
|
|
usage: sawUsage ? usageAcc : undefined,
|
||
|
|
raw: { responses: rawResponses, toolCallsUsed: totalToolCalls, toolCallLimitReached: true },
|
||
|
|
toolEvents,
|
||
|
|
};
|
||
|
|
}
|
||
|
|
|
||
|
|
export async function* streamWithChatCompletionsApi(params: ToolAwareCompletionParams): AsyncGenerator<ToolAwareStreamingEvent> {
|
||
|
|
const enabledTools = getEnabledChatTools(params);
|
||
|
|
if (!enabledTools.length) {
|
||
|
|
const rawResponses: unknown[] = [];
|
||
|
|
const usageAcc: Required<NonNullable<ToolAwareCompletionResult["usage"]>> = { inputTokens: 0, outputTokens: 0, totalTokens: 0 };
|
||
|
|
let sawUsage = false;
|
||
|
|
let text = "";
|
||
|
|
|
||
|
|
const stream = await params.client.chat.completions.create({
|
||
|
|
model: params.model,
|
||
|
|
messages: normalizePlainMessages(params.messages, params.userLocation),
|
||
|
|
temperature: params.temperature,
|
||
|
|
max_tokens: params.maxTokens,
|
||
|
|
stream: true,
|
||
|
|
} as any);
|
||
|
|
|
||
|
|
for await (const chunk of stream as any as AsyncIterable<any>) {
|
||
|
|
rawResponses.push(chunk);
|
||
|
|
sawUsage = mergeUsage(usageAcc, chunk?.usage) || sawUsage;
|
||
|
|
|
||
|
|
const deltaText = chunk?.choices?.[0]?.delta?.content ?? "";
|
||
|
|
if (typeof deltaText === "string" && deltaText.length) {
|
||
|
|
text += deltaText;
|
||
|
|
yield { type: "delta", text: deltaText };
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
yield {
|
||
|
|
type: "done",
|
||
|
|
result: {
|
||
|
|
text,
|
||
|
|
usage: sawUsage ? usageAcc : undefined,
|
||
|
|
raw: { streamed: true, responses: rawResponses, api: "chat.completions" },
|
||
|
|
toolEvents: [],
|
||
|
|
},
|
||
|
|
};
|
||
|
|
return;
|
||
|
|
}
|
||
|
|
|
||
|
|
const conversation: any[] = normalizeMessages(params.messages, params.userLocation, params);
|
||
|
|
const rawResponses: unknown[] = [];
|
||
|
|
const toolEvents: ToolExecutionEvent[] = [];
|
||
|
|
const usageAcc: Required<NonNullable<ToolAwareCompletionResult["usage"]>> = { inputTokens: 0, outputTokens: 0, totalTokens: 0 };
|
||
|
|
let sawUsage = false;
|
||
|
|
let totalToolCalls = 0;
|
||
|
|
let danglingToolIntentRetries = 0;
|
||
|
|
|
||
|
|
for (let round = 0; round < MAX_TOOL_ROUNDS; round += 1) {
|
||
|
|
const stream = await params.client.chat.completions.create({
|
||
|
|
model: params.model,
|
||
|
|
messages: conversation,
|
||
|
|
temperature: params.temperature,
|
||
|
|
max_tokens: params.maxTokens,
|
||
|
|
tools: enabledTools,
|
||
|
|
tool_choice: "auto",
|
||
|
|
stream: true,
|
||
|
|
stream_options: { include_usage: true },
|
||
|
|
} as any);
|
||
|
|
|
||
|
|
let roundText = "";
|
||
|
|
let streamedRoundText = "";
|
||
|
|
let roundHasToolCalls = false;
|
||
|
|
const roundToolCalls = new Map<number, { id?: string; name?: string; arguments: string }>();
|
||
|
|
|
||
|
|
for await (const chunk of stream as any as AsyncIterable<any>) {
|
||
|
|
rawResponses.push(chunk);
|
||
|
|
sawUsage = mergeUsage(usageAcc, chunk?.usage) || sawUsage;
|
||
|
|
|
||
|
|
const choice = chunk?.choices?.[0];
|
||
|
|
const deltaText = choice?.delta?.content ?? "";
|
||
|
|
if (typeof deltaText === "string" && deltaText.length) {
|
||
|
|
roundText += deltaText;
|
||
|
|
if (!roundHasToolCalls) {
|
||
|
|
streamedRoundText += deltaText;
|
||
|
|
yield { type: "delta", text: deltaText };
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
const deltaToolCalls = Array.isArray(choice?.delta?.tool_calls) ? choice.delta.tool_calls : [];
|
||
|
|
if (deltaToolCalls.length) {
|
||
|
|
roundHasToolCalls = true;
|
||
|
|
}
|
||
|
|
for (const toolCall of deltaToolCalls) {
|
||
|
|
const idx = typeof toolCall?.index === "number" ? toolCall.index : 0;
|
||
|
|
const entry = roundToolCalls.get(idx) ?? { arguments: "" };
|
||
|
|
if (typeof toolCall?.id === "string" && toolCall.id.length) {
|
||
|
|
entry.id = toolCall.id;
|
||
|
|
}
|
||
|
|
if (typeof toolCall?.function?.name === "string" && toolCall.function.name.length) {
|
||
|
|
entry.name = toolCall.function.name;
|
||
|
|
}
|
||
|
|
if (typeof toolCall?.function?.arguments === "string" && toolCall.function.arguments.length) {
|
||
|
|
entry.arguments += toolCall.function.arguments;
|
||
|
|
}
|
||
|
|
roundToolCalls.set(idx, entry);
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
const normalizedToolCalls: NormalizedToolCall[] = [...roundToolCalls.entries()]
|
||
|
|
.sort((a, b) => a[0] - b[0])
|
||
|
|
.map(([_, call], index) => ({
|
||
|
|
id: call.id ?? `tool_call_${round}_${index}`,
|
||
|
|
name: call.name ?? "unknown_tool",
|
||
|
|
arguments: call.arguments || "{}",
|
||
|
|
}));
|
||
|
|
|
||
|
|
if (!normalizedToolCalls.length) {
|
||
|
|
if (!streamedRoundText && danglingToolIntentRetries < MAX_DANGLING_TOOL_INTENT_RETRIES && looksLikeDanglingToolIntent(roundText)) {
|
||
|
|
danglingToolIntentRetries += 1;
|
||
|
|
appendDanglingToolIntentCorrection(conversation, roundText);
|
||
|
|
continue;
|
||
|
|
}
|
||
|
|
const unstreamedText = getUnstreamedText(roundText, streamedRoundText);
|
||
|
|
if (unstreamedText) {
|
||
|
|
yield { type: "delta", text: unstreamedText };
|
||
|
|
}
|
||
|
|
yield {
|
||
|
|
type: "done",
|
||
|
|
result: {
|
||
|
|
text: roundText,
|
||
|
|
usage: sawUsage ? usageAcc : undefined,
|
||
|
|
raw: { streamed: true, responses: rawResponses, toolCallsUsed: totalToolCalls },
|
||
|
|
toolEvents,
|
||
|
|
},
|
||
|
|
};
|
||
|
|
return;
|
||
|
|
}
|
||
|
|
|
||
|
|
totalToolCalls += normalizedToolCalls.length;
|
||
|
|
const assistantToolCallMessage: any = {
|
||
|
|
role: "assistant",
|
||
|
|
tool_calls: normalizedToolCalls.map((call) => ({
|
||
|
|
id: call.id,
|
||
|
|
type: "function",
|
||
|
|
function: {
|
||
|
|
name: call.name,
|
||
|
|
arguments: call.arguments,
|
||
|
|
},
|
||
|
|
})),
|
||
|
|
};
|
||
|
|
if (roundText) {
|
||
|
|
assistantToolCallMessage.content = roundText;
|
||
|
|
}
|
||
|
|
conversation.push(assistantToolCallMessage);
|
||
|
|
|
||
|
|
for (const call of normalizedToolCalls) {
|
||
|
|
const { event: initiatedEvent, execution } = prepareToolCallExecution(call);
|
||
|
|
yield { type: "tool_call", event: initiatedEvent };
|
||
|
|
const { event, toolResult } = await executeToolCallAndBuildEvent(call, execution, params);
|
||
|
|
toolEvents.push(event);
|
||
|
|
yield { type: "tool_call", event };
|
||
|
|
conversation.push({
|
||
|
|
role: "tool",
|
||
|
|
tool_call_id: call.id,
|
||
|
|
content: JSON.stringify(toolResult),
|
||
|
|
});
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
yield {
|
||
|
|
type: "done",
|
||
|
|
result: {
|
||
|
|
text: "I reached the tool-call limit while gathering information. Please narrow the request and try again.",
|
||
|
|
usage: sawUsage ? usageAcc : undefined,
|
||
|
|
raw: { streamed: true, responses: rawResponses, toolCallsUsed: totalToolCalls, toolCallLimitReached: true },
|
||
|
|
toolEvents,
|
||
|
|
},
|
||
|
|
};
|
||
|
|
}
|