Skip to content
Open
738 changes: 673 additions & 65 deletions astrbot/builtin_stars/astrbot/long_term_memory.py

Large diffs are not rendered by default.

54 changes: 34 additions & 20 deletions astrbot/builtin_stars/astrbot/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ class Main(star.Star):
def __init__(self, context: star.Context) -> None:
self.context = context
self.ltm = None
self._ltm_was_enabled: dict[str, bool] = {}
try:
self.ltm = LongTermMemory(self.context.astrbot_config_mgr, self.context)
except BaseException as e:
Expand Down Expand Up @@ -170,20 +171,27 @@ async def on_message(self, event: AstrMessageEvent):
return
try:
conv = None
session_curr_cid = await self.context.conversation_manager.get_curr_conversation_id(
event.unified_msg_origin,
)

if not session_curr_cid:
logger.error(
"当前未处于对话状态,无法主动回复,请确保 平台设置->会话隔离(unique_session) 未开启,并使用 /switch 序号 切换或者 /new 创建一个会话。",
if not group_icl_enable:
# 仅在走 Conversation 模式时才需要查询会话
session_curr_cid = await self.context.conversation_manager.get_curr_conversation_id(
event.unified_msg_origin,
)
return

conv = await self.context.conversation_manager.get_conversation(
event.unified_msg_origin,
session_curr_cid,
)
if not session_curr_cid:
logger.error(
"当前未处于对话状态,无法主动回复,请确保 平台设置->会话隔离(unique_session) 未开启,并使用 /switch 序号 切换或者 /new 创建一个会话。",
)
return

conv = await self.context.conversation_manager.get_conversation(
event.unified_msg_origin,
session_curr_cid,
)

if not conv:
logger.error("未找到对话,无法主动回复")
return

prompt = event.message_str
image_urls = []
Expand All @@ -194,10 +202,6 @@ async def on_message(self, event: AstrMessageEvent):
except Exception:
logger.exception("主动回复处理图片失败")

if not conv:
logger.error("未找到对话,无法主动回复")
return

yield event.request_llm(
prompt=prompt,
session_id=event.session_id,
Expand All @@ -214,19 +218,29 @@ async def decorate_llm_req(
) -> None:
"""在请求 LLM 前注入人格信息、Identifier、时间、回复内容等 System Prompt"""
if self.ltm and self.ltm_enabled(event):
umo = event.unified_msg_origin

# 惰性切换检测:false → true 时清理残留旧数据
now_enabled = self.ltm_enabled(event)
was_enabled = self._ltm_was_enabled.get(umo, False)
if now_enabled and not was_enabled:
await self.ltm.remove_session(event)
logger.info(f"LTM: group_icl_enable 开启,已重置 {umo} 上下文")
self._ltm_was_enabled[umo] = now_enabled

try:
await self.ltm.on_req_llm(event, req)
except BaseException as e:
logger.error(f"ltm: {e}")

@filter.on_llm_response()
async def record_llm_resp_to_ltm(
self, event: AstrMessageEvent, resp: LLMResponse
@filter.on_agent_done()
async def record_agent_result_to_ltm(
self, event: AstrMessageEvent, run_context, resp: LLMResponse
) -> None:
"""在 LLM 响应后记录对话"""
"""Agent 完成后记录对话(含工具链)"""
if self.ltm and self.ltm_enabled(event):
try:
await self.ltm.after_req_llm(event, resp)
await self.ltm.on_agent_done(event, run_context, resp)
except Exception as e:
logger.error(f"ltm: {e}")

Expand Down
28 changes: 28 additions & 0 deletions astrbot/core/agent/context/guard.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
from ..message import Message
from .config import ContextConfig
from .manager import ContextManager


class RequestContextGuard:
"""Request-time context guard before sending messages to a provider.

This guard is intentionally scoped to a single provider request. It may
truncate or compress the in-flight messages to keep the current request
within model/provider limits, but it does not own persistent history and
should not be treated as the memory-layer compactor.
"""

def __init__(self, config: ContextConfig) -> None:
self.config = config
self._manager = ContextManager(config)

async def process(
self,
messages: list[Message],
trusted_token_usage: int = 0,
) -> list[Message]:
"""Apply request-time context guarding to messages."""
return await self._manager.process(
messages,
trusted_token_usage=trusted_token_usage,
)
32 changes: 20 additions & 12 deletions astrbot/core/agent/runners/tool_loop_agent_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@

from ..context.compressor import ContextCompressor
from ..context.config import ContextConfig
from ..context.manager import ContextManager
from ..context.guard import RequestContextGuard
from ..context.token_counter import EstimateTokenCounter, TokenCounter
from ..hooks import BaseAgentRunHooks
from ..message import (
Expand Down Expand Up @@ -241,13 +241,10 @@ async def reset(
self.tool_result_overflow_dir = tool_result_overflow_dir
self.read_tool = read_tool
self._tool_result_token_counter = EstimateTokenCounter()
# we will do compress when:
# 1. before requesting LLM
# TODO: 2. after LLM output a tool call
self.context_config = ContextConfig(
# <=0 will never do compress
self.request_context_guard_config = ContextConfig(
# <=0 disables token-based guarding.
max_context_tokens=provider.provider_config.get("max_context_tokens", 0),
# enforce max turns before compression
# Enforce max turns before token-based guarding.
enforce_max_turns=self.enforce_max_turns,
truncate_turns=self.truncate_turns,
llm_compress_instruction=self.llm_compress_instruction,
Expand All @@ -256,7 +253,9 @@ async def reset(
custom_token_counter=self.custom_token_counter,
custom_compressor=self.custom_compressor,
)
self.context_manager = ContextManager(self.context_config)
self.request_context_guard = RequestContextGuard(
self.request_context_guard_config
)

self.provider = provider
self.fallback_providers: list[Provider] = []
Expand Down Expand Up @@ -459,8 +458,11 @@ async def _iter_llm_responses(
self, *, include_model: bool = True
) -> T.AsyncGenerator[LLMResponse, None]:
"""Yields chunks *and* a final LLMResponse."""
messages_for_provider = getattr(
self, "_provider_messages", self.run_context.messages
)
payload = {
"contexts": self._sanitize_contexts_for_provider(self.run_context.messages),
"contexts": self._sanitize_contexts_for_provider(messages_for_provider),
"func_tool": self._func_tool_for_provider(),
"session_id": self.req.session_id,
"extra_user_content_parts": self.req.extra_user_content_parts, # list[ContentPart]
Expand Down Expand Up @@ -704,10 +706,13 @@ async def step(self):
self._transition_state(AgentState.RUNNING)
llm_resp_result = None

# do truncate and compress
# Apply request-time context guard *on a copy* so the runner's canonical
# messages are never mutated by the guard. The guard result is only used
# for this provider call. Persistent compaction is owned by the
# conversation / memory layer.
token_usage = self.req.conversation.token_usage if self.req.conversation else 0
self._simple_print_message_role("[BefCompact]")
self.run_context.messages = await self.context_manager.process(
self._provider_messages = await self.request_context_guard.process(
self.run_context.messages, trusted_token_usage=token_usage
)
self._simple_print_message_role("[AftCompact]")
Expand Down Expand Up @@ -1398,14 +1403,17 @@ async def _iter_tool_executor_results(
self,
executor: AsyncIterator[ToolExecutorResultT],
) -> T.AsyncGenerator[ToolExecutorResultT, None]:
async def _next_executor_result() -> ToolExecutorResultT:
return await anext(executor)

while True:
if self._is_stop_requested():
await self._close_executor(executor)
raise _ToolExecutionInterrupted(
"Tool execution interrupted before reading the next tool result."
)

next_result_task = asyncio.create_task(anext(executor))
next_result_task = asyncio.create_task(_next_executor_result())
abort_task = asyncio.create_task(self._abort_signal.wait())
try:
done, _ = await asyncio.wait(
Comment thread
RC-CHN marked this conversation as resolved.
Expand Down
26 changes: 10 additions & 16 deletions astrbot/core/astr_main_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -1119,26 +1119,21 @@ async def _apply_web_search_tools(


def _get_compress_provider(
config: MainAgentBuildConfig, plugin_context: Context
config: MainAgentBuildConfig,
plugin_context: Context,
event: AstrMessageEvent | None = None,
) -> Provider | None:
if not config.llm_compress_provider_id:
return None
if config.context_limit_reached_strategy != "llm_compress":
return None
provider = plugin_context.get_provider_by_id(config.llm_compress_provider_id)
if provider is None:
if config.llm_compress_provider_id:
provider = plugin_context.get_provider_by_id(config.llm_compress_provider_id)
if provider and isinstance(provider, Provider):
return provider
logger.warning(
"未找到指定的上下文压缩模型 %s,将跳过压缩。",
"指定的上下文压缩模型 %s 不可用",
config.llm_compress_provider_id,
)
return None
if not isinstance(provider, Provider):
logger.warning(
"指定的上下文压缩模型 %s 不是对话模型,将跳过压缩。",
config.llm_compress_provider_id,
)
return None
return provider
return None
Comment thread
RC-CHN marked this conversation as resolved.
Comment thread
RC-CHN marked this conversation as resolved.


def _get_fallback_chat_providers(
Expand Down Expand Up @@ -1449,9 +1444,8 @@ async def build_main_agent(
streaming=config.streaming_response,
llm_compress_instruction=config.llm_compress_instruction,
llm_compress_keep_recent=config.llm_compress_keep_recent,
llm_compress_provider=_get_compress_provider(config, plugin_context),
llm_compress_provider=_get_compress_provider(config, plugin_context, event),
truncate_turns=config.dequeue_context_length,
enforce_max_turns=config.max_context_length,
tool_schema_mode=config.tool_schema_mode,
fallback_providers=_get_fallback_chat_providers(
provider, plugin_context, config.provider_settings
Expand Down
Loading
Loading