Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 16 additions & 1 deletion src/agentscope/app/_service/_chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -257,8 +257,9 @@ async def _run_impl(
# (any process) wakes an idle session — no in-process retrigger
# plumbing is needed here.
# ----------------------------------------------------------------
inbox_middleware = InboxMiddleware(self._message_bus)
middlewares: list = [
InboxMiddleware(self._message_bus),
inbox_middleware,
StateChangeMiddleware(
message_bus=self._message_bus,
session_id=session_id,
Expand Down Expand Up @@ -383,6 +384,20 @@ async def _run_impl(
session_id,
msg,
)
else:
inbox_events = await inbox_middleware.drain(agent)
if not inbox_events:
logger.info(
"Skipping wake-up for session %s: inbox is "
"empty.",
session_id,
)
return
for event in inbox_events:
await self._message_bus.session_publish_event(
session_id,
event.model_dump(mode="json"),
)

async for event in agent.reply_stream(inputs=input_msg):
await self._message_bus.session_publish_event(
Expand Down
97 changes: 56 additions & 41 deletions src/agentscope/app/middleware/_inbox_middleware.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,41 +78,49 @@ async def on_reasoning( # type: ignore[override]
One ``HintBlockEvent`` per drained inbox entry,
followed by events from downstream.
"""
for event in await self.drain(agent):
yield event

async for evt in next_handler(**input_kwargs):
yield evt

async def drain(self, agent: Agent) -> list[HintBlockEvent]:
"""Drain pending inbox entries into the agent context.

Args:
agent (`Agent`):
The executing agent. ``agent.state.session_id`` selects
the inbox to drain.

Returns:
`list[HintBlockEvent]`:
One event per injected :class:`HintBlock`. Empty when
the inbox had no pending entries.
"""
entries = await self._bus.inbox_drain(
agent.state.session_id,
max_count=self._max_count,
)

if entries:
hint_blocks = [
HintBlock.model_validate(payload)
for _entry_id, payload in entries
]

logger.info(
"InboxMiddleware: injecting %d HintBlock(s) into context "
"for session %s",
len(hint_blocks),
agent.state.session_id,
)
if not entries:
return []

hint_blocks = [
HintBlock.model_validate(payload) for _entry_id, payload in entries
]

logger.info(
"InboxMiddleware: injecting %d HintBlock(s) into context "
"for session %s",
len(hint_blocks),
agent.state.session_id,
)

# Inject into agent context (same pattern as
# ToolOffloadMiddleware).
if len(agent.state.context) > 0:
last_msg = agent.state.context[-1]
if (
last_msg.role == "assistant"
and last_msg.name == agent.name
):
last_msg.content.extend(hint_blocks)
else:
agent.state.context.append(
AssistantMsg(
id=agent.state.reply_id,
name=agent.name,
content=list(hint_blocks),
),
)
# Inject into agent context (same pattern as ToolOffloadMiddleware).
if len(agent.state.context) > 0:
last_msg = agent.state.context[-1]
if last_msg.role == "assistant" and last_msg.name == agent.name:
last_msg.content.extend(hint_blocks)
else:
agent.state.context.append(
AssistantMsg(
Expand All @@ -121,16 +129,23 @@ async def on_reasoning( # type: ignore[override]
content=list(hint_blocks),
),
)
else:
agent.state.context.append(
AssistantMsg(
id=agent.state.reply_id,
name=agent.name,
content=list(hint_blocks),
),
)

# Yield one-shot events so the front-end SSE stream sees
# each HintBlock.
for hint in hint_blocks:
yield HintBlockEvent(
reply_id=agent.state.reply_id,
block_id=hint.id,
source=hint.source,
hint=hint.hint,
)

async for evt in next_handler(**input_kwargs):
yield evt
# Return one-shot events so the front-end SSE stream sees each
# HintBlock.
return [
HintBlockEvent(
reply_id=agent.state.reply_id,
block_id=hint.id,
source=hint.source,
hint=hint.hint,
)
for hint in hint_blocks
]
233 changes: 233 additions & 0 deletions tests/service_chat_wakeup_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,233 @@
# -*- coding: utf-8 -*-
# pylint: disable=protected-access
"""Regression tests for wake-up driven ChatService runs."""
from contextlib import asynccontextmanager
from typing import Any, AsyncGenerator
from unittest import IsolatedAsyncioTestCase
from unittest.mock import AsyncMock, patch

from agentscope.agent import ContextConfig, ReActConfig
from agentscope.app._service import ChatService
from agentscope.app.storage import (
AgentData,
AgentRecord,
ChatModelConfig,
SessionConfig,
SessionRecord,
)
from agentscope.message import HintBlock


class _FakeStorage:
"""Minimal storage surface used by :class:`ChatService`."""

def __init__(self) -> None:
self.agent = AgentRecord(
user_id="u",
data=AgentData(
name="A",
system_prompt="You are A.",
context_config=ContextConfig(),
react_config=ReActConfig(),
),
)
self.session = SessionRecord(
user_id="u",
agent_id=self.agent.data.id,
config=SessionConfig(
workspace_id="ws",
chat_model_config=ChatModelConfig(
type="test",
credential_id="c",
model="m",
parameters={},
),
),
)
self.updated_states: list[Any] = []

async def get_agent(self, user_id: str, agent_id: str) -> AgentRecord:
"""Return the single test agent."""
assert user_id == "u"
assert agent_id == self.agent.data.id
return self.agent

async def get_session(
self,
user_id: str,
agent_id: str,
session_id: str,
) -> SessionRecord:
"""Return the single test session."""
assert user_id == "u"
assert agent_id == self.agent.data.id
assert session_id == "s"
return self.session

async def upsert_message(self, *_args: Any, **_kwargs: Any) -> None:
"""No-op message persistence for this focused regression."""

async def get_message(self, *_args: Any, **_kwargs: Any) -> None:
"""No continuation message is needed in these tests."""
return None

async def update_session_state(
self,
*,
state: Any,
**_kwargs: Any,
) -> None:
"""Capture persisted states."""
self.updated_states.append(state)


class _FakeWorkspace:
"""Workspace object with just the workdir ChatService records."""

workdir = "/tmp/agentscope-test-workspace"


class _FakeWorkspaceManager:
"""Workspace manager returning one fake workspace."""

async def get_workspace(
self,
*_args: Any,
**_kwargs: Any,
) -> _FakeWorkspace:
"""Return the fake workspace."""
return _FakeWorkspace()


class _FakeBus:
"""In-memory bus covering session locks, events, and inbox drains."""

def __init__(self) -> None:
self.inbox: list[tuple[str, dict]] = []
self.events: list[dict] = []

@asynccontextmanager
async def session_run(
self,
_session_id: str,
) -> AsyncGenerator[None, None]:
"""No-op lock context."""
yield

async def inbox_drain(
self,
_session_id: str,
max_count: int = 100,
) -> list[tuple[str, dict]]:
"""Drain queued test inbox entries."""
entries = self.inbox[:max_count]
self.inbox = self.inbox[max_count:]
return entries

async def session_publish_event(
self,
_session_id: str,
event: dict,
) -> str:
"""Capture published events."""
self.events.append(event)
return f"evt-{len(self.events)}"


class _FakeAgent:
"""Agent double that records whether the model/reasoning path ran."""

calls: list[Any] = []

def __init__(
self,
*,
name: str,
state: Any,
middlewares: list[Any],
**_kwargs: Any,
) -> None:
self.name = name
self.state = state
self.middlewares = middlewares

async def reply_stream(
self,
inputs: Any = None,
) -> AsyncGenerator[Any, None]:
"""Record invocation and yield no model events."""
self.calls.append(inputs)
return
yield # pragma: no cover # pylint: disable=unreachable


def _make_service(storage: _FakeStorage, bus: _FakeBus) -> ChatService:
"""Build a ChatService wired to local fakes."""
return ChatService(
storage=storage,
workspace_manager=_FakeWorkspaceManager(),
scheduler_manager=object(),
background_task_manager=object(),
message_bus=bus,
custom_agent_cls=_FakeAgent,
)


class TestChatServiceWakeupRuns(IsolatedAsyncioTestCase):
"""Wake-up runs should only invoke the agent when work was delivered."""

async def asyncSetUp(self) -> None:
"""Reset class-level call tracking."""
_FakeAgent.calls = []

async def test_empty_wakeup_does_not_call_agent(self) -> None:
"""Duplicate wake-ups with an empty inbox are treated as no-ops."""
storage = _FakeStorage()
bus = _FakeBus()
service = _make_service(storage, bus)

with (
patch(
"agentscope.app._service._chat.get_model",
AsyncMock(return_value=object()),
),
patch(
"agentscope.app._service._chat.get_toolkit",
AsyncMock(return_value=object()),
),
):
await service._run_impl("u", "s", storage.agent.data.id, None)

self.assertEqual(_FakeAgent.calls, [])
self.assertEqual(bus.events, [])
self.assertEqual(storage.updated_states, [])

async def test_wakeup_with_inbox_runs_agent_after_publishing_hints(
self,
) -> None:
"""Pending inbox content is still delivered before agent reasoning."""
storage = _FakeStorage()
bus = _FakeBus()
hint = HintBlock(hint="background result", source="tool")
bus.inbox.append(("id-1", hint.model_dump(mode="json")))
service = _make_service(storage, bus)

with (
patch(
"agentscope.app._service._chat.get_model",
AsyncMock(return_value=object()),
),
patch(
"agentscope.app._service._chat.get_toolkit",
AsyncMock(return_value=object()),
),
):
await service._run_impl("u", "s", storage.agent.data.id, None)

self.assertEqual(_FakeAgent.calls, [None])
self.assertEqual(len(bus.events), 1)
self.assertEqual(bus.events[0]["type"], "HINT_BLOCK")
self.assertEqual(bus.events[0]["block_id"], hint.id)
self.assertEqual(len(storage.session.state.context), 1)
self.assertEqual(storage.session.state.context[0].content[0], hint)
self.assertEqual(storage.updated_states, [storage.session.state])
Loading