diff --git a/apps/client/src/services/llm_chat.ts b/apps/client/src/services/llm_chat.ts index fa0a0279d38..fdfa8170bb0 100644 --- a/apps/client/src/services/llm_chat.ts +++ b/apps/client/src/services/llm_chat.ts @@ -13,7 +13,7 @@ export async function getAvailableModels(): Promise { export interface StreamCallbacks { onChunk: (text: string) => void; onThinking?: (text: string) => void; - onToolUse?: (toolName: string, input: Record) => void; + onToolUse?: (toolName: string, input: Record, requiresApproval?: boolean) => void; onToolResult?: (toolName: string, result: string, isError?: boolean) => void; onCitation?: (citation: LlmCitation) => void; onUsage?: (usage: LlmUsage) => void; @@ -76,7 +76,7 @@ export async function streamChatCompletion( callbacks.onThinking?.(data.content); break; case "tool_use": - callbacks.onToolUse?.(data.toolName, data.toolInput); + callbacks.onToolUse?.(data.toolName, data.toolInput, data.requiresApproval); // Yield to force Preact to commit the pending tool call // state before we process the result. await new Promise((r) => setTimeout(r, 1)); @@ -112,3 +112,18 @@ export async function streamChatCompletion( reader.releaseLock(); } } + +/** + * Execute a mutating tool call after user approval. + */ +export async function executeToolCall(toolName: string, toolInput: Record): Promise<{ result: string; isError?: boolean }> { + const response = await server.post<{ result?: object; error?: string }>("llm-chat/execute-tool", { toolName, toolInput }); + + if (response.error) { + return { result: response.error, isError: true }; + } + + return { + result: typeof response.result === "string" ? response.result : JSON.stringify(response.result) + }; +} diff --git a/apps/client/src/translations/en/translation.json b/apps/client/src/translations/en/translation.json index cd0523abc86..ed31839a5bf 100644 --- a/apps/client/src/translations/en/translation.json +++ b/apps/client/src/translations/en/translation.json @@ -1664,7 +1664,11 @@ "note_context_enabled": "Click to disable note context: {{title}}", "note_context_disabled": "Click to include current note in context", "no_provider_message": "No AI provider configured. Add one to start chatting.", - "add_provider": "Add AI Provider" + "add_provider": "Add AI Provider", + "approve": "Approve", + "reject": "Reject", + "pending_approval": "Pending approval", + "rejected_by_user": "Rejected by user" }, "sidebar_chat": { "title": "AI Chat", diff --git a/apps/client/src/widgets/type_widgets/llm_chat/ChatMessage.tsx b/apps/client/src/widgets/type_widgets/llm_chat/ChatMessage.tsx index 26a7f36c464..e0d0b6d9523 100644 --- a/apps/client/src/widgets/type_widgets/llm_chat/ChatMessage.tsx +++ b/apps/client/src/widgets/type_widgets/llm_chat/ChatMessage.tsx @@ -62,6 +62,8 @@ function MarkdownContent({ html, isStreaming }: { html: string; isStreaming?: bo interface Props { message: StoredMessage; isStreaming?: boolean; + onApproveToolCall?: (toolCallId: string) => Promise; + onRejectToolCall?: (toolCallId: string) => void; } type ContentGroup = @@ -127,7 +129,7 @@ function CitationsSection({ citations }: { citations: LlmCitation[] }) { ); } -export default function ChatMessage({ message, isStreaming }: Props) { +export default function ChatMessage({ message, isStreaming, onApproveToolCall, onRejectToolCall }: Props) { const isError = message.type === "error"; const isThinking = message.type === "thinking"; const textContent = typeof message.content === "string" ? message.content : getMessageText(message.content); @@ -172,7 +174,7 @@ export default function ChatMessage({ message, isStreaming }: Props) {
{message.role === "assistant" && !isError ? ( hasBlockContent ? ( - renderContentBlocks(message.content as ContentBlock[], isStreaming) + renderContentBlocks(message.content as ContentBlock[], isStreaming, onApproveToolCall, onRejectToolCall) ) : ( ) @@ -244,7 +246,12 @@ function groupContentBlocks(blocks: ContentBlock[]): ContentGroup[] { return groups; } -function renderContentBlocks(blocks: ContentBlock[], isStreaming?: boolean) { +function renderContentBlocks( + blocks: ContentBlock[], + isStreaming?: boolean, + onApproveToolCall?: (toolCallId: string) => Promise, + onRejectToolCall?: (toolCallId: string) => void +) { return groupContentBlocks(blocks).map((group) => { if (group.type === "text") { const html = renderMarkdown(group.block.content); @@ -256,6 +263,13 @@ function renderContentBlocks(blocks: ContentBlock[], isStreaming?: boolean) { ); } - return b.toolCall)} />; + return ( + b.toolCall)} + onApprove={onApproveToolCall} + onReject={onRejectToolCall} + /> + ); }); } diff --git a/apps/client/src/widgets/type_widgets/llm_chat/ExpandableCard.tsx b/apps/client/src/widgets/type_widgets/llm_chat/ExpandableCard.tsx index 2e12c08e644..d9e5476cc2a 100644 --- a/apps/client/src/widgets/type_widgets/llm_chat/ExpandableCard.tsx +++ b/apps/client/src/widgets/type_widgets/llm_chat/ExpandableCard.tsx @@ -7,12 +7,14 @@ interface ExpandableSectionProps { label: ComponentChildren; className?: string; children: ComponentChildren; + /** Whether the section should be expanded by default */ + defaultExpanded?: boolean; } /** A collapsible section within an ExpandableCard. */ -export function ExpandableSection({ icon, label, className, children }: ExpandableSectionProps) { +export function ExpandableSection({ icon, label, className, children, defaultExpanded }: ExpandableSectionProps) { return ( -
+
{label} diff --git a/apps/client/src/widgets/type_widgets/llm_chat/LlmChat.tsx b/apps/client/src/widgets/type_widgets/llm_chat/LlmChat.tsx index 301141f0d88..3f210714524 100644 --- a/apps/client/src/widgets/type_widgets/llm_chat/LlmChat.tsx +++ b/apps/client/src/widgets/type_widgets/llm_chat/LlmChat.tsx @@ -63,7 +63,12 @@ export default function LlmChat({ note, ntxId, noteContext }: TypeWidgetProps) { /> )} {chat.messages.map(msg => ( - + ))} {chat.isStreaming && chat.streamingThinking && ( Promise; + onReject?: (toolCallId: string) => void; +}) { const hasError = toolCall.isError; + const isPendingApproval = toolCall.requiresApproval && !toolCall.result && !toolCall.rejected; return ( } className={hasError ? "llm-chat-tool-call-error" : ""} + defaultExpanded={isPendingApproval} >
{t("llm_chat.input")}
+ {isPendingApproval && onApprove && onReject && ( +
+ {t("llm_chat.pending_approval")} +
+ + +
+
+ )} {toolCall.result && (
{hasError ? t("llm_chat.error") : t("llm_chat.result")} @@ -202,11 +228,15 @@ function ToolCallSection({ toolCall }: { toolCall: ToolCall }) { } /** A card that groups one or more sequential tool calls together. */ -export default function ToolCallCard({ toolCalls }: { toolCalls: ToolCall[] }) { +export default function ToolCallCard({ toolCalls, onApprove, onReject }: { + toolCalls: ToolCall[]; + onApprove?: (toolCallId: string) => Promise; + onReject?: (toolCallId: string) => void; +}) { return ( {toolCalls.map((tc, idx) => ( - + ))} ); diff --git a/apps/client/src/widgets/type_widgets/llm_chat/llm_chat_types.ts b/apps/client/src/widgets/type_widgets/llm_chat/llm_chat_types.ts index d05bf291b5b..fec7f6e277c 100644 --- a/apps/client/src/widgets/type_widgets/llm_chat/llm_chat_types.ts +++ b/apps/client/src/widgets/type_widgets/llm_chat/llm_chat_types.ts @@ -8,6 +8,8 @@ export interface ToolCall { input: Record; result?: string; isError?: boolean; + requiresApproval?: boolean; + rejected?: boolean; } /** A block of text content (rendered as Markdown for assistant messages). */ diff --git a/apps/client/src/widgets/type_widgets/llm_chat/useLlmChat.ts b/apps/client/src/widgets/type_widgets/llm_chat/useLlmChat.ts index 63cbf4bbf42..e1f7e0ce593 100644 --- a/apps/client/src/widgets/type_widgets/llm_chat/useLlmChat.ts +++ b/apps/client/src/widgets/type_widgets/llm_chat/useLlmChat.ts @@ -2,7 +2,8 @@ import type { LlmCitation, LlmMessage, LlmModelInfo, LlmUsage } from "@triliumne import { RefObject } from "preact"; import { useCallback, useEffect, useRef, useState } from "preact/hooks"; -import { getAvailableModels, streamChatCompletion } from "../../../services/llm_chat.js"; +import { executeToolCall, getAvailableModels, streamChatCompletion } from "../../../services/llm_chat.js"; +import { t } from "../../../services/i18n.js"; import { randomString } from "../../../services/utils.js"; import type { ContentBlock, LlmChatContent, StoredMessage } from "./llm_chat_types.js"; @@ -62,6 +63,10 @@ export interface UseLlmChatReturn { clearMessages: () => void; /** Refresh the provider/models list */ refreshModels: () => void; + /** Approve a pending mutating tool call */ + approveToolCall: (toolCallId: string) => Promise; + /** Reject a pending mutating tool call */ + rejectToolCall: (toolCallId: string) => void; } export function useLlmChat( @@ -267,13 +272,14 @@ export function useLlmChat( thinkingContent += text; setStreamingThinking(thinkingContent); }, - onToolUse: (toolName, toolInput) => { + onToolUse: (toolName, toolInput, requiresApproval) => { contentBlocks.push({ type: "tool_call", toolCall: { id: randomString(), toolName, - input: toolInput + input: toolInput, + requiresApproval } }); setStreamingBlocks([...contentBlocks]); @@ -365,6 +371,44 @@ export function useLlmChat( } }, [handleSubmit]); + /** Approve a pending mutating tool call — execute it server-side and update the message. */ + const approveToolCall = useCallback(async (toolCallId: string) => { + // Find the tool call in messages + for (const msg of messages) { + if (!Array.isArray(msg.content)) continue; + for (const block of msg.content) { + if (block.type === "tool_call" && block.toolCall.id === toolCallId && block.toolCall.requiresApproval && !block.toolCall.result) { + const { result, isError } = await executeToolCall(block.toolCall.toolName, block.toolCall.input); + // Update the tool call block immutably + const updatedContent = msg.content.map(b => + b.type === "tool_call" && b.toolCall.id === toolCallId + ? { ...b, toolCall: { ...b.toolCall, result, isError } } + : b + ); + const updatedMessages = messages.map(m => + m.id === msg.id ? { ...m, content: updatedContent } : m + ); + setMessages(updatedMessages); + return; + } + } + } + }, [messages, setMessages]); + + /** Reject a pending mutating tool call. */ + const rejectToolCall = useCallback((toolCallId: string) => { + const updatedMessages = messages.map(msg => { + if (!Array.isArray(msg.content)) return msg; + const updatedContent = msg.content.map(b => + b.type === "tool_call" && b.toolCall.id === toolCallId + ? { ...b, toolCall: { ...b.toolCall, rejected: true, result: t("llm_chat.rejected_by_user"), isError: true } } + : b + ); + return { ...msg, content: updatedContent }; + }); + setMessages(updatedMessages); + }, [messages, setMessages]); + return { // State messages, @@ -402,6 +446,8 @@ export function useLlmChat( loadFromContent, getContent, clearMessages, - refreshModels + refreshModels, + approveToolCall, + rejectToolCall }; } diff --git a/apps/server/src/routes/api/llm_chat.ts b/apps/server/src/routes/api/llm_chat.ts index dd5bf149c8b..ca3d3c7e51a 100644 --- a/apps/server/src/routes/api/llm_chat.ts +++ b/apps/server/src/routes/api/llm_chat.ts @@ -4,7 +4,9 @@ import type { Request, Response } from "express"; import { generateChatTitle } from "../../services/llm/chat_title.js"; import { getAllModels, getProviderByType, hasConfiguredProviders, type LlmProviderConfig } from "../../services/llm/index.js"; import { streamToChunks } from "../../services/llm/stream.js"; +import { allToolRegistries } from "../../services/llm/tools/index.js"; import log from "../../services/log.js"; +import sql from "../../services/sql.js"; import { safeExtractMessageAndStackFromError } from "../../services/utils.js"; interface ChatRequest { @@ -51,6 +53,15 @@ async function streamChat(req: Request, res: Response) { } const provider = getProviderByType(config.provider || "anthropic"); + + // Collect names of tools that require human approval + const mutatingToolNames = new Set(); + for (const registry of allToolRegistries) { + for (const name of registry.getMutatingToolNames()) { + mutatingToolNames.add(name); + } + } + const result = provider.chat(messages, config); // Get pricing and display name for the model @@ -62,7 +73,7 @@ async function streamChat(req: Request, res: Response) { const pricing = provider.getModelPricing(modelId); const modelDisplayName = provider.getAvailableModels().find(m => m.id === modelId)?.name || modelId; - for await (const chunk of streamToChunks(result, { model: modelDisplayName, pricing })) { + for await (const chunk of streamToChunks(result, { model: modelDisplayName, pricing, mutatingToolNames })) { res.write(`data: ${JSON.stringify(chunk)}\n\n`); // Flush immediately to ensure real-time streaming if (typeof flushableRes.flush === "function") { @@ -98,7 +109,36 @@ function getModels(_req: Request, _res: Response) { return { models: getAllModels() }; } +/** + * Execute a single tool call after user approval. + * Used for mutating tools that require human-in-the-loop confirmation. + */ +function executeTool(req: Request, _res: Response) { + const { toolName, toolInput } = req.body as { toolName: string; toolInput: Record }; + + if (!toolName || typeof toolName !== "string") { + return { error: "toolName is required" }; + } + + // Find the tool definition across all registries + for (const registry of allToolRegistries) { + for (const [name, def] of registry) { + if (name === toolName) { + if (!def.mutates) { + return { error: "Only mutating tools can be executed via this endpoint" }; + } + + const result = sql.transactional(() => def.execute(toolInput)); + return { result }; + } + } + } + + return { error: `Tool '${toolName}' not found` }; +} + export default { streamChat, - getModels + getModels, + executeTool }; diff --git a/apps/server/src/routes/routes.ts b/apps/server/src/routes/routes.ts index 198fa5c22e4..58429e5a9ac 100644 --- a/apps/server/src/routes/routes.ts +++ b/apps/server/src/routes/routes.ts @@ -332,6 +332,7 @@ function register(app: express.Application) { // LLM chat endpoints asyncRoute(PST, "/api/llm-chat/stream", [auth.checkApiAuthOrElectron, csrfMiddleware], llmChatRoute.streamChat, null); apiRoute(GET, "/api/llm-chat/models", llmChatRoute.getModels); + apiRoute(PST, "/api/llm-chat/execute-tool", llmChatRoute.executeTool); // no CSRF since this is called from android app route(PST, "/api/sender/login", [loginRateLimiter], loginApiRoute.token, apiResultHandler); diff --git a/apps/server/src/services/llm/stream.ts b/apps/server/src/services/llm/stream.ts index e66da16a5b4..50d35166ec7 100644 --- a/apps/server/src/services/llm/stream.ts +++ b/apps/server/src/services/llm/stream.ts @@ -23,6 +23,8 @@ export interface StreamOptions { model?: string; /** Model pricing for cost calculation (from provider) */ pricing?: ModelPricing; + /** Names of mutating tools that require user approval before execution */ + mutatingToolNames?: Set; } /** @@ -45,7 +47,8 @@ export async function* streamToChunks(result: StreamResult, options: StreamOptio yield { type: "tool_use", toolName: part.toolName, - toolInput: part.input as Record + toolInput: part.input as Record, + ...(options.mutatingToolNames?.has(part.toolName) ? { requiresApproval: true } : {}) }; break; diff --git a/apps/server/src/services/llm/tools/tool_registry.ts b/apps/server/src/services/llm/tools/tool_registry.ts index 35fc225182c..e0a399d00bf 100644 --- a/apps/server/src/services/llm/tools/tool_registry.ts +++ b/apps/server/src/services/llm/tools/tool_registry.ts @@ -12,8 +12,6 @@ import { tool } from "ai"; import type { z } from "zod"; import type { ToolSet } from "ai"; -import sql from "../../sql.js"; - /** * Type constraint that rejects Promises at compile time. * Works by requiring `then` to be void if present - Promises have `then: Function`. @@ -56,24 +54,35 @@ export class ToolRegistry implements Iterable<[string, ToolDefinition]> { /** * Convert to an AI SDK ToolSet for use with the LLM chat providers. - * Mutating tools are wrapped in a transaction for consistency with MCP. + * Read-only tools are given an `execute` function so the AI SDK auto-runs them. + * Mutating tools are registered WITHOUT `execute` so the SDK emits a tool-call + * event but does NOT auto-execute — the client must approve first. * (CLS context is provided by the route handler.) */ toToolSet(): ToolSet { const set: ToolSet = {}; for (const [name, def] of this) { - const execute = def.mutates - ? (args: unknown) => sql.transactional(() => def.execute(args)) - : def.execute; - - set[name] = tool({ - description: def.description, - inputSchema: def.inputSchema, - execute - }); + if (def.mutates) { + // No execute → AI SDK emits tool-call but doesn't auto-execute (human-in-the-loop) + set[name] = tool({ + description: def.description, + inputSchema: def.inputSchema + }); + } else { + set[name] = tool({ + description: def.description, + inputSchema: def.inputSchema, + execute: def.execute + }); + } } return set; } + + /** Return the names of all mutating tools in this registry. */ + getMutatingToolNames(): string[] { + return [...this].filter(([, def]) => def.mutates).map(([name]) => name); + } } /**