diff --git a/astrbot/core/agent/agent.py b/astrbot/core/agent/agent.py index d6e2e7cb41..676776a358 100644 --- a/astrbot/core/agent/agent.py +++ b/astrbot/core/agent/agent.py @@ -11,5 +11,6 @@ class Agent(Generic[TContext]): name: str instructions: str | None = None tools: list[str | FunctionTool] | None = None + skills: list[str] | None = None run_hooks: BaseAgentRunHooks[TContext] | None = None begin_dialogs: list[Any] | None = None diff --git a/astrbot/core/astr_agent_tool_exec.py b/astrbot/core/astr_agent_tool_exec.py index de5caad554..0cfbc40cb5 100644 --- a/astrbot/core/astr_agent_tool_exec.py +++ b/astrbot/core/astr_agent_tool_exec.py @@ -30,6 +30,7 @@ from astrbot.core.platform.message_session import MessageSession from astrbot.core.provider.entites import ProviderRequest from astrbot.core.provider.register import llm_tools +from astrbot.core.skills.skill_manager import SkillManager, build_skills_prompt from astrbot.core.tools.computer_tools import ( CuaKeyboardTypeTool, CuaMouseClickTool, @@ -292,6 +293,39 @@ def _build_handoff_toolset( toolset.add_tool(tool_name_or_obj) return None if toolset.empty() else toolset + @classmethod + def _build_handoff_system_prompt( + cls, + instructions: str | None, + skill_names: list[str] | None, + runtime: str, + ) -> str: + skills_prompt = cls._build_handoff_skills_prompt(skill_names, runtime) + parts = [ + part.strip() + for part in (instructions, skills_prompt) + if isinstance(part, str) and part.strip() + ] + return "\n\n".join(parts) + + @classmethod + def _build_handoff_skills_prompt( + cls, + skill_names: list[str] | None, + runtime: str, + ) -> str: + if skill_names == []: + return "" + + skills = SkillManager().list_skills(active_only=True, runtime=runtime) + if skill_names is not None: + allowed = set(skill_names) + skills = [skill for skill in skills if skill.name in allowed] + + if not skills: + return "" + return build_skills_prompt(skills) + @classmethod async def _execute_handoff( cls, @@ -348,7 +382,14 @@ async def _execute_handoff( except Exception: continue - prov_settings: dict = ctx.get_config(umo=umo).get("provider_settings", {}) + cfg = ctx.get_config(umo=umo) + prov_settings: dict = cfg.get("provider_settings", {}) + runtime = str(prov_settings.get("computer_use_runtime", "local")) + system_prompt = cls._build_handoff_system_prompt( + tool.agent.instructions, + getattr(tool.agent, "skills", []), + runtime, + ) agent_max_step = int(prov_settings.get("max_agent_step", 30)) stream = prov_settings.get("streaming_response", False) llm_resp = await ctx.tool_loop_agent( @@ -356,7 +397,7 @@ async def _execute_handoff( chat_provider_id=prov_id, prompt=input_, image_urls=image_urls, - system_prompt=tool.agent.instructions, + system_prompt=system_prompt, tools=toolset, contexts=contexts, max_steps=agent_max_step, diff --git a/astrbot/core/subagent_orchestrator.py b/astrbot/core/subagent_orchestrator.py index c6c595dfc9..b72d823a21 100644 --- a/astrbot/core/subagent_orchestrator.py +++ b/astrbot/core/subagent_orchestrator.py @@ -61,6 +61,7 @@ async def reload_from_config(self, cfg: dict[str, Any]) -> None: if provider_id is not None: provider_id = str(provider_id).strip() or None tools = item.get("tools", []) + skills = item.get("skills", []) begin_dialogs = None if persona_data: @@ -71,6 +72,7 @@ async def reload_from_config(self, cfg: dict[str, Any]) -> None: persona_data.get("_begin_dialogs_processed") ) tools = persona_data.get("tools") + skills = persona_data.get("skills", []) if public_description == "" and prompt: public_description = prompt[:120] if tools is None: @@ -79,11 +81,18 @@ async def reload_from_config(self, cfg: dict[str, Any]) -> None: tools = [] else: tools = [str(t).strip() for t in tools if str(t).strip()] + if skills is None: + skills = None + elif not isinstance(skills, list): + skills = [] + else: + skills = [str(s).strip() for s in skills if str(s).strip()] agent = Agent[AstrAgentContext]( name=name, instructions=instructions, tools=tools, # type: ignore + skills=skills, ) agent.begin_dialogs = begin_dialogs # The tool description should be a short description for the main LLM, diff --git a/tests/unit/test_astr_agent_tool_exec.py b/tests/unit/test_astr_agent_tool_exec.py index 5fab9fe0a2..93a677a129 100644 --- a/tests/unit/test_astr_agent_tool_exec.py +++ b/tests/unit/test_astr_agent_tool_exec.py @@ -6,6 +6,7 @@ from astrbot.core.agent.run_context import ContextWrapper from astrbot.core.astr_agent_tool_exec import FunctionToolExecutor from astrbot.core.message.components import Image +from astrbot.core.skills.skill_manager import SkillInfo class _DummyEvent: @@ -321,6 +322,110 @@ async def _fake_tool_loop_agent(**kwargs): assert captured["tool_call_timeout"] == 120 +def test_build_handoff_skills_prompt_filters_selected_skills( + monkeypatch: pytest.MonkeyPatch, +): + manager = SimpleNamespace( + list_skills=lambda **_kwargs: [ + SkillInfo( + name="web-search-skill", + description="Search the web", + path="/skills/web-search-skill/SKILL.md", + active=True, + ), + SkillInfo( + name="other-skill", + description="Other work", + path="/skills/other-skill/SKILL.md", + active=True, + ), + ], + ) + monkeypatch.setattr( + "astrbot.core.astr_agent_tool_exec.SkillManager", + lambda: manager, + ) + + prompt = FunctionToolExecutor._build_handoff_skills_prompt( + ["web-search-skill"], + "local", + ) + + assert "web-search-skill" in prompt + assert "Search the web" in prompt + assert "other-skill" not in prompt + + +@pytest.mark.asyncio +async def test_execute_handoff_appends_agent_skills_prompt( + monkeypatch: pytest.MonkeyPatch, +): + captured: dict = {} + + async def _fake_get_current_chat_provider_id(_umo): + return "provider-id" + + async def _fake_tool_loop_agent(**kwargs): + captured.update(kwargs) + return SimpleNamespace(completion_text="ok") + + context = SimpleNamespace( + get_current_chat_provider_id=_fake_get_current_chat_provider_id, + tool_loop_agent=_fake_tool_loop_agent, + get_config=lambda **_kwargs: {"provider_settings": {}}, + ) + event = _DummyEvent([]) + run_context = ContextWrapper(context=SimpleNamespace(event=event, context=context)) + tool = SimpleNamespace( + name="transfer_to_subagent", + provider_id=None, + agent=SimpleNamespace( + name="subagent", + tools=[], + skills=["web-search-skill"], + instructions="subagent-instructions", + begin_dialogs=[], + run_hooks=None, + ), + ) + monkeypatch.setattr( + FunctionToolExecutor, + "_build_handoff_skills_prompt", + classmethod(lambda cls, skill_names, runtime: "SKILL PROMPT"), + ) + + results = [] + async for result in FunctionToolExecutor._execute_handoff( + tool, + run_context, + image_urls_prepared=True, + input="hello", + image_urls=[], + ): + results.append(result) + + assert len(results) == 1 + assert captured["system_prompt"] == "subagent-instructions\n\nSKILL PROMPT" + + +def test_build_handoff_system_prompt_omits_empty_parts( + monkeypatch: pytest.MonkeyPatch, +): + monkeypatch.setattr( + FunctionToolExecutor, + "_build_handoff_skills_prompt", + classmethod(lambda cls, skill_names, runtime: "SKILL PROMPT\n"), + ) + + prompt = FunctionToolExecutor._build_handoff_system_prompt( + " ", + ["web-search-skill"], + "local", + ) + + assert prompt == "SKILL PROMPT" + + @pytest.mark.asyncio async def test_collect_handoff_image_urls_filters_extensionless_file_outside_temp_root( monkeypatch: pytest.MonkeyPatch, diff --git a/tests/unit/test_subagent_orchestrator.py b/tests/unit/test_subagent_orchestrator.py index 9befac8872..07c475f006 100644 --- a/tests/unit/test_subagent_orchestrator.py +++ b/tests/unit/test_subagent_orchestrator.py @@ -38,6 +38,7 @@ async def test_reload_from_config_default_persona_is_resolved(): handoff = orchestrator.handoffs[0] assert handoff.agent.instructions == default_persona["prompt"] assert handoff.agent.tools is None + assert handoff.agent.skills == [] assert handoff.agent.begin_dialogs == default_persona["_begin_dialogs_processed"] @@ -55,6 +56,7 @@ async def test_reload_from_config_missing_persona_falls_back_to_inline_and_warns handoff = orchestrator.handoffs[0] assert handoff.agent.instructions == "inline prompt" assert handoff.agent.tools == ["tool_a", "tool_b"] + assert handoff.agent.skills == [] assert handoff.agent.begin_dialogs is None mock_logger.warning.assert_called_once_with( "SubAgent persona %s not found, fallback to inline prompt.", @@ -71,6 +73,7 @@ async def test_reload_from_config_uses_processed_begin_dialogs_and_deepcopy(): "name": "custom", "prompt": "persona prompt", "tools": ["tool_from_persona"], + "skills": ["web-search-skill"], "_begin_dialogs_processed": processed_dialogs, } orchestrator = SubAgentOrchestrator(tool_mgr=tool_mgr, persona_mgr=persona_mgr) @@ -81,6 +84,7 @@ async def test_reload_from_config_uses_processed_begin_dialogs_and_deepcopy(): handoff = orchestrator.handoffs[0] assert handoff.agent.instructions == "persona prompt" assert handoff.agent.tools == ["tool_from_persona"] + assert handoff.agent.skills == ["web-search-skill"] assert handoff.agent.begin_dialogs[0]["content"] == "hello" @@ -108,3 +112,30 @@ async def test_reload_from_config_tool_normalization(raw_tools, expected_tools): handoff = orchestrator.handoffs[0] assert handoff.agent.tools == expected_tools + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + ("raw_skills", "expected_skills"), + [ + (None, None), + ([" web-search-skill ", "", "weather"], ["web-search-skill", "weather"]), + ("not-a-list", []), + ], +) +async def test_reload_from_config_skill_normalization(raw_skills, expected_skills): + tool_mgr = MagicMock() + persona_mgr = MagicMock() + persona_mgr.get_persona_v3_by_id.return_value = { + "name": "custom", + "prompt": "persona prompt", + "tools": [], + "skills": raw_skills, + "_begin_dialogs_processed": [], + } + orchestrator = SubAgentOrchestrator(tool_mgr=tool_mgr, persona_mgr=persona_mgr) + + await orchestrator.reload_from_config(_build_cfg({"persona_id": "custom"})) + + handoff = orchestrator.handoffs[0] + assert handoff.agent.skills == expected_skills