diff --git a/backend/agents/create_agent_info.py b/backend/agents/create_agent_info.py index b8d1ae101..ad4cde7b7 100644 --- a/backend/agents/create_agent_info.py +++ b/backend/agents/create_agent_info.py @@ -310,6 +310,7 @@ async def create_agent_config( allow_memory_search: bool = True, version_no: int = 0, override_model_id: int | None = None, + conversation_id: int = None, ): agent_info = search_agent_info_by_agent_id( agent_id=agent_id, tenant_id=tenant_id, version_no=version_no) @@ -331,13 +332,14 @@ async def create_agent_config( allow_memory_search=allow_memory_search, version_no=sub_agent_version_no, override_model_id=None, + conversation_id=None, ) managed_agents.append(sub_agent_config) # create external A2A agents (synchronous function, no await needed) external_a2a_agents = _get_external_a2a_agents(agent_id, tenant_id, version_no) - tool_list = await create_tool_config_list(agent_id, tenant_id, user_id, version_no=version_no) + tool_list = await create_tool_config_list(agent_id, tenant_id, user_id, version_no=version_no, conversation_id=conversation_id) # Build system prompt: prioritize segmented fields, fallback to original prompt field if not available duty_prompt = agent_info.get("duty_prompt", "") @@ -562,7 +564,7 @@ async def create_agent_config( return agent_config -async def create_tool_config_list(agent_id, tenant_id, user_id, version_no: int = 0): +async def create_tool_config_list(agent_id, tenant_id, user_id, version_no: int = 0, conversation_id: int = None): # create tool tool_config_list = [] langchain_tools = await discover_langchain_tools() @@ -665,6 +667,17 @@ async def create_tool_config_list(agent_id, tenant_id, user_id, version_no: int "storage_client": minio_client, "validate_url_access": lambda urls: validate_urls_access(urls, user_id) } + elif tool_config.class_name == "ScheduledTaskTool": + from database.scheduled_task_db import create_scheduled_task, query_tasks_by_agent, cancel_task + tool_config.metadata = { + "db_create": create_scheduled_task, + "db_list": query_tasks_by_agent, + "db_cancel": cancel_task, + "agent_id": agent_id, + "tenant_id": tenant_id, + "user_id": user_id, + "conversation_id": conversation_id, + } tool_config_list.append(tool_config) @@ -929,6 +942,7 @@ async def create_agent_run_info( is_debug: bool = False, override_version_no: int | None = None, override_model_id: int | None = None, + conversation_id: int = None, ): # Determine which version_no to use based on is_debug flag # If is_debug=false, use the current published version (current_version_no) @@ -957,6 +971,7 @@ async def create_agent_run_info( "last_user_query": final_query, "allow_memory_search": allow_memory_search, "version_no": version_no, + "conversation_id": conversation_id, } if override_model_id is not None: create_config_kwargs["override_model_id"] = override_model_id diff --git a/backend/apps/conversation_management_app.py b/backend/apps/conversation_management_app.py index 9beeedf2e..b53b4d76a 100644 --- a/backend/apps/conversation_management_app.py +++ b/backend/apps/conversation_management_app.py @@ -3,6 +3,7 @@ from typing import Any, Dict, Optional from fastapi import APIRouter, Header, HTTPException, Request +from starlette.responses import JSONResponse from consts.model import ( ConversationRequest, @@ -18,6 +19,7 @@ generate_conversation_title_service, get_conversation_history_service, get_conversation_list_service, + get_new_messages_service, get_sources_service, rename_conversation_service, update_message_opinion_service, get_message_id_by_index_impl, @@ -240,3 +242,52 @@ async def get_message_id_endpoint(request: MessageIdRequest): except Exception as e: logging.error(f"Failed to get message ID: {str(e)}") raise HTTPException(status_code=HTTPStatus.INTERNAL_SERVER_ERROR, detail=str(e)) + + +@router.get("/{conversation_id}/new_messages", response_model=Dict[str, Any]) +async def check_new_messages_endpoint(conversation_id: int, since_index: int = 0, authorization: Optional[str] = Header(None)): + """ + Lightweight polling: check if new messages exist for a single conversation. + + Args: + conversation_id: Conversation ID + since_index: Last known message index on the client side + authorization: Authorization header + + Returns: + Dict with has_new, max_index, since_index + """ + try: + user_id, tenant_id = get_current_user_id(authorization) + result = get_new_messages_service(conversation_id, user_id, since_index) + return JSONResponse(status_code=HTTPStatus.OK, content=result) + except Exception as e: + logging.error(f"Failed to check new messages: {str(e)}") + raise HTTPException(status_code=HTTPStatus.INTERNAL_SERVER_ERROR, detail=str(e)) + + +@router.post("/batch_new_messages", response_model=Dict[str, Any]) +async def batch_check_new_messages_endpoint(request: Dict[str, Any], authorization: Optional[str] = Header(None)): + """ + Batch check for new messages across multiple conversations. + + Args: + request: Body with "checks" list of {"conversation_id": int, "since_index": int} + authorization: Authorization header + + Returns: + Dict mapping conversation_id to {has_new, max_index, since_index} + """ + try: + user_id, tenant_id = get_current_user_id(authorization) + checks = request.get("checks", []) + results = {} + for check in checks: + cid = check.get("conversation_id") + since = check.get("since_index", 0) + if cid is not None: + results[str(cid)] = get_new_messages_service(cid, user_id, since) + return JSONResponse(status_code=HTTPStatus.OK, content={"results": results}) + except Exception as e: + logging.error(f"Failed to batch check new messages: {str(e)}") + raise HTTPException(status_code=HTTPStatus.INTERNAL_SERVER_ERROR, detail=str(e)) diff --git a/backend/apps/runtime_app.py b/backend/apps/runtime_app.py index ba856b3ce..2e00825c7 100644 --- a/backend/apps/runtime_app.py +++ b/backend/apps/runtime_app.py @@ -8,6 +8,7 @@ from apps.file_management_app import file_management_runtime_router as file_management_router from apps.skill_app import skill_creator_router from middleware.exception_handler import ExceptionHandlerMiddleware +from services.scheduled_task_scheduler import scheduled_task_scheduler # Create logger instance logger = logging.getLogger("runtime_app") @@ -24,3 +25,13 @@ app.include_router(file_management_router) app.include_router(voice_router) app.include_router(skill_creator_router) + + +@app.on_event("startup") +async def start_scheduled_task_scheduler(): + scheduled_task_scheduler.start() + + +@app.on_event("shutdown") +async def stop_scheduled_task_scheduler(): + scheduled_task_scheduler.stop() diff --git a/backend/database/conversation_db.py b/backend/database/conversation_db.py index 2d06bb9be..dc1d22505 100644 --- a/backend/database/conversation_db.py +++ b/backend/database/conversation_db.py @@ -956,6 +956,17 @@ def get_source_searches_by_conversation(conversation_id: int, user_id: Optional[ return [as_dict(record) for record in search_records] +def get_max_message_index(conversation_id: int) -> int: + """Return the maximum message_index for a conversation, or -1 if empty.""" + with get_db_session() as session: + conversation_id = int(conversation_id) + stmt = select(func.coalesce(func.max(ConversationMessage.message_index), -1)).where( + ConversationMessage.conversation_id == conversation_id, + ConversationMessage.delete_flag == 'N', + ) + return session.execute(stmt).scalar() + + def get_message(message_id: int, user_id: Optional[str] = None) -> Dict[str, Any]: """ Get message details by message ID diff --git a/backend/database/db_models.py b/backend/database/db_models.py index 8a20e9003..8d1ec5c3b 100644 --- a/backend/database/db_models.py +++ b/backend/database/db_models.py @@ -1164,6 +1164,35 @@ class A2AMessage(SimpleTableBase): timezone=False), server_default=func.now(), doc="Message creation timestamp") +class ScheduledTaskRecord(TableBase): + """ + Scheduled task records for deferred / recurring agent execution. + """ + __tablename__ = "scheduled_tasks_t" + __table_args__ = ( + Index("ix_scheduled_task_status_next_fire", "status", "next_fire_time"), + Index("ix_scheduled_task_agent_delete", "agent_id", "delete_flag"), + {"schema": SCHEMA}, + ) + + task_id = Column(Integer, Sequence("scheduled_tasks_t_task_id_seq", schema=SCHEMA), + primary_key=True, nullable=False, autoincrement=True, doc="Primary key") + task_uuid = Column(String(36), unique=True, nullable=False, doc="Unique task identifier (UUID)") + task_name = Column(String(200), doc="Human-readable task name") + task_prompt = Column(Text, nullable=False, doc="The prompt to execute when the task fires") + task_type = Column(String(10), nullable=False, doc="Task type: oneshot or cron") + cron_expression = Column(String(100), doc="Cron expression for recurring tasks") + delay_seconds = Column(Integer, doc="Delay in seconds for oneshot tasks") + status = Column(String(20), default="pending", doc="Task status: pending, fired, cancelled, error") + next_fire_time = Column(TIMESTAMP(timezone=False), doc="Next scheduled execution time") + fire_count = Column(Integer, default=0, doc="Number of times this task has fired") + max_fires = Column(Integer, nullable=True, doc="Maximum number of fires (NULL = unlimited)") + agent_id = Column(Integer, nullable=False, doc="Agent ID that owns this task") + conversation_id = Column(Integer, nullable=True, doc="Conversation ID associated with this task") + tenant_id = Column(String(100), nullable=False, doc="Tenant ID for multi-tenancy isolation") + user_id = Column(String(100), nullable=False, doc="User ID who created this task") + + class A2AArtifact(SimpleTableBase): """ A2A artifacts. Stores the output/artifacts produced by a task. diff --git a/backend/database/scheduled_task_db.py b/backend/database/scheduled_task_db.py new file mode 100644 index 000000000..4e903e968 --- /dev/null +++ b/backend/database/scheduled_task_db.py @@ -0,0 +1,82 @@ +import logging +from datetime import datetime +from typing import Optional + +from sqlalchemy import select, update + +from .client import as_dict, get_db_session +from .db_models import ScheduledTaskRecord + +logger = logging.getLogger("scheduled_task_db") + + +def create_scheduled_task(data: dict) -> dict: + """Insert a new scheduled task record and return it as a dict.""" + with get_db_session() as session: + record = ScheduledTaskRecord(**data) + session.add(record) + session.flush() + return as_dict(record) + + +def query_tasks_by_agent(agent_id: int, tenant_id: str, user_id: str = None) -> list[dict]: + """Return pending tasks for a given agent and tenant, optionally filtered by user.""" + with get_db_session() as session: + stmt = select(ScheduledTaskRecord).where( + ScheduledTaskRecord.agent_id == agent_id, + ScheduledTaskRecord.tenant_id == tenant_id, + ScheduledTaskRecord.status == "pending", + ScheduledTaskRecord.delete_flag == "N", + ) + if user_id: + stmt = stmt.where(ScheduledTaskRecord.user_id == user_id) + stmt = stmt.order_by(ScheduledTaskRecord.task_id.desc()) + records = session.scalars(stmt).all() + return [as_dict(r) for r in records] + + +def query_pending_tasks_due(now: datetime) -> list[dict]: + """Return all pending tasks whose next_fire_time <= now (global, no tenant filter).""" + with get_db_session() as session: + stmt = select(ScheduledTaskRecord).where( + ScheduledTaskRecord.status == "pending", + ScheduledTaskRecord.next_fire_time <= now, + ScheduledTaskRecord.delete_flag == "N", + ) + records = session.scalars(stmt).all() + return [as_dict(r) for r in records] + + +def cancel_task(task_uuid: str, agent_id: int, tenant_id: str, user_id: str = None) -> bool: + """Soft-cancel a task. Optionally restrict to a specific user for isolation.""" + with get_db_session() as session: + conditions = [ + ScheduledTaskRecord.task_uuid == task_uuid, + ScheduledTaskRecord.agent_id == agent_id, + ScheduledTaskRecord.tenant_id == tenant_id, + ScheduledTaskRecord.delete_flag == "N", + ScheduledTaskRecord.status == "pending", + ] + if user_id: + conditions.append(ScheduledTaskRecord.user_id == user_id) + stmt = ( + update(ScheduledTaskRecord) + .where(*conditions) + .values(status="cancelled") + ) + result = session.execute(stmt) + return result.rowcount > 0 + + +def update_task_status(task_uuid: str, updates: dict) -> None: + """Update arbitrary columns on a task record identified by task_uuid.""" + with get_db_session() as session: + stmt = ( + update(ScheduledTaskRecord) + .where( + ScheduledTaskRecord.task_uuid == task_uuid, + ScheduledTaskRecord.delete_flag == "N", + ) + .values(**updates) + ) + session.execute(stmt) diff --git a/backend/services/conversation_management_service.py b/backend/services/conversation_management_service.py index 81c2bfa98..bb0406f50 100644 --- a/backend/services/conversation_management_service.py +++ b/backend/services/conversation_management_service.py @@ -384,6 +384,17 @@ def delete_conversation_service(conversation_id: int, user_id: str) -> bool: raise Exception(str(e)) +def get_new_messages_service(conversation_id: int, user_id: str, since_index: int) -> Dict[str, Any]: + """Lightweight polling: check for new messages after since_index.""" + from database.conversation_db import get_conversation, get_max_message_index + conv = get_conversation(conversation_id, user_id) + if conv is None: + return {"has_new": False, "max_index": since_index, "since_index": since_index} + max_idx = get_max_message_index(conversation_id) + has_new = max_idx > since_index + return {"has_new": has_new, "max_index": max_idx, "since_index": since_index} + + def get_conversation_history_service(conversation_id: int, user_id: str) -> List[Dict[str, Any]]: """ Get complete history of specified conversation diff --git a/backend/services/scheduled_task_scheduler.py b/backend/services/scheduled_task_scheduler.py new file mode 100644 index 000000000..384f1376a --- /dev/null +++ b/backend/services/scheduled_task_scheduler.py @@ -0,0 +1,205 @@ +"""Global scheduler singleton that periodically polls for due scheduled tasks and executes them.""" + +import asyncio +import logging +import threading +from datetime import datetime, timezone + +logger = logging.getLogger("scheduled_task_scheduler") + + +def _save_simple_message(conversation_id, msg_idx, role, content, user_id, tenant_id): + """Save a single text message to the conversation.""" + from services.conversation_management_service import save_message + from consts.model import MessageRequest, MessageUnit + save_message( + MessageRequest( + conversation_id=conversation_id, + message_idx=msg_idx, + role=role, + message=[MessageUnit(type="string", content=content)], + minio_files=None, + ), + user_id=user_id, + tenant_id=tenant_id, + ) + + +def _run_scheduled_task_from_db(task_dict: dict): + """Execute a due scheduled task synchronously in a background thread. + + The function: + 1. Gets the max message index to avoid collisions. + 2. Saves a user message containing the task prompt with scheduling context. + 3. Creates an AgentRunInfo via create_agent_run_info. + 4. Removes ScheduledTaskTool from the agent's tools to prevent recursive scheduling. + 5. Runs the agent and saves the assistant response. + """ + from database.conversation_db import get_max_message_index + from agents.create_agent_info import create_agent_run_info + from nexent.core.agents.run_agent import agent_run + from consts.const import MESSAGE_ROLE + + task_uuid = task_dict.get("task_uuid", "unknown") + task_prompt = task_dict.get("task_prompt", "") + agent_id = task_dict.get("agent_id") + conversation_id = task_dict.get("conversation_id") + tenant_id = task_dict.get("tenant_id") + user_id = task_dict.get("user_id") + + if not all([agent_id, tenant_id, user_id, conversation_id]): + logger.error(f"Task {task_uuid} is missing required fields, skipping execution") + return + + try: + # Get current max message index to avoid collisions + max_idx = get_max_message_index(conversation_id) + + # Save user message with scheduling instruction + user_content = ( + f"[定时任务触发] 以下是一条已到期的定时任务,请直接执行任务内容并回复用户。" + f"不要创建新的定时任务,不要调用 scheduled_task 工具。\n\n任务内容:{task_prompt}" + ) + _save_simple_message( + conversation_id, max_idx + 1, MESSAGE_ROLE["USER"], + user_content, user_id, tenant_id, + ) + + # Create agent run info + agent_run_info = asyncio.run(create_agent_run_info( + agent_id=agent_id, + minio_files=None, + query=user_content, + history=[], + tenant_id=tenant_id, + user_id=user_id, + conversation_id=conversation_id, + )) + + # Remove ScheduledTaskTool from agent_config.tools to prevent recursive scheduling + if hasattr(agent_run_info, "agent_config") and hasattr(agent_run_info.agent_config, "tools"): + agent_run_info.agent_config.tools = [ + t for t in agent_run_info.agent_config.tools + if t.class_name != "ScheduledTaskTool" + ] + + # Run agent and collect response chunks + chunks = [] + async def _run_and_collect(): + async for chunk in agent_run(agent_run_info): + chunks.append(chunk) + + loop = asyncio.new_event_loop() + try: + loop.run_until_complete(_run_and_collect()) + finally: + loop.close() + + # Build assistant response text from chunks + response_parts = [ + c.get("content", "") if isinstance(c, dict) else str(c) + for c in chunks + ] + assistant_content = "".join(p for p in response_parts if p) or "(task completed with no output)" + + # Save assistant message + _save_simple_message( + conversation_id, max_idx + 2, MESSAGE_ROLE["ASSISTANT"], + assistant_content, user_id, tenant_id, + ) + + logger.info(f"Scheduled task {task_uuid} executed successfully") + + except Exception as e: + logger.error(f"Failed to execute scheduled task {task_uuid}: {e}", exc_info=True) + + +class ScheduledTaskScheduler: + """Background scheduler that polls for due tasks and executes them.""" + + def __init__(self, poll_interval: float = 10.0): + self._poll_interval = poll_interval + self._thread: threading.Thread | None = None + self._stop_event = threading.Event() + + def start(self): + if self._thread is not None and self._thread.is_alive(): + logger.warning("ScheduledTaskScheduler is already running") + return + self._stop_event.clear() + self._thread = threading.Thread(target=self._scheduler_loop, daemon=True) + self._thread.start() + logger.info("ScheduledTaskScheduler started") + + def stop(self): + self._stop_event.set() + if self._thread is not None: + self._thread.join(timeout=30) + self._thread = None + logger.info("ScheduledTaskScheduler stopped") + + def _scheduler_loop(self): + while not self._stop_event.is_set(): + try: + self._process_due_tasks() + except Exception as e: + logger.error(f"Error in scheduler loop: {e}", exc_info=True) + self._stop_event.wait(timeout=self._poll_interval) + + def _process_due_tasks(self): + from database.scheduled_task_db import query_pending_tasks_due, update_task_status + + now = datetime.now(timezone.utc).replace(tzinfo=None) + due_tasks = query_pending_tasks_due(now) + + for task_dict in due_tasks: + task_uuid = task_dict.get("task_uuid") + try: + # Mark as fired before execution to prevent re-entrancy + update_task_status(task_uuid, {"status": "fired"}) + + # Execute in a separate thread to avoid blocking the scheduler + t = threading.Thread( + target=_run_scheduled_task_from_db, + args=(task_dict,), + daemon=True, + ) + t.start() + t.join(timeout=300) # 5-minute timeout per task + + # Update fire count and schedule next run for cron tasks + updates = {"fire_count": (task_dict.get("fire_count") or 0) + 1} + task_type = task_dict.get("task_type") + cron_expr = task_dict.get("cron_expression") + max_fires = task_dict.get("max_fires") + + if task_type == "cron" and cron_expr: + fire_count = updates["fire_count"] + if max_fires is not None and fire_count >= max_fires: + updates["status"] = "completed" + else: + # Compute next fire time + from nexent.core.tools.scheduled_task_tool import ScheduledTaskTool + cron_parts = ScheduledTaskTool._parse_cron(cron_expr) + if cron_parts: + next_fire = ScheduledTaskTool._compute_next_fire( + cron_parts, datetime.now(timezone.utc) + ) + updates["next_fire_time"] = next_fire.replace(tzinfo=None) + updates["status"] = "pending" + else: + updates["status"] = "error" + # oneshot tasks stay as "fired" + + update_task_status(task_uuid, updates) + + except Exception as e: + logger.error(f"Failed to process task {task_uuid}: {e}", exc_info=True) + try: + update_task_status(task_uuid, {"status": "error"}) + except Exception: + pass + + +# Module-level singleton +scheduled_task_scheduler = ScheduledTaskScheduler() diff --git a/docker/init.sql b/docker/init.sql index 4952eaea0..79102fc24 100644 --- a/docker/init.sql +++ b/docker/init.sql @@ -1934,3 +1934,42 @@ COMMENT ON TABLE nexent.user_cas_session_t IS 'Server-side session records for C COMMENT ON COLUMN nexent.user_cas_session_t.session_id IS 'JWT sid claim for revocation checks'; COMMENT ON COLUMN nexent.user_cas_session_t.cas_user_id IS 'User identifier returned by CAS'; COMMENT ON COLUMN nexent.user_cas_session_t.cas_session_index IS 'CAS SessionIndex or service ticket'; + +-- Scheduled tasks table +CREATE TABLE IF NOT EXISTS nexent.scheduled_tasks_t ( + task_id SERIAL PRIMARY KEY, + task_uuid VARCHAR(36) NOT NULL UNIQUE, + task_name VARCHAR(200), + task_prompt TEXT NOT NULL, + task_type VARCHAR(10) NOT NULL, + cron_expression VARCHAR(100), + delay_seconds INTEGER, + status VARCHAR(20) DEFAULT 'pending', + next_fire_time TIMESTAMP, + fire_count INTEGER DEFAULT 0, + max_fires INTEGER, + agent_id INTEGER NOT NULL, + conversation_id INTEGER, + tenant_id VARCHAR(100) NOT NULL, + user_id VARCHAR(100) NOT NULL, + delete_flag VARCHAR(1) DEFAULT 'N', + create_time TIMESTAMP DEFAULT CURRENT_TIMESTAMP, + update_time TIMESTAMP DEFAULT CURRENT_TIMESTAMP, + created_by VARCHAR(100), + updated_by VARCHAR(100) +); + +CREATE INDEX IF NOT EXISTS ix_scheduled_task_status_next_fire + ON nexent.scheduled_tasks_t (status, next_fire_time); +CREATE INDEX IF NOT EXISTS ix_scheduled_task_agent_delete + ON nexent.scheduled_tasks_t (agent_id, delete_flag); + +COMMENT ON TABLE nexent.scheduled_tasks_t IS 'Scheduled task records for deferred and recurring agent execution'; +COMMENT ON COLUMN nexent.scheduled_tasks_t.task_uuid IS 'Unique task identifier (UUID)'; +COMMENT ON COLUMN nexent.scheduled_tasks_t.task_type IS 'Task type: oneshot or cron'; +COMMENT ON COLUMN nexent.scheduled_tasks_t.cron_expression IS 'Cron expression for recurring tasks'; +COMMENT ON COLUMN nexent.scheduled_tasks_t.delay_seconds IS 'Delay in seconds for oneshot tasks'; +COMMENT ON COLUMN nexent.scheduled_tasks_t.status IS 'Task status: pending, fired, cancelled, error'; +COMMENT ON COLUMN nexent.scheduled_tasks_t.next_fire_time IS 'Next scheduled execution time'; +COMMENT ON COLUMN nexent.scheduled_tasks_t.fire_count IS 'Number of times this task has fired'; +COMMENT ON COLUMN nexent.scheduled_tasks_t.max_fires IS 'Maximum number of fires (NULL = unlimited)'; diff --git a/frontend/app/[locale]/chat/internal/chatInterface.tsx b/frontend/app/[locale]/chat/internal/chatInterface.tsx index 9dd9bb847..be8c56c80 100644 --- a/frontend/app/[locale]/chat/internal/chatInterface.tsx +++ b/frontend/app/[locale]/chat/internal/chatInterface.tsx @@ -151,6 +151,110 @@ export function ChatInterface() { }; }, [attachments, fileUrls]); + // ---- Scheduled-task polling ---- + // Track the max message index per conversation for polling + const pollingIndexRef = useRef<{[cid: number]: number}>({}); + + // Helper: format raw API messages into ChatMessageType[] for polling refresh + const formatApiMessages = useCallback((cid: number, rawMsgs: any[]): ChatMessageType[] => { + return rawMsgs.map((msg: any, idx: number) => { + if (msg.role === MESSAGE_ROLES.USER) { + return { + id: `msg-${cid}-${idx}`, + role: msg.role, + content: msg.message || "", + files: msg.minio_files || [], + timestamp: new Date().toLocaleTimeString(), + }; + } + // assistant + let assistantContent = ""; + if (typeof msg.message === "string") { + assistantContent = msg.message; + } else if (Array.isArray(msg.message)) { + const finalAnswer = msg.message.find((u: any) => u.type === "final_answer"); + assistantContent = finalAnswer?.content || ""; + } + return { + id: `msg-${cid}-${idx}`, + role: msg.role, + content: assistantContent, + files: msg.minio_files || [], + timestamp: new Date().toLocaleTimeString(), + search: msg.search, + picture: msg.picture, + searchByUnitId: msg.searchByUnitId, + messageId: msg.message_id, + opinionFlag: msg.opinion_flag, + }; + }); + }, []); + + // Helper: refresh conversation messages from API and update sessionMessages + const refreshConversation = useCallback(async (cid: number, scroll: boolean = false) => { + const detail = await conversationService.getDetail(cid); + if (detail?.data?.[0]?.message) { + const formatted = formatApiMessages(cid, detail.data[0].message); + setSessionMessages(prev => ({ ...prev, [cid]: formatted })); + if (scroll) setShouldScrollToBottom(true); + } + }, [formatApiMessages]); + + // Layer 1 (5s): poll the currently active conversation + useEffect(() => { + const cid = conversationManagement.selectedConversationId; + if (!cid) return; + + const timer = setInterval(async () => { + try { + const sinceIdx = pollingIndexRef.current[cid] ?? -1; + const result = await conversationService.checkNewMessages(cid, sinceIdx); + if (result?.has_new) { + pollingIndexRef.current[cid] = result.max_index; + await refreshConversation(cid, true); + } + } catch { + // Silently ignore polling errors + } + }, 5000); + + return () => clearInterval(timer); + }, [conversationManagement.selectedConversationId, refreshConversation]); + + // Layer 2 (10s): batch-poll all other cached conversations for silent cache update + useEffect(() => { + const timer = setInterval(async () => { + const activeCid = conversationManagement.selectedConversationId; + const allCids = Object.keys(sessionMessages) + .map(Number) + .filter(id => id !== activeCid); + if (allCids.length === 0) return; + + const checks = allCids.map(cid => ({ + conversation_id: cid, + since_index: pollingIndexRef.current[cid] ?? -1, + })); + + try { + const result = await conversationService.batchCheckNewMessages(checks); + const results = result?.results || {}; + for (const cidStr of Object.keys(results)) { + const info = results[cidStr]; + if (info?.has_new) { + const cid = Number(cidStr); + pollingIndexRef.current[cid] = info.max_index; + refreshConversation(cid).catch(() => {}); + } + } + } catch { + // Silently ignore batch polling errors + } + }, 10000); + + return () => clearInterval(timer); + }, [conversationManagement.selectedConversationId, sessionMessages, refreshConversation]); + // ---- End scheduled-task polling ---- + // Handle file upload const handleFileUpload = (file: File) => { return preProcessHandleFileUpload(file, setFileUrls, t); diff --git a/frontend/services/api.ts b/frontend/services/api.ts index ef8b97ff4..34bf003b8 100644 --- a/frontend/services/api.ts +++ b/frontend/services/api.ts @@ -47,6 +47,8 @@ export const API_ENDPOINTS = { sources: `${API_BASE_URL}/conversation/sources`, opinion: `${API_BASE_URL}/conversation/message/update_opinion`, messageId: `${API_BASE_URL}/conversation/message/id`, + newMessages: (id: number) => `${API_BASE_URL}/conversation/${id}/new_messages`, + batchNewMessages: `${API_BASE_URL}/conversation/batch_new_messages`, }, agent: { run: `${API_BASE_URL}/agent/run`, diff --git a/frontend/services/conversationService.ts b/frontend/services/conversationService.ts index 746c38f63..2edbc3a7b 100644 --- a/frontend/services/conversationService.ts +++ b/frontend/services/conversationService.ts @@ -855,4 +855,30 @@ export const conversationService = { throw new ApiError(data.code, data.message); }, + + // Check for new messages in a single conversation (lightweight polling) + async checkNewMessages(conversationId: number, sinceIndex: number) { + const response = await fetch( + `${API_ENDPOINTS.conversation.newMessages(conversationId)}?since_index=${sinceIndex}`, + { + method: 'GET', + headers: getAuthHeaders(), + } + ); + + const data = await response.json(); + return data; + }, + + // Batch check for new messages across multiple conversations + async batchCheckNewMessages(checks: Array<{ conversation_id: number; since_index: number }>) { + const response = await fetch(API_ENDPOINTS.conversation.batchNewMessages, { + method: 'POST', + headers: getAuthHeaders(), + body: JSON.stringify({ checks }), + }); + + const data = await response.json(); + return data; + }, }; diff --git a/sdk/nexent/core/agents/nexent_agent.py b/sdk/nexent/core/agents/nexent_agent.py index b3c5b8cd0..9ea9d77af 100644 --- a/sdk/nexent/core/agents/nexent_agent.py +++ b/sdk/nexent/core/agents/nexent_agent.py @@ -264,6 +264,15 @@ def create_local_tool(self, tool_config: ToolConfig): "agent_id", "") if tool_config.metadata else "" tools_obj.memory_user_config = tool_config.metadata.get( "memory_user_config", None) if tool_config.metadata else None + elif class_name == "ScheduledTaskTool": + tools_obj = tool_class() + # Inject DB functions and context from metadata + if tool_config.metadata: + for attr in ("db_create", "db_list", "db_cancel", + "agent_id", "tenant_id", "user_id", + "conversation_id"): + if attr in tool_config.metadata: + setattr(tools_obj, attr, tool_config.metadata[attr]) else: tools_obj = tool_class(**params) if hasattr(tools_obj, 'observer'): diff --git a/sdk/nexent/core/tools/__init__.py b/sdk/nexent/core/tools/__init__.py index c35991f6e..12ed4d113 100644 --- a/sdk/nexent/core/tools/__init__.py +++ b/sdk/nexent/core/tools/__init__.py @@ -26,6 +26,7 @@ from .read_skill_config_tool import read_skill_config from .store_memory_tool import StoreMemoryTool from .search_memory_tool import SearchMemoryTool +from .scheduled_task_tool import ScheduledTaskTool __all__ = [ "MySqlTool", @@ -58,4 +59,5 @@ "read_skill_config", "StoreMemoryTool", "SearchMemoryTool", + "ScheduledTaskTool", ] diff --git a/sdk/nexent/core/tools/scheduled_task_tool.py b/sdk/nexent/core/tools/scheduled_task_tool.py new file mode 100644 index 000000000..fa39cd1ef --- /dev/null +++ b/sdk/nexent/core/tools/scheduled_task_tool.py @@ -0,0 +1,244 @@ +"""Scheduled task tool - thin CRUD wrapper for creating, listing, and cancelling scheduled tasks.""" + +import logging +import uuid +from datetime import datetime, timedelta, timezone +from typing import Any, Callable, Optional + +from smolagents.tools import Tool + +logger = logging.getLogger("scheduled_task_tool") + + +class ScheduledTaskTool(Tool): + name = "scheduled_task" + description = ( + "Create, list, or cancel scheduled tasks that will be executed " + "automatically at a specified time or on a recurring schedule. " + "Use this to set reminders, schedule periodic reports, or defer " + "actions to a future time." + ) + description_zh = ( + "创建、查看或取消定时任务。定时任务会在指定时间自动执行," + "支持一次性延迟任务和周期性 cron 任务。可用于设置提醒、" + "定期报告或将操作推迟到未来执行。" + ) + + inputs = { + "action": { + "type": "string", + "description": "Action to perform: 'create', 'list', or 'cancel'", + "description_zh": "操作类型:'create'(创建)、'list'(查看)或 'cancel'(取消)", + }, + "task_name": { + "type": "string", + "description": "Name for the task (used in create)", + "description_zh": "任务名称(创建时使用)", + "nullable": True, + }, + "task_prompt": { + "type": "string", + "description": "The prompt content to execute when the task fires (used in create)", + "description_zh": "任务触发时要执行的提示内容(创建时使用)", + "nullable": True, + }, + "task_type": { + "type": "string", + "description": "Type: 'oneshot' (run once after delay) or 'cron' (recurring). Default 'oneshot'", + "description_zh": "类型:'oneshot'(一次性延迟)或 'cron'(周期性)。默认 'oneshot'", + "nullable": True, + }, + "cron_expression": { + "type": "string", + "description": "Cron expression for recurring tasks, e.g. '0 9 * * *' (daily at 9am). Required if task_type='cron'", + "description_zh": "周期性任务的 cron 表达式,如 '0 9 * * *'(每天9点)。task_type='cron' 时必填", + "nullable": True, + }, + "delay_seconds": { + "type": "integer", + "description": "Delay in seconds for oneshot tasks. Required if task_type='oneshot'", + "description_zh": "一次性任务的延迟秒数。task_type='oneshot' 时必填", + "nullable": True, + }, + "task_uuid": { + "type": "string", + "description": "UUID of the task to cancel (used in cancel)", + "description_zh": "要取消的任务 UUID(取消时使用)", + "nullable": True, + }, + } + output_type = "string" + + # These attributes are injected via metadata at runtime + db_create: Callable = None + db_list: Callable = None + db_cancel: Callable = None + agent_id: int = None + tenant_id: str = None + user_id: str = None + conversation_id: int = None + + def forward( + self, + action: str, + task_name: Optional[str] = None, + task_prompt: Optional[str] = None, + task_type: Optional[str] = "oneshot", + cron_expression: Optional[str] = None, + delay_seconds: Optional[int] = None, + task_uuid: Optional[str] = None, + ) -> str: + if action == "create": + return self._handle_create(task_name, task_prompt, task_type, cron_expression, delay_seconds) + elif action == "list": + return self._handle_list() + elif action == "cancel": + return self._handle_cancel(task_uuid) + else: + return f"Unknown action '{action}'. Use 'create', 'list', or 'cancel'." + + def _handle_create(self, task_name, task_prompt, task_type, cron_expression, delay_seconds): + if not task_prompt: + return "Error: task_prompt is required for creating a task." + + task_type = task_type or "oneshot" + + # Compute next_fire_time + now = datetime.now(timezone.utc) + if task_type == "cron": + if not cron_expression: + return "Error: cron_expression is required for cron tasks." + cron_parts = self._parse_cron(cron_expression) + if cron_parts is None: + return f"Error: invalid cron expression '{cron_expression}'." + next_fire = self._compute_next_fire(cron_parts, now) + else: + # oneshot + if delay_seconds is None or delay_seconds <= 0: + return "Error: delay_seconds must be a positive integer for oneshot tasks." + next_fire = now + timedelta(seconds=delay_seconds) + + data = { + "task_uuid": str(uuid.uuid4()), + "task_name": task_name or "", + "task_prompt": task_prompt, + "task_type": task_type, + "cron_expression": cron_expression, + "delay_seconds": delay_seconds, + "status": "pending", + "next_fire_time": next_fire.replace(tzinfo=None), + "fire_count": 0, + "agent_id": self.agent_id, + "conversation_id": self.conversation_id, + "tenant_id": self.tenant_id, + "user_id": self.user_id, + "delete_flag": "N", + } + + try: + record = self.db_create(data) + return f"Scheduled task created successfully. task_uuid={record.get('task_uuid')}, next_fire_time={next_fire.isoformat()}" + except Exception as e: + logger.error(f"Failed to create scheduled task: {e}") + return f"Error creating scheduled task: {e}" + + def _handle_list(self): + tasks = self._safe_call("list", lambda: self.db_list(self.agent_id, self.tenant_id, self.user_id)) + if isinstance(tasks, str): # error message + return tasks + if not tasks: + return "No scheduled tasks found." + lines = [ + f"- [{t.get('status', 'unknown')}] {t.get('task_name', 'unnamed')} " + f"({t.get('task_type', '?')}) uuid={t.get('task_uuid')}, next_fire={t.get('next_fire_time', '?')}" + for t in tasks + ] + return "Scheduled tasks:\n" + "\n".join(lines) + + def _handle_cancel(self, task_uuid): + if not task_uuid: + return "Error: task_uuid is required for cancelling a task." + ok = self._safe_call("cancel", lambda: self.db_cancel(task_uuid, self.agent_id, self.tenant_id, self.user_id)) + if isinstance(ok, str): # error message + return ok + return f"Task {task_uuid} cancelled successfully." if ok else f"Task {task_uuid} not found or already cancelled." + + def _safe_call(self, action: str, fn: Callable) -> Any: + """Execute a DB call with unified error handling.""" + try: + return fn() + except Exception as e: + logger.error(f"Failed to {action} scheduled task(s): {e}") + return f"Error {action} scheduled task(s): {e}" + + @staticmethod + def _parse_cron(expression: str): + """Parse a 5-field cron expression into a dict of field values. + + Returns dict with keys: minute, hour, day_of_month, month, day_of_week + or None if the expression is invalid. + """ + parts = expression.strip().split() + if len(parts) != 5: + return None + try: + return { + "minute": int(parts[0]), + "hour": int(parts[1]), + "day_of_month": parts[2], + "month": parts[3], + "day_of_week": parts[4], + } + except (ValueError, IndexError): + return None + + @staticmethod + def _compute_next_fire(cron_parts: dict, from_timestamp: datetime) -> datetime: + """Compute the next fire time from a parsed cron expression. + + This is a simplified implementation that handles common patterns. + Supports numeric values and '*' for day_of_month, month, day_of_week. + """ + minute = cron_parts["minute"] + hour = cron_parts["hour"] + day_of_month = cron_parts["day_of_month"] + month = cron_parts["month"] + + # Start from the next minute after from_timestamp + candidate = from_timestamp.replace(second=0, microsecond=0) + timedelta(minutes=1) + + # Simple approach: search forward minute by minute up to 366 days + max_iterations = 525960 # 366 * 24 * 60 + for _ in range(max_iterations): + month_match = (month == "*" or candidate.month == int(month)) + dom_match = (day_of_month == "*" or candidate.day == int(day_of_month)) + hour_match = candidate.hour == hour + minute_match = candidate.minute == minute + + if month_match and dom_match and hour_match and minute_match: + return candidate + + # Skip ahead if possible + if not month_match: + # Jump to first day of next month + if candidate.month < 12: + candidate = candidate.replace(month=candidate.month + 1, day=1, hour=0, minute=0) + else: + candidate = candidate.replace(year=candidate.year + 1, month=1, day=1, hour=0, minute=0) + continue + + if not dom_match: + candidate = candidate.replace(hour=0, minute=0) + timedelta(days=1) + continue + + if not hour_match and candidate.hour < hour: + candidate = candidate.replace(hour=hour, minute=minute) + if candidate.minute == minute: + return candidate + continue + + # Move to next day + candidate = candidate.replace(hour=0, minute=0) + timedelta(days=1) + + # Fallback: return from_timestamp + 1 hour + return from_timestamp + timedelta(hours=1)