diff --git a/astrbot/builtin_stars/astrbot/long_term_memory.py b/astrbot/builtin_stars/astrbot/long_term_memory.py index e08cdc5157..b5c48c4709 100644 --- a/astrbot/builtin_stars/astrbot/long_term_memory.py +++ b/astrbot/builtin_stars/astrbot/long_term_memory.py @@ -1,7 +1,10 @@ +import asyncio import datetime +import json import random import uuid -from collections import defaultdict +from collections import defaultdict, deque +from typing import Any from astrbot import logger from astrbot.api import star @@ -12,56 +15,123 @@ from astrbot.core.astrbot_config_mgr import AstrBotConfigManager """ -聊天记忆增强 +聊天记忆增强 (LTM v2) """ +# === 常量 === + +CHATROOM_SYSTEM_NOTE = ( + "You are now in a chatroom. " + "Chat history messages below use the format '[username/time]: content'. " + "Your own messages are presented via the standard assistant role.\n" +) + +MAX_MSGS_PER_USER_SEGMENT = 50 +MAX_CHARS_PER_USER_SEGMENT = 3000 +MAX_RAW_BYTES = 500_000 # 500KB / 群 +DEFAULT_HISTORY_TOOL_RESULT_MAX_CHARS = 8192 +SUMMARY_RETRY_COOLDOWN = 5 # 轮数:LLM 摘要失败后等待多少轮再重试 + +TOOL_CALL_PREFIX = "" +TOOL_RES_PREFIX = " None: self.acm = acm self.context = context - self.session_chats = defaultdict(list) - """记录群成员的群聊记录""" + + self._locks: dict[str, asyncio.Lock] = {} + + self.raw_records: dict[str, deque[str]] = defaultdict(deque) + """群聊原始记录。deque 支持 O(1) popleft。""" + + self._raw_cursor: dict[str, int] = defaultdict(int) + """raw_records 中已消费到 contexts 的位置(指向下一条未消费的索引)。""" + + self.contexts: dict[str, list[dict]] = defaultdict(list) + """累积累积态 LLM 上下文。由 ContextManager 修改后保留。""" + + self._persisted_tool_call_ids: dict[str, set[str]] = defaultdict(set) + """已持久化到 raw_records 的 的 tool_call_id。用于防重复注入。""" + self._persisted_tool_result_ids: dict[str, set[str]] = defaultdict(set) + """已持久化到 raw_records 的 的 tool_call_id。用于防重复注入。""" + + self.summaries: dict[str, str] = defaultdict(str) + """LLM summary 策略下每个群聊的长期摘要文本。""" + + self._summary_next_retry: dict[str, int] = defaultdict(int) + """LLM 摘要失败后,下次允许重试的 rounds 数下限(冷却期内跳过)。""" + + def _get_lock(self, umo: str) -> asyncio.Lock: + """Return the per-session lock for a unified message origin.""" + lock = self._locks.get(umo) + if lock is None: + lock = asyncio.Lock() + self._locks[umo] = lock + return lock + + # ========================================================================= + # 配置 + # ========================================================================= def cfg(self, event: AstrMessageEvent): cfg = self.context.get_config(umo=event.unified_msg_origin) - try: - max_cnt = int(cfg["provider_ltm_settings"]["group_message_max_cnt"]) - except BaseException as e: - logger.error(e) - max_cnt = 300 + ltm_cfg = cfg["provider_ltm_settings"] image_caption_prompt = cfg["provider_settings"]["image_caption_prompt"] - image_caption_provider_id = cfg["provider_ltm_settings"].get( - "image_caption_provider_id" - ) - image_caption = cfg["provider_ltm_settings"]["image_caption"] and bool( - image_caption_provider_id + image_caption_provider_id = ltm_cfg.get("image_caption_provider_id") + image_caption = ltm_cfg["image_caption"] and bool(image_caption_provider_id) + history_tool_result_truncate = ltm_cfg.get("history_tool_result_truncate", True) + history_tool_result_max_chars = int( + ltm_cfg.get( + "history_tool_result_max_chars", + DEFAULT_HISTORY_TOOL_RESULT_MAX_CHARS, + ) + or DEFAULT_HISTORY_TOOL_RESULT_MAX_CHARS ) - active_reply = cfg["provider_ltm_settings"]["active_reply"] + active_reply = ltm_cfg["active_reply"] enable_active_reply = active_reply.get("enable", False) ar_method = active_reply["method"] ar_possibility = active_reply["possibility_reply"] ar_prompt = active_reply.get("prompt", "") ar_whitelist = active_reply.get("whitelist", []) - ret = { - "max_cnt": max_cnt, + # LTM compaction + ltm_compaction_strategy = ltm_cfg.get("ltm_compaction_strategy", "truncate") + ltm_max_rounds = ltm_cfg.get("ltm_max_rounds", 80) + ltm_truncate_drop_rounds = ltm_cfg.get("ltm_truncate_drop_rounds", 50) + ltm_summary_trigger_rounds = ltm_cfg.get("ltm_summary_trigger_rounds", 80) + ltm_summary_keep_recent_rounds = ltm_cfg.get( + "ltm_summary_keep_recent_rounds", 30 + ) + ltm_summary_provider_id = ltm_cfg.get("ltm_summary_provider_id", "") + ltm_summary_prompt = ltm_cfg.get("ltm_summary_prompt", "") + return { "image_caption": image_caption, "image_caption_prompt": image_caption_prompt, "image_caption_provider_id": image_caption_provider_id, + "history_tool_result_truncate": history_tool_result_truncate, + "history_tool_result_max_chars": max(1, history_tool_result_max_chars), "enable_active_reply": enable_active_reply, "ar_method": ar_method, "ar_possibility": ar_possibility, "ar_prompt": ar_prompt, "ar_whitelist": ar_whitelist, + "ltm_compaction_strategy": ltm_compaction_strategy, + "ltm_max_rounds": max(1, ltm_max_rounds), + "ltm_truncate_drop_rounds": max(1, ltm_truncate_drop_rounds), + "ltm_summary_trigger_rounds": max(1, ltm_summary_trigger_rounds), + "ltm_summary_keep_recent_rounds": max(1, ltm_summary_keep_recent_rounds), + "ltm_summary_provider_id": ltm_summary_provider_id, + "ltm_summary_prompt": ltm_summary_prompt, + "ltm_raw_records_max_bytes": ltm_cfg.get( + "ltm_raw_records_max_bytes", MAX_RAW_BYTES + ), } - return ret - async def remove_session(self, event: AstrMessageEvent) -> int: - cnt = 0 - if event.unified_msg_origin in self.session_chats: - cnt = len(self.session_chats[event.unified_msg_origin]) - del self.session_chats[event.unified_msg_origin] - return cnt + # ========================================================================= + # 图片描述 + # ========================================================================= async def get_image_caption( self, @@ -85,17 +155,18 @@ async def get_image_caption( ) return response.completion_text + # ========================================================================= + # 主动回复判断 + # ========================================================================= + async def need_active_reply(self, event: AstrMessageEvent) -> bool: cfg = self.cfg(event) if not cfg["enable_active_reply"]: return False if event.get_message_type() != MessageType.GROUP_MESSAGE: return False - if event.is_at_or_wake_command: - # if the message is a command, let it pass return False - if cfg["ar_whitelist"] and ( event.unified_msg_origin not in cfg["ar_whitelist"] and ( @@ -103,21 +174,47 @@ async def need_active_reply(self, event: AstrMessageEvent) -> bool: ) ): return False - match cfg["ar_method"]: case "possibility_reply": trig = random.random() < cfg["ar_possibility"] return trig - return False + # ========================================================================= + # 会话清理 + # ========================================================================= + + async def remove_session(self, event: AstrMessageEvent) -> int: + """清理指定群的全部 LTM 状态。返回被清理的 raw_records 条数。""" + umo = event.unified_msg_origin + async with self._get_lock(umo): + cnt = len(self.raw_records.get(umo, deque())) + self.raw_records.pop(umo, None) + self.contexts.pop(umo, None) + self._raw_cursor.pop(umo, None) + self._persisted_tool_call_ids.pop(umo, None) + self._persisted_tool_result_ids.pop(umo, None) + self._summary_next_retry.pop(umo, None) + self.summaries.pop(umo, None) + return cnt + + # ========================================================================= + # 消息记录 (on_message 调用) + # ========================================================================= + async def handle_message(self, event: AstrMessageEvent) -> None: - """仅支持群聊""" - if event.get_message_type() == MessageType.GROUP_MESSAGE: - datetime_str = datetime.datetime.now().strftime("%H:%M:%S") + """仅记录原始消息到 raw_records,不构建 contexts。""" + if event.get_message_type() != MessageType.GROUP_MESSAGE: + return - parts = [f"[{event.message_obj.sender.nickname}/{datetime_str}]: "] + umo = event.unified_msg_origin + async with self._get_lock(umo): + # 记录写入前索引 → on_req_llm 精确排除 + raw_idx = len(self.raw_records[umo]) + event.set_extra("_ltm_raw_idx", raw_idx) + datetime_str = datetime.datetime.now().strftime("%H:%M:%S") + parts = [f"[{event.message_obj.sender.nickname}/{datetime_str}]: "] cfg = self.cfg(event) for comp in event.get_messages(): @@ -143,46 +240,446 @@ async def handle_message(self, event: AstrMessageEvent) -> None: parts.append(f" [At: {comp.name}]") final_message = "".join(parts) - logger.debug(f"ltm | {event.unified_msg_origin} | {final_message}") - self.session_chats[event.unified_msg_origin].append(final_message) - if len(self.session_chats[event.unified_msg_origin]) > cfg["max_cnt"]: - self.session_chats[event.unified_msg_origin].pop(0) + logger.debug(f"ltm | {umo} | {final_message}") + self.raw_records[umo].append(final_message) + + # ========================================================================= + # LLM 请求前(on_llm_request 钩子 → decorate_llm_req 调用) + # ========================================================================= async def on_req_llm(self, event: AstrMessageEvent, req: ProviderRequest) -> None: - """当触发 LLM 请求前,调用此方法修改 req""" - if event.unified_msg_origin not in self.session_chats: + """增量构建 contexts 并注入到 req。ContextManager 由 agent runner 自动调用。""" + umo = event.unified_msg_origin + prompt_idx = event.get_extra("_ltm_raw_idx", -1) + if prompt_idx < 0: return - chats_str = "\n---\n".join(self.session_chats[event.unified_msg_origin]) + async with self._get_lock(umo): + if umo not in self.raw_records: + return + + raw_list = list(self.raw_records[umo]) + cursor = self._raw_cursor[umo] + new_raw = raw_list[cursor:prompt_idx] if prompt_idx > cursor else [] + + if new_raw: + new_segs = _build_segments(new_raw) + self.contexts[umo].extend(new_segs) + self._raw_cursor[umo] = prompt_idx + + # 前置保留 Persona 已注入的 begin_dialogs + existing_contexts = req.contexts or [] + ctxs: list[dict] = list(existing_contexts) + + # Inject LTM summary if available (LLM summary compaction strategy). + summary = self.summaries.get(umo, "") + if summary: + ctxs.append( + { + "role": "system", + "content": ("Long-term group memory summary:\n" + summary), + } + ) + + ctxs.extend(self.contexts[umo]) + req.contexts = ctxs + req.conversation = None + req.system_prompt += CHATROOM_SYSTEM_NOTE + + # ========================================================================= + # Agent 完成后(on_agent_done 钩子 → main.py 调用) + # ========================================================================= + async def on_agent_done( + self, + event: AstrMessageEvent, + run_context, # ContextWrapper + resp: LLMResponse, + ) -> None: + """记录工具链 + bot 回复到 raw_records,闭合段,裁剪。""" + umo = event.unified_msg_origin cfg = self.cfg(event) - if cfg["enable_active_reply"]: - prompt = req.prompt - req.prompt = ( - f"You are now in a chatroom. The chat history is as follows:\n{chats_str}" - f"\nNow, a new message is coming: `{prompt}`. " - "Please react to it. Only output your response and do not output any other information. " - "You MUST use the SAME language as the chatroom is using." - ) - req.contexts = [] # 清空上下文,当使用了主动回复,所有聊天记录都在一个prompt中。 - else: - req.system_prompt += ( - "You are now in a chatroom. The chat history is as follows: \n" + + async with self._get_lock(umo): + if umo not in self.raw_records: + return + + time_str = datetime.datetime.now().strftime("%H:%M:%S") + + # 1. 提取工具链 → raw_records(按 tool_call_id 去重,避免历史重复注入) + for msg in run_context.messages: + if msg.role == "assistant" and msg.tool_calls: + for tc in msg.tool_calls: + tc_dict = tc if isinstance(tc, dict) else tc.model_dump() + tc_id = tc_dict["id"] + if tc_id in self._persisted_tool_call_ids[umo]: + continue + self._persisted_tool_call_ids[umo].add(tc_id) + call_entry = { + "id": tc_id, + "name": tc_dict["function"]["name"], + "args": ( + json.loads(tc_dict["function"]["arguments"]) + if isinstance(tc_dict["function"]["arguments"], str) + else tc_dict["function"]["arguments"] + ), + } + self.raw_records[umo].append( + f"{json.dumps(call_entry, ensure_ascii=False)}" + ) + elif msg.role == "tool": + if msg.tool_call_id in self._persisted_tool_result_ids[umo]: + continue + self._persisted_tool_result_ids[umo].add(msg.tool_call_id) + content = ( + msg.content + if isinstance(msg.content, str) + else str(msg.content) + ) + if cfg["history_tool_result_truncate"]: + content = _truncate_tool_result_for_history( + content, cfg["history_tool_result_max_chars"] + ) + self.raw_records[umo].append( + f"{content}" + ) + + # 最终文本回复 + if resp and resp.completion_text: + self.raw_records[umo].append( + f": {resp.completion_text}" + ) + + # 2. 构建本轮全部未消费 raw 为 contexts 段(含 @bot prompt) + raw_list = list(self.raw_records[umo]) + cursor = self._raw_cursor[umo] + remaining = raw_list[cursor:] # 从 prompt_idx 开始,含 @bot 行 + if remaining: + new_segs = _build_segments(remaining) + self.contexts[umo].extend(new_segs) + self._raw_cursor[umo] = len(raw_list) + + # 2b. LTM persistent compaction — either turn-based truncation OR + # LLM summary, mutually exclusive. Both use high-water / low-water + # to avoid per-round churn. + strategy = cfg.get("ltm_compaction_strategy", "truncate") + rounds = _split_into_rounds(self.contexts[umo]) + + if strategy == "llm_summary": + provider_id = cfg.get("ltm_summary_provider_id", "") + trigger = cfg.get("ltm_summary_trigger_rounds", 80) + keep_recent = cfg.get("ltm_summary_keep_recent_rounds", 30) + if len(rounds) > trigger: + # Resolve provider: explicit ID first, fallback to + # current chat model for this group (umo). + if provider_id: + provider = self.context.get_provider_by_id(provider_id) + else: + provider = self.context.get_using_provider(umo) + if provider is None or not isinstance(provider, Provider): + logger.warning( + "LTM summary 没有可用的 provider (umo=%s, configured=%s)", + umo, + provider_id or "(auto)", + ) + else: + next_retry = self._summary_next_retry.get(umo, 0) + if len(rounds) < next_retry: + logger.debug( + "LTM summary 冷却中 (umo=%s, rounds=%d, 允许=%d)", + umo, + len(rounds), + next_retry, + ) + else: + await self._compact_with_llm_summary( + event, + provider, + keep_recent, + cfg.get("ltm_summary_prompt", ""), + rounds, + ) + else: + max_rounds = cfg.get("ltm_max_rounds", 80) + drop_rounds = cfg.get("ltm_truncate_drop_rounds", 50) + if len(rounds) > max_rounds: + safe_drop = min(drop_rounds, len(rounds) - 1) + kept = rounds[safe_drop:] + self.contexts[umo] = [seg for rnd in kept for seg in rnd] + + # 3. 裁剪 raw_records + self._trim_raw_records( + umo, max_bytes=cfg.get("ltm_raw_records_max_bytes", MAX_RAW_BYTES) ) - req.system_prompt += chats_str - async def after_req_llm( - self, event: AstrMessageEvent, llm_resp: LLMResponse + # ========================================================================= + # LTM compaction + # ========================================================================= + + async def _compact_with_llm_summary( + self, + event: AstrMessageEvent, + provider: Provider, + keep_recent: int, + prompt: str, + rounds: list[list[dict]], ) -> None: - if event.unified_msg_origin not in self.session_chats: + """Compress old rounds into a persistent summary using an LLM.""" + umo = event.unified_msg_origin + old_rounds = rounds[:-keep_recent] + recent_rounds = rounds[-keep_recent:] + if not old_rounds: return - if llm_resp.completion_text: - final_message = f"[You/{datetime.datetime.now().strftime('%H:%M:%S')}]: {llm_resp.completion_text}" - logger.debug( - f"Recorded AI response: {event.unified_msg_origin} | {final_message}" + old_text = _rounds_to_text(old_rounds) + existing_summary = self.summaries.get(umo, "") + instruction = prompt or ( + "Merge the older conversation rounds below into the existing " + "group-chat memory summary. Preserve stable facts about users, " + "preferences, decisions, recurring topics, and unresolved tasks. " + "Drop transient chatter, greetings, and irrelevant details. " + "Output only the updated summary, no preamble." + ) + + summary_prompt = ( + f"{instruction}\n\n" + f"Existing memory summary:\n{existing_summary or '(none)'}\n\n" + f"Older conversation rounds to merge:\n{old_text}" + ) + + try: + resp = await provider.text_chat( + prompt=summary_prompt, + session_id=uuid.uuid4().hex, + persist=False, ) - self.session_chats[event.unified_msg_origin].append(final_message) - cfg = self.cfg(event) - if len(self.session_chats[event.unified_msg_origin]) > cfg["max_cnt"]: - self.session_chats[event.unified_msg_origin].pop(0) + summary_text = resp.completion_text.strip() + if not summary_text: + logger.warning( + "LTM LLM summary 返回空文本,跳过本次压缩 (umo=%s, provider=%s)", + umo, + provider, + ) + self._summary_next_retry[umo] = len(rounds) + SUMMARY_RETRY_COOLDOWN + return + self.summaries[umo] = summary_text + self.contexts[umo] = [seg for rnd in recent_rounds for seg in rnd] + # 成功后清除冷却,下次按正常 trigger 走 + self._summary_next_retry.pop(umo, None) + except Exception: + logger.warning("LTM LLM summary 失败,保留原始 contexts", exc_info=True) + self._summary_next_retry[umo] = len(rounds) + SUMMARY_RETRY_COOLDOWN + + # ========================================================================= + # 裁剪 + # ========================================================================= + + def _trim_raw_records(self, umo: str, max_bytes: int = MAX_RAW_BYTES) -> None: + """仅淘汰 cursor 之前的条目。cursor 之后的绝不碰。""" + dq = self.raw_records[umo] + cursor = self._raw_cursor[umo] + + # 1. 无条件清除 cursor 之前的条目(已消费) + while dq and cursor > 0: + dq.popleft() + cursor -= 1 + self._raw_cursor[umo] = cursor + + # 2. 按大小继续从前面淘汰(限制极端情况的总内存) + total = sum(len(s.encode()) for s in dq) + while total > max_bytes and dq: + removed = dq.popleft() + total -= len(removed.encode()) + if cursor > 0: + cursor -= 1 + self._raw_cursor[umo] = max(0, cursor) + + +# ============================================================================= +# _build_segments — 从 raw lines 构建 OpenAI 格式 contexts 段 +# ============================================================================= + + +def _build_segments(raw_lines: list[str]) -> list[dict]: + """从 raw strings 构建 OpenAI 格式 contexts 段。 + + 规则: + 1. json → 连续多条合并为一个 assistant(tool_calls) + 2. content → tool 消息,tool_call_id 配对 + 3. : content → assistant(纯文本) + 4. 其它行 → user(合并为段,段内裁剪 MAX_MSGS/MAX_CHARS) + """ + if not raw_lines: + return [] + + segments: list[dict] = [] + user_buf: list[str] = [] + tool_calls_buf: list[dict] = [] + + def flush_user(): + if not user_buf: + return + truncated = _truncate_user_segment(user_buf) + segments.append({"role": "user", "content": "\n".join(truncated)}) + user_buf.clear() + + def flush_tool_calls(): + if not tool_calls_buf: + return + segments.append( + { + "role": "assistant", + "content": None, + "tool_calls": tool_calls_buf.copy(), + } + ) + tool_calls_buf.clear() + + for line in raw_lines: + if line.startswith(TOOL_CALL_PREFIX): + flush_user() + tc_data = _parse_tool_call(line) + if tc_data: + tool_calls_buf.append(tc_data) + elif line.startswith(TOOL_RES_PREFIX): + flush_user() + flush_tool_calls() + tool_msg = _parse_tool_result(line) + if tool_msg: + segments.append(tool_msg) + elif line.startswith(BOT_MARKER): + flush_user() + flush_tool_calls() + content = _extract_bot_content(line) + if content: + segments.append({"role": "assistant", "content": content}) + else: + user_buf.append(line) + + flush_user() + flush_tool_calls() + return segments + + +# ============================================================================= +# 解析 helper +# ============================================================================= + + +def _parse_tool_call(line: str) -> dict | None: + """{"id":"x","name":"f","args":{...}} → tool_call dict""" + inner = _extract_tag_content(line, TOOL_CALL_PREFIX, "") + if not inner: + return None + try: + tc = json.loads(inner) + tc_id = tc["id"] + tc_name = tc["name"] + tc_args = tc["args"] + except (json.JSONDecodeError, TypeError, KeyError): + return None + return { + "id": tc_id, + "type": "function", + "function": { + "name": tc_name, + "arguments": json.dumps(tc_args, ensure_ascii=False), + }, + } + + +def _parse_tool_result(line: str) -> dict | None: + """content → {"role":"tool", ...}""" + rest = line[len(TOOL_RES_PREFIX) :].strip() + gt = rest.find(">") + if gt == -1: + return None + id_part = rest[:gt] + content = rest[gt + 1 :] + if content.endswith(""): + content = content[: -len("")] + if not id_part.startswith("id="): + return None + tc_id = id_part[3:] + return {"role": "tool", "tool_call_id": tc_id, "content": content} + + +def _truncate_tool_result_for_history(content: str, max_chars: int) -> str: + """Truncate a single tool result before persisting into LTM history.""" + if max_chars <= 0 or len(content) <= max_chars: + return content + + omitted = len(content) - max_chars + marker = f"\n...[TRUNCATED {omitted} chars]..." + if len(marker) >= max_chars: + return content[:max_chars] + + head_len = max_chars - len(marker) + return content[:head_len] + marker + + +def _extract_bot_content(line: str) -> str | None: + """: content → content""" + idx = line.find(">: ") + if idx == -1: + return None + return line[idx + 3 :].strip() + + +def _extract_tag_content(line: str, start_tag: str, end_tag: str) -> str | None: + """content → content""" + if not line.endswith(end_tag): + return None + return line[len(start_tag) : -len(end_tag)].strip() + + +def _truncate_user_segment(lines: list[str]) -> list[str]: + """段内裁剪:保留最近 N 条,不超字符上限。从段内最早的消息开始丢弃。""" + result: list[str] = [] + total = 0 + for line in reversed(lines): + if len(result) >= MAX_MSGS_PER_USER_SEGMENT: + break + if total + len(line) > MAX_CHARS_PER_USER_SEGMENT and result: + break + result.append(line) + total += len(line) + 1 # +1 for \n + result.reverse() + return result + + +# ============================================================================= +# _split_into_rounds — LTM compaction helper +# ============================================================================= + + +def _split_into_rounds(contexts: list[dict[str, Any]]) -> list[list[dict[str, Any]]]: + """Split a flat contexts list into logical rounds. + + A round begins at a ``user`` segment and includes all subsequent + ``assistant`` / ``tool`` segments until the next ``user`` segment. + """ + rounds: list[list[dict[str, Any]]] = [] + current: list[dict[str, Any]] = [] + for seg in contexts: + if seg.get("role") == "user" and current: + rounds.append(current) + current = [] + current.append(seg) + if current: + rounds.append(current) + return rounds + + +def _rounds_to_text(rounds: list[list[dict[str, Any]]]) -> str: + """Render rounds into a plain-text string for LLM summarisation.""" + lines: list[str] = [] + for i, rnd in enumerate(rounds, 1): + lines.append(f"--- Round {i} ---") + for seg in rnd: + role = seg.get("role", "?") + content = seg.get("content") or seg.get("tool_calls") or "" + if isinstance(content, list): + content = json.dumps(content, ensure_ascii=False) + lines.append(f"[{role}] {content}") + return "\n".join(lines) diff --git a/astrbot/builtin_stars/astrbot/main.py b/astrbot/builtin_stars/astrbot/main.py index 828e49552f..95ff769f11 100644 --- a/astrbot/builtin_stars/astrbot/main.py +++ b/astrbot/builtin_stars/astrbot/main.py @@ -23,6 +23,7 @@ class Main(star.Star): def __init__(self, context: star.Context) -> None: self.context = context self.ltm = None + self._ltm_was_enabled: dict[str, bool] = {} try: self.ltm = LongTermMemory(self.context.astrbot_config_mgr, self.context) except BaseException as e: @@ -161,26 +162,29 @@ async def on_message(self, event: AstrMessageEvent): return try: conv = None - session_curr_cid = await self.context.conversation_manager.get_curr_conversation_id( - event.unified_msg_origin, - ) - if not session_curr_cid: - logger.error( - "当前未处于对话状态,无法主动回复,请确保 平台设置->会话隔离(unique_session) 未开启,并使用 /switch 序号 切换或者 /new 创建一个会话。", + if not group_icl_enable: + # 仅在走 Conversation 模式时才需要查询会话 + session_curr_cid = await self.context.conversation_manager.get_curr_conversation_id( + event.unified_msg_origin, ) - return - conv = await self.context.conversation_manager.get_conversation( - event.unified_msg_origin, - session_curr_cid, - ) + if not session_curr_cid: + logger.error( + "当前未处于对话状态,无法主动回复,请确保 平台设置->会话隔离(unique_session) 未开启,并使用 /switch 序号 切换或者 /new 创建一个会话。", + ) + return - prompt = event.message_str + conv = await self.context.conversation_manager.get_conversation( + event.unified_msg_origin, + session_curr_cid, + ) + + if not conv: + logger.error("未找到对话,无法主动回复") + return - if not conv: - logger.error("未找到对话,无法主动回复") - return + prompt = event.message_str yield event.request_llm( prompt=prompt, @@ -197,19 +201,29 @@ async def decorate_llm_req( ) -> None: """在请求 LLM 前注入人格信息、Identifier、时间、回复内容等 System Prompt""" if self.ltm and self.ltm_enabled(event): + umo = event.unified_msg_origin + + # 惰性切换检测:false → true 时清理残留旧数据 + now_enabled = self.ltm_enabled(event) + was_enabled = self._ltm_was_enabled.get(umo, False) + if now_enabled and not was_enabled: + await self.ltm.remove_session(event) + logger.info(f"LTM: group_icl_enable 开启,已重置 {umo} 上下文") + self._ltm_was_enabled[umo] = now_enabled + try: await self.ltm.on_req_llm(event, req) except BaseException as e: logger.error(f"ltm: {e}") - @filter.on_llm_response() - async def record_llm_resp_to_ltm( - self, event: AstrMessageEvent, resp: LLMResponse + @filter.on_agent_done() + async def record_agent_result_to_ltm( + self, event: AstrMessageEvent, run_context, resp: LLMResponse ) -> None: - """在 LLM 响应后记录对话""" + """Agent 完成后记录对话(含工具链)""" if self.ltm and self.ltm_enabled(event): try: - await self.ltm.after_req_llm(event, resp) + await self.ltm.on_agent_done(event, run_context, resp) except Exception as e: logger.error(f"ltm: {e}") diff --git a/astrbot/core/agent/context/guard.py b/astrbot/core/agent/context/guard.py new file mode 100644 index 0000000000..eedff5adc8 --- /dev/null +++ b/astrbot/core/agent/context/guard.py @@ -0,0 +1,28 @@ +from ..message import Message +from .config import ContextConfig +from .manager import ContextManager + + +class RequestContextGuard: + """Request-time context guard before sending messages to a provider. + + This guard is intentionally scoped to a single provider request. It may + truncate or compress the in-flight messages to keep the current request + within model/provider limits, but it does not own persistent history and + should not be treated as the memory-layer compactor. + """ + + def __init__(self, config: ContextConfig) -> None: + self.config = config + self._manager = ContextManager(config) + + async def process( + self, + messages: list[Message], + trusted_token_usage: int = 0, + ) -> list[Message]: + """Apply request-time context guarding to messages.""" + return await self._manager.process( + messages, + trusted_token_usage=trusted_token_usage, + ) diff --git a/astrbot/core/agent/runners/tool_loop_agent_runner.py b/astrbot/core/agent/runners/tool_loop_agent_runner.py index 968426b8b4..f421a05c34 100644 --- a/astrbot/core/agent/runners/tool_loop_agent_runner.py +++ b/astrbot/core/agent/runners/tool_loop_agent_runner.py @@ -50,7 +50,7 @@ from ..context.compressor import ContextCompressor from ..context.config import ContextConfig -from ..context.manager import ContextManager +from ..context.guard import RequestContextGuard from ..context.token_counter import EstimateTokenCounter, TokenCounter from ..hooks import BaseAgentRunHooks from ..message import ( @@ -241,13 +241,10 @@ async def reset( self.tool_result_overflow_dir = tool_result_overflow_dir self.read_tool = read_tool self._tool_result_token_counter = EstimateTokenCounter() - # we will do compress when: - # 1. before requesting LLM - # TODO: 2. after LLM output a tool call - self.context_config = ContextConfig( - # <=0 will never do compress + self.request_context_guard_config = ContextConfig( + # <=0 disables token-based guarding. max_context_tokens=provider.provider_config.get("max_context_tokens", 0), - # enforce max turns before compression + # Enforce max turns before token-based guarding. enforce_max_turns=self.enforce_max_turns, truncate_turns=self.truncate_turns, llm_compress_instruction=self.llm_compress_instruction, @@ -256,7 +253,9 @@ async def reset( custom_token_counter=self.custom_token_counter, custom_compressor=self.custom_compressor, ) - self.context_manager = ContextManager(self.context_config) + self.request_context_guard = RequestContextGuard( + self.request_context_guard_config + ) self.provider = provider self.fallback_providers: list[Provider] = [] @@ -459,8 +458,11 @@ async def _iter_llm_responses( self, *, include_model: bool = True ) -> T.AsyncGenerator[LLMResponse, None]: """Yields chunks *and* a final LLMResponse.""" + messages_for_provider = getattr( + self, "_provider_messages", self.run_context.messages + ) payload = { - "contexts": self._sanitize_contexts_for_provider(self.run_context.messages), + "contexts": self._sanitize_contexts_for_provider(messages_for_provider), "func_tool": self._func_tool_for_provider(), "session_id": self.req.session_id, "extra_user_content_parts": self.req.extra_user_content_parts, # list[ContentPart] @@ -704,10 +706,13 @@ async def step(self): self._transition_state(AgentState.RUNNING) llm_resp_result = None - # do truncate and compress + # Apply request-time context guard *on a copy* so the runner's canonical + # messages are never mutated by the guard. The guard result is only used + # for this provider call. Persistent compaction is owned by the + # conversation / memory layer. token_usage = self.req.conversation.token_usage if self.req.conversation else 0 self._simple_print_message_role("[BefCompact]") - self.run_context.messages = await self.context_manager.process( + self._provider_messages = await self.request_context_guard.process( self.run_context.messages, trusted_token_usage=token_usage ) self._simple_print_message_role("[AftCompact]") @@ -1395,6 +1400,9 @@ async def _iter_tool_executor_results( self, executor: AsyncIterator[ToolExecutorResultT], ) -> T.AsyncGenerator[ToolExecutorResultT, None]: + async def _next_executor_result() -> ToolExecutorResultT: + return await anext(executor) + while True: if self._is_stop_requested(): await self._close_executor(executor) @@ -1402,7 +1410,7 @@ async def _iter_tool_executor_results( "Tool execution interrupted before reading the next tool result." ) - next_result_task = asyncio.create_task(anext(executor)) + next_result_task = asyncio.create_task(_next_executor_result()) abort_task = asyncio.create_task(self._abort_signal.wait()) try: done, _ = await asyncio.wait( diff --git a/astrbot/core/astr_main_agent.py b/astrbot/core/astr_main_agent.py index fd1a9aeb8c..04cd572f79 100644 --- a/astrbot/core/astr_main_agent.py +++ b/astrbot/core/astr_main_agent.py @@ -1119,26 +1119,21 @@ async def _apply_web_search_tools( def _get_compress_provider( - config: MainAgentBuildConfig, plugin_context: Context + config: MainAgentBuildConfig, + plugin_context: Context, + event: AstrMessageEvent | None = None, ) -> Provider | None: - if not config.llm_compress_provider_id: - return None if config.context_limit_reached_strategy != "llm_compress": return None - provider = plugin_context.get_provider_by_id(config.llm_compress_provider_id) - if provider is None: + if config.llm_compress_provider_id: + provider = plugin_context.get_provider_by_id(config.llm_compress_provider_id) + if provider and isinstance(provider, Provider): + return provider logger.warning( - "未找到指定的上下文压缩模型 %s,将跳过压缩。", + "指定的上下文压缩模型 %s 不可用", config.llm_compress_provider_id, ) - return None - if not isinstance(provider, Provider): - logger.warning( - "指定的上下文压缩模型 %s 不是对话模型,将跳过压缩。", - config.llm_compress_provider_id, - ) - return None - return provider + return None def _get_fallback_chat_providers( @@ -1449,9 +1444,8 @@ async def build_main_agent( streaming=config.streaming_response, llm_compress_instruction=config.llm_compress_instruction, llm_compress_keep_recent=config.llm_compress_keep_recent, - llm_compress_provider=_get_compress_provider(config, plugin_context), + llm_compress_provider=_get_compress_provider(config, plugin_context, event), truncate_turns=config.dequeue_context_length, - enforce_max_turns=config.max_context_length, tool_schema_mode=config.tool_schema_mode, fallback_providers=_get_fallback_chat_providers( provider, plugin_context, config.provider_settings diff --git a/astrbot/core/config/default.py b/astrbot/core/config/default.py index 6a689784ae..cc9b1d3a0e 100644 --- a/astrbot/core/config/default.py +++ b/astrbot/core/config/default.py @@ -120,7 +120,7 @@ "default_personality": "default", "persona_pool": ["*"], "prompt_prefix": "{{prompt}}", - "context_limit_reached_strategy": "truncate_by_turns", # or llm_compress + "context_limit_reached_strategy": "llm_compress", # or truncate_by_turns "llm_compress_instruction": ( "Based on our full conversation history, produce a concise summary of key takeaways and/or project progress.\n" "1. Systematically cover all core topics discussed and the final conclusion/outcome for each; clearly highlight the latest primary focus.\n" @@ -130,8 +130,8 @@ ), "llm_compress_keep_recent": 6, "llm_compress_provider_id": "", - "max_context_length": -1, - "dequeue_context_length": 1, + "max_context_length": 25, + "dequeue_context_length": 10, "streaming_response": False, "show_tool_use_status": False, "show_tool_call_result": False, @@ -217,9 +217,18 @@ }, "provider_ltm_settings": { "group_icl_enable": False, - "group_message_max_cnt": 300, "image_caption": False, "image_caption_provider_id": "", + "history_tool_result_truncate": True, + "history_tool_result_max_chars": 8192, + "ltm_compaction_strategy": "truncate", + "ltm_max_rounds": 80, + "ltm_truncate_drop_rounds": 50, + "ltm_summary_trigger_rounds": 80, + "ltm_summary_keep_recent_rounds": 30, + "ltm_summary_provider_id": "", + "ltm_summary_prompt": "", + "ltm_raw_records_max_bytes": 500000, "active_reply": { "enable": False, "method": "possibility_reply", @@ -2862,9 +2871,6 @@ "group_icl_enable": { "type": "bool", }, - "group_message_max_cnt": { - "type": "int", - }, "image_caption": { "type": "bool", }, @@ -2874,6 +2880,12 @@ "image_caption_prompt": { "type": "string", }, + "history_tool_result_truncate": { + "type": "bool", + }, + "history_tool_result_max_chars": { + "type": "int", + }, "active_reply": { "type": "object", "items": { @@ -3483,30 +3495,30 @@ "type": "object", "items": { "provider_settings.max_context_length": { - "description": "最多携带对话轮数", + "description": "压缩前最多保留对话轮数", "type": "int", - "hint": "超出这个数量时丢弃最旧的部分,一轮聊天记为 1 条,-1 为不限制", + "hint": "普通会话历史超过该轮数后,才会按下方策略进行持久化截断或 LLM 压缩;请求发送前也会先按该值约束上下文。-1 表示不按轮数限制。", "condition": { "provider_settings.agent_runner_type": "local", }, }, "provider_settings.dequeue_context_length": { - "description": "丢弃对话轮数", + "description": "轮次超限时一次丢弃轮数", "type": "int", - "hint": "超出最多携带对话轮数时, 一次丢弃的聊天轮数", + "hint": "当超过“压缩前最多保留对话轮数”且无法使用 LLM 压缩时,一次丢弃多少轮旧对话;请求期截断也会复用该值。", "condition": { "provider_settings.agent_runner_type": "local", }, }, "provider_settings.context_limit_reached_strategy": { - "description": "超出模型上下文窗口时的处理方式", + "description": "历史超限或上下文接近上限时的处理方式", "type": "string", "options": ["truncate_by_turns", "llm_compress"], "labels": ["按对话轮数截断", "由 LLM 压缩上下文"], "condition": { "provider_settings.agent_runner_type": "local", }, - "hint": "", + "hint": "普通会话历史仅在超过“压缩前最多保留对话轮数”后执行该策略;请求发送前也会在上下文 token 接近模型窗口时使用同一策略保护本次请求。", }, "provider_settings.llm_compress_instruction": { "description": "上下文压缩提示词", @@ -3530,7 +3542,7 @@ "description": "用于上下文压缩的模型提供商 ID", "type": "string", "_special": "select_provider", - "hint": "留空时将降级为“按对话轮数截断”的策略。", + "hint": "留空时使用当前聊天模型进行压缩;如果模型不可用或压缩失败,将回退为“按对话轮数截断”的策略。", "condition": { "provider_settings.context_limit_reached_strategy": "llm_compress", "provider_settings.agent_runner_type": "local", @@ -4078,10 +4090,6 @@ "description": "启用群聊上下文感知", "type": "bool", }, - "provider_ltm_settings.group_message_max_cnt": { - "description": "最大消息数量", - "type": "int", - }, "provider_ltm_settings.image_caption": { "description": "自动理解图片", "type": "bool", @@ -4096,6 +4104,79 @@ "provider_ltm_settings.image_caption": True, }, }, + "provider_ltm_settings.history_tool_result_truncate": { + "description": "截断历史工具输出", + "type": "bool", + "hint": "仅影响群聊 LTM 历史轮,不影响当前工具调用轮的完整推理。", + }, + "provider_ltm_settings.history_tool_result_max_chars": { + "description": "历史工具输出截断上限", + "type": "int", + "hint": "单条工具输出写入群聊历史时的最大字符数,默认 8192。", + "condition": { + "provider_ltm_settings.history_tool_result_truncate": True, + }, + }, + "provider_ltm_settings.ltm_compaction_strategy": { + "description": "LTM 上下文压缩策略", + "type": "string", + "options": ["truncate", "llm_summary"], + "hint": "truncate: 按轮截断; llm_summary: 调用 LLM 做长期摘要。", + }, + "provider_ltm_settings.ltm_max_rounds": { + "description": "LTM 最大保留轮数", + "type": "int", + "hint": "truncate 策略生效时的截断上限,默认 80。", + "condition": { + "provider_ltm_settings.ltm_compaction_strategy": "truncate", + }, + }, + "provider_ltm_settings.ltm_truncate_drop_rounds": { + "description": "截断时丢弃轮数", + "type": "int", + "hint": "truncate 策略触发截断时,从前面丢弃多少轮。默认 50。", + "condition": { + "provider_ltm_settings.ltm_compaction_strategy": "truncate", + }, + }, + "provider_ltm_settings.ltm_summary_trigger_rounds": { + "description": "摘要触发轮数", + "type": "int", + "hint": "超过多少轮时触发 LLM 摘要压缩,默认 80。", + "condition": { + "provider_ltm_settings.ltm_compaction_strategy": "llm_summary", + }, + }, + "provider_ltm_settings.ltm_summary_keep_recent_rounds": { + "description": "摘要时保留最近轮数", + "type": "int", + "hint": "llm_summary 策略下保留最近 N 轮精确上下文,默认 30。", + "condition": { + "provider_ltm_settings.ltm_compaction_strategy": "llm_summary", + }, + }, + "provider_ltm_settings.ltm_summary_provider_id": { + "description": "LTM 摘要模型", + "type": "string", + "_special": "select_provider", + "hint": "llm_summary 策略使用的模型,留空使用当前聊天模型。", + "condition": { + "provider_ltm_settings.ltm_compaction_strategy": "llm_summary", + }, + }, + "provider_ltm_settings.ltm_summary_prompt": { + "description": "LTM 摘要提示词", + "type": "string", + "hint": "llm_summary 策略的自定义摘要 prompt,留空使用内置默认。", + "condition": { + "provider_ltm_settings.ltm_compaction_strategy": "llm_summary", + }, + }, + "provider_ltm_settings.ltm_raw_records_max_bytes": { + "description": "Raw Records 最大内存字节", + "type": "int", + "hint": "每个群聊允许 raw_records 占用的最大字节数,默认 500000 (500KB)。", + }, "provider_ltm_settings.active_reply.enable": { "description": "主动回复", "type": "bool", diff --git a/astrbot/core/pipeline/process_stage/method/agent_sub_stages/internal.py b/astrbot/core/pipeline/process_stage/method/agent_sub_stages/internal.py index fee641c192..1a707689cb 100644 --- a/astrbot/core/pipeline/process_stage/method/agent_sub_stages/internal.py +++ b/astrbot/core/pipeline/process_stage/method/agent_sub_stages/internal.py @@ -49,6 +49,38 @@ ) +def _count_conversation_turns(messages: list[Message]) -> int: + """Count persisted conversation turns by user messages. + + A turn starts with a user message and may include assistant tool calls, + tool results, and the final assistant answer. Counting user messages avoids + treating tool call/result pairs as additional conversation turns. + """ + return sum(1 for message in messages if message.role == "user") + + +def _history_exceeds_turn_limit(messages: list[Message], max_turns: int) -> bool: + """Return whether persisted history exceeds the configured turn limit.""" + if max_turns == -1: + return False + if max_turns <= 0: + return False + return _count_conversation_turns(messages) > max_turns + + +def _has_valid_summary_message(messages: list[Message]) -> bool: + """Return whether LLM compression produced a non-empty summary block.""" + summary_prefix = "Our previous history conversation summary:" + for message in messages: + if message.role != "user" or not isinstance(message.content, str): + continue + if not message.content.startswith(summary_prefix): + continue + summary_text = message.content.removeprefix(summary_prefix).strip() + return bool(summary_text) + return False + + class InternalAgentSubStage(Stage): async def initialize(self, ctx: PipelineContext) -> None: self.ctx = ctx @@ -463,6 +495,74 @@ async def _save_to_history( continue messages_to_save.append(message) + # Persistent conversation compaction — either turn-based truncation OR + # LLM summary, mutually exclusive. Only compact persisted history when + # the configured turn limit is exceeded; request-time token guarding is + # handled separately by the agent runner. + if _history_exceeds_turn_limit(messages_to_save, self.max_context_length): + from astrbot.core.agent.context.truncator import ContextTruncator + + def fallback_truncate() -> list[Message]: + truncator = ContextTruncator() + return truncator.truncate_by_turns( + messages_to_save, + keep_most_recent_turns=self.max_context_length, + drop_turns=self.dequeue_context_length, + ) + + compress_provider = None + if self.context_limit_reached_strategy == "llm_compress": + from astrbot.api.provider import Provider as ApiProvider + + provider_source = ( + self.llm_compress_provider_id or "(current chat model)" + ) + if self.llm_compress_provider_id: + raw_provider = self.ctx.plugin_manager.context.get_provider_by_id( + self.llm_compress_provider_id + ) + else: + raw_provider = self.ctx.plugin_manager.context.get_using_provider( + umo=event.unified_msg_origin + ) + + if raw_provider is not None and isinstance(raw_provider, ApiProvider): + compress_provider = raw_provider + if not self.llm_compress_provider_id: + logger.info("llm_compress 使用当前聊天模型进行持久化历史压缩") + else: + logger.warning( + "上下文压缩模型 %s 不可用,将回退为按对话轮数截断", + provider_source, + ) + + if compress_provider is not None: + # LLM summary strategy: compress old turns into a summary. + from astrbot.core.agent.context.compressor import ( + LLMSummaryCompressor, + ) + + original_messages = messages_to_save + compressor = LLMSummaryCompressor( + provider=compress_provider, + keep_recent=self.llm_compress_keep_recent, + instruction_text=self.llm_compress_instruction, + ) + compressed_messages = await compressor(original_messages) + if ( + compressed_messages is original_messages + or not _has_valid_summary_message(compressed_messages) + ): + logger.warning( + "LLM 上下文压缩未产生有效摘要,将回退为按对话轮数截断" + ) + messages_to_save = fallback_truncate() + else: + messages_to_save = compressed_messages + else: + # Fallback: turn-based truncation only. + messages_to_save = fallback_truncate() + checkpoint_id = event.get_extra("llm_checkpoint_id") message_to_save = dump_messages_with_checkpoints(messages_to_save) if isinstance(checkpoint_id, str) and checkpoint_id: diff --git a/dashboard/src/i18n/locales/en-US/features/config-metadata.json b/dashboard/src/i18n/locales/en-US/features/config-metadata.json index 534358439b..f0970dc5d2 100644 --- a/dashboard/src/i18n/locales/en-US/features/config-metadata.json +++ b/dashboard/src/i18n/locales/en-US/features/config-metadata.json @@ -247,19 +247,19 @@ "provider_settings": { "max_context_length": { "description": "Max Turns Before Compression", - "hint": "Limits history turns before any compression strategy is applied; -1 means no turn-based limit" + "hint": "Persistent conversation history is truncated or LLM-compressed by the strategy below only after it exceeds this many turns. Request-time contexts are also constrained by this value before sending. -1 means no turn-based limit." }, "dequeue_context_length": { "description": "Turns to Discard When Limit Exceeded", - "hint": "Number of old conversation turns to discard at once when the turn limit is exceeded; also used as fallback when compression is unavailable" + "hint": "When history exceeds 'Max Turns Before Compression' and LLM compression is unavailable, discard this many oldest turns at once. Request-time truncation also reuses this value." }, "context_limit_reached_strategy": { - "description": "Handling When Context Approaches Model Limit", + "description": "Handling for History Limits or Context Window Pressure", "labels": [ "Truncate by Turns", "Compress by LLM" ], - "hint": "This strategy only triggers after turn-based limiting, when context tokens approach the model's window limit. When 'Truncate by Turns' is selected, the oldest N conversation turns will be discarded based on the 'Turns to Discard When Limit Exceeded' setting above. When 'Compress by LLM' is selected, the specified model will be used for context compression." + "hint": "Persistent conversation history uses this strategy only after exceeding 'Max Turns Before Compression'. Before each request, the same strategy may also protect the in-flight context when tokens approach the model window." }, "llm_compress_instruction": { "description": "Context Compression Instruction", @@ -271,7 +271,7 @@ }, "llm_compress_provider_id": { "description": "Model Provider ID for Context Compression", - "hint": "When left empty, will fall back to the 'Truncate by Turns' strategy." + "hint": "When left empty, the current chat model will be used for compression. If the model is unavailable or compression fails, AstrBot falls back to the 'Truncate by Turns' strategy." }, "fallback_max_context_tokens": { "description": "Fallback context window size", @@ -989,9 +989,6 @@ "group_icl_enable": { "description": "Enable Group Chat Context Awareness" }, - "group_message_max_cnt": { - "description": "Maximum Message Count" - }, "image_caption": { "description": "Auto-understand Images", "hint": "Requires setting a group chat image caption model." @@ -1000,6 +997,46 @@ "description": "Group Chat Image Caption Model", "hint": "Used for image understanding in group chat context awareness, configured separately from the default image caption model." }, + "history_tool_result_truncate": { + "description": "Truncate Historical Tool Output", + "hint": "Only affects group chat LTM history, not the current tool call round's full reasoning." + }, + "history_tool_result_max_chars": { + "description": "Historical Tool Output Truncation Limit", + "hint": "Maximum characters per tool output written to group chat history, default 8192." + }, + "ltm_compaction_strategy": { + "description": "LTM Context Compaction Strategy", + "hint": "truncate: keep only recent rounds; llm_summary: use LLM to produce a long-term memory summary." + }, + "ltm_max_rounds": { + "description": "LTM Max Retained Rounds", + "hint": "When truncate strategy is active, drop rounds beyond this limit. Default 80." + }, + "ltm_truncate_drop_rounds": { + "description": "Truncate Drop Rounds", + "hint": "When truncate strategy triggers, drop this many rounds from the front. Default 50." + }, + "ltm_summary_trigger_rounds": { + "description": "Summary Trigger Rounds", + "hint": "Trigger LLM summary compaction when rounds exceed this. Default 80." + }, + "ltm_summary_keep_recent_rounds": { + "description": "Recent Rounds to Keep (Summary)", + "hint": "Under llm_summary strategy, keep this many recent rounds precise. Default 30." + }, + "ltm_summary_provider_id": { + "description": "LTM Summary Model", + "hint": "Model used for llm_summary compaction. Leave empty to use the current chat model." + }, + "ltm_summary_prompt": { + "description": "LTM Summary Prompt", + "hint": "Custom summary prompt for llm_summary strategy. Leave empty for built-in default." + }, + "ltm_raw_records_max_bytes": { + "description": "Raw Message Buffer Memory Limit", + "hint": "Maximum bytes for the unprocessed message buffer per group. Prevents memory overflow in groups where the bot hasn't been @-mentioned for a long time. Default 500000 (500KB)." + }, "active_reply": { "enable": { "description": "Active Reply" diff --git a/dashboard/src/i18n/locales/ru-RU/features/config-metadata.json b/dashboard/src/i18n/locales/ru-RU/features/config-metadata.json index d196cb37a3..761af076bf 100644 --- a/dashboard/src/i18n/locales/ru-RU/features/config-metadata.json +++ b/dashboard/src/i18n/locales/ru-RU/features/config-metadata.json @@ -246,20 +246,20 @@ "description": "Стратегия управления контекстом", "provider_settings": { "max_context_length": { - "description": "Макс. количество раундов диалога", - "hint": "При превышении удаляются старые сообщения. 1 раунд = 1 пара запрос-ответ. -1 означает без ограничений." + "description": "Макс. раундов перед сжатием", + "hint": "Постоянная история диалога обрезается или сжимается LLM по стратегии ниже только после превышения этого числа раундов. Контекст перед запросом также ограничивается этим значением. -1 означает без ограничений по раундам." }, "dequeue_context_length": { - "description": "Кол-во удаляемых раундов", - "hint": "Сколько раундов удалять за один раз при достижении лимита." + "description": "Раундов для удаления при превышении лимита", + "hint": "Когда история превышает лимит раундов и LLM-сжатие недоступно, за один раз удаляется это число самых старых раундов. Обрезка перед запросом также использует это значение." }, "context_limit_reached_strategy": { - "description": "Действие при переполнении окна контекста", + "description": "Действие при лимите истории или давлении окна контекста", "labels": [ "Обрезать по раундам", "Сжать с помощью LLM" ], - "hint": "При выборе 'Обрезать' удаляются старые сообщения. При выборе 'Сжать' используется модель для суммаризации контекста." + "hint": "Постоянная история диалога использует эту стратегию только после превышения лимита раундов. Перед каждым запросом та же стратегия может защищать текущий контекст, когда токены приближаются к окну модели." }, "llm_compress_instruction": { "description": "Инструкция для сжатия контекста", @@ -271,7 +271,7 @@ }, "llm_compress_provider_id": { "description": "Модель для сжатия контекста", - "hint": "Если не выбрано, произойдет откат к стратегии удаления сообщений." + "hint": "Если не выбрано, для сжатия используется текущая модель чата. Если модель недоступна или сжатие завершается ошибкой, AstrBot откатывается к обрезке по раундам." }, "fallback_max_context_tokens": { "description": "Запасной размер окна контекста", @@ -990,9 +990,6 @@ "group_icl_enable": { "description": "Включить осведомленность о контексте группы" }, - "group_message_max_cnt": { - "description": "Максимальное количество сообщений" - }, "image_caption": { "description": "Автоматическое понимание изображений", "hint": "Требуется настройка модели описания изображений для группового чата." @@ -1001,6 +998,46 @@ "description": "Модель описания изображений для групп", "hint": "Используется для понимания изображений в контексте группового чата, настраивается отдельно от основной модели." }, + "history_tool_result_truncate": { + "description": "Обрезать вывод истории инструментов", + "hint": "Влияет только на историю LTM в группах, не затрагивая полный вывод текущего раунда." + }, + "history_tool_result_max_chars": { + "description": "Лимит обрезки истории инструментов", + "hint": "Макс. символов для вывода одного инструмента в истории группы, по умолчанию 8192." + }, + "ltm_compaction_strategy": { + "description": "Стратегия сжатия LTM", + "hint": "truncate: только последние раунды; llm_summary: LLM резюмирует старые раунды." + }, + "ltm_max_rounds": { + "description": "Макс. раундов LTM", + "hint": "При стратегии truncate — лимит раундов, по умолчанию 80." + }, + "ltm_truncate_drop_rounds": { + "description": "Сброс раундов при truncate", + "hint": "При срабатывании truncate — сколько раундов отбросить спереди. По умолчанию 50." + }, + "ltm_summary_trigger_rounds": { + "description": "Порог для LLM резюме", + "hint": "При превышении этого количества раундов запускается LLM резюме. По умолчанию 80." + }, + "ltm_summary_keep_recent_rounds": { + "description": "Сохранять последних раундов (LLM summary)", + "hint": "При llm_summary — сколько последних раундов хранить точно, по умолчанию 30." + }, + "ltm_summary_provider_id": { + "description": "Модель для LTM резюме", + "hint": "Модель для llm_summary. Оставьте пустым, чтобы использовать текущую модель чата." + }, + "ltm_summary_prompt": { + "description": "Промпт для LTM резюме", + "hint": "Свой промпт для llm_summary. Оставьте пустым для встроенного по умолчанию." + }, + "ltm_raw_records_max_bytes": { + "description": "Лимит буфера сообщений", + "hint": "Максимальный размер необработанного буфера сообщений на группу. Предотвращает переполнение памяти в группах, где бот долго не упоминался. По умолчанию 500000 (500KB)." + }, "active_reply": { "enable": { "description": "Активный ответ" diff --git a/dashboard/src/i18n/locales/zh-CN/features/config-metadata.json b/dashboard/src/i18n/locales/zh-CN/features/config-metadata.json index c8d9d572af..42e6f30e33 100644 --- a/dashboard/src/i18n/locales/zh-CN/features/config-metadata.json +++ b/dashboard/src/i18n/locales/zh-CN/features/config-metadata.json @@ -249,19 +249,19 @@ "provider_settings": { "max_context_length": { "description": "压缩前最多保留对话轮数", - "hint": "无论选择截断还是 LLM 压缩,都会先按该值限制历史轮数;-1 表示不按轮数限制" + "hint": "普通会话历史超过该轮数后,才会按下方策略进行持久化截断或 LLM 压缩;请求发送前也会先按该值约束上下文。-1 表示不按轮数限制。" }, "dequeue_context_length": { "description": "轮次超限时一次丢弃轮数", - "hint": "当超过\"压缩前最多保留对话轮数\"时,一次丢弃多少轮旧对话;同时也可能作为压缩不可用时的回退截断参数" + "hint": "当超过\"压缩前最多保留对话轮数\"且无法使用 LLM 压缩时,一次丢弃多少轮旧对话;请求期截断也会复用该值。" }, "context_limit_reached_strategy": { - "description": "模型上下文接近上限后的处理方式", + "description": "历史超限或上下文接近上限时的处理方式", "labels": [ "按对话轮数截断", "由 LLM 压缩上下文" ], - "hint": "该策略只会在完成轮次限制后,且上下文 token 接近模型窗口上限时触发。当按对话轮数截断时,会根据上面\"轮次超限时一次丢弃轮数\"的配置丢弃最旧的 N 轮对话。当由 LLM 压缩上下文时,会使用指定的模型进行上下文压缩。" + "hint": "普通会话历史仅在超过\"压缩前最多保留对话轮数\"后执行该策略;请求发送前也会在上下文 token 接近模型窗口时使用同一策略保护本次请求。" }, "llm_compress_instruction": { "description": "上下文压缩提示词", @@ -273,7 +273,7 @@ }, "llm_compress_provider_id": { "description": "用于上下文压缩的模型提供商 ID", - "hint": "留空时将降级为\"按对话轮数截断\"的策略。" + "hint": "留空时使用当前聊天模型进行压缩;如果模型不可用或压缩失败,将回退为\"按对话轮数截断\"的策略。" }, "fallback_max_context_tokens": { "description": "上下文窗口兜底值", @@ -991,9 +991,6 @@ "group_icl_enable": { "description": "启用群聊上下文感知" }, - "group_message_max_cnt": { - "description": "最大消息数量" - }, "image_caption": { "description": "自动理解图片", "hint": "需要设置群聊图片转述模型。" @@ -1002,6 +999,46 @@ "description": "群聊图片转述模型", "hint": "用于群聊上下文感知的图片理解,与默认图片转述模型分开配置。" }, + "history_tool_result_truncate": { + "description": "截断历史工具输出", + "hint": "仅影响群聊 LTM 历史轮,不影响当前工具调用轮的完整推理。" + }, + "history_tool_result_max_chars": { + "description": "历史工具输出截断上限", + "hint": "单条工具输出写入群聊历史时的最大字符数,默认 8192。" + }, + "ltm_compaction_strategy": { + "description": "LTM 上下文压缩策略", + "hint": "truncate: 按轮截断; llm_summary: 调用 LLM 做长期摘要。" + }, + "ltm_max_rounds": { + "description": "LTM 最大保留轮数", + "hint": "truncate 策略生效时的截断上限,默认 80。" + }, + "ltm_truncate_drop_rounds": { + "description": "截断时丢弃轮数", + "hint": "truncate 策略触发截断时,从前面丢弃的轮数。默认 50。" + }, + "ltm_summary_trigger_rounds": { + "description": "摘要触发轮数", + "hint": "超过多少轮时触发 LLM 摘要压缩,默认 80。" + }, + "ltm_summary_keep_recent_rounds": { + "description": "摘要时保留最近轮数", + "hint": "llm_summary 策略下保留最近 N 轮精确上下文,默认 30。" + }, + "ltm_summary_provider_id": { + "description": "LTM 摘要模型", + "hint": "llm_summary 策略使用的模型,留空使用当前聊天模型。" + }, + "ltm_summary_prompt": { + "description": "LTM 摘要提示词", + "hint": "llm_summary 策略的自定义摘要 prompt,留空使用内置默认。" + }, + "ltm_raw_records_max_bytes": { + "description": "原始消息缓冲区内存上限", + "hint": "每个群聊的未消费消息缓冲区的最大字节数。用于防止长期未 @bot 的群内存溢出,默认 500000 (500KB)。" + }, "active_reply": { "enable": { "description": "主动回复" diff --git a/tests/test_context_compaction.py b/tests/test_context_compaction.py new file mode 100644 index 0000000000..e36bfe8802 --- /dev/null +++ b/tests/test_context_compaction.py @@ -0,0 +1,280 @@ +"""Test that persistent context compaction is actually effective — +the compressed result is saved to the conversation layer, not discarded +on a temporary copy. +""" + +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest + +from astrbot.core.agent.message import Message +from astrbot.core.pipeline.process_stage.method.agent_sub_stages.internal import ( + _count_conversation_turns, + _has_valid_summary_message, + _history_exceeds_turn_limit, +) + + +# --------------------------------------------------------------------------- +# helpers +# --------------------------------------------------------------------------- + + +def _make_message(role: str, content: str = "test") -> Message: + return Message(role=role, content=content) + + +def _build_turns(n: int) -> list[Message]: + """Build *n* user+assistant pairs.""" + msgs: list[Message] = [] + for i in range(n): + msgs.append(_make_message("user", f"user_msg_{i}")) + msgs.append(_make_message("assistant", f"assistant_msg_{i}")) + return msgs + + +# --------------------------------------------------------------------------- +# _count_conversation_turns +# --------------------------------------------------------------------------- + + +def test_count_turns_empty(): + assert _count_conversation_turns([]) == 0 + + +def test_count_turns_only_user(): + msgs = [_make_message("user", "hi"), _make_message("user", "hey")] + assert _count_conversation_turns(msgs) == 2 + + +def test_count_turns_mixed(): + msgs = [ + _make_message("system", "sys"), + _make_message("user", "u1"), + _make_message("assistant", "a1"), + _make_message("user", "u2"), + _make_message("assistant", "a2"), + _make_message("tool", "t1"), + ] + assert _count_conversation_turns(msgs) == 2 # only user=u1,u2 + + +# --------------------------------------------------------------------------- +# _history_exceeds_turn_limit +# --------------------------------------------------------------------------- + + +def test_exceeds_disabled(): + msgs = _build_turns(100) + assert _history_exceeds_turn_limit(msgs, -1) is False + + +def test_exceeds_zero_or_negative(): + msgs = _build_turns(10) + assert _history_exceeds_turn_limit(msgs, 0) is False + assert _history_exceeds_turn_limit(msgs, -5) is False + + +def test_exceeds_under_limit(): + msgs = _build_turns(25) + assert _history_exceeds_turn_limit(msgs, 25) is False # 25 is not > 25 + + +def test_exceeds_over_limit(): + msgs = _build_turns(26) + assert _history_exceeds_turn_limit(msgs, 25) is True + + +# --------------------------------------------------------------------------- +# _has_valid_summary_message +# --------------------------------------------------------------------------- + + +def test_valid_summary_detected(): + msgs = [ + Message( + role="user", + content="Our previous history conversation summary: the user asked about weather.", + ), + Message(role="assistant", content="Acknowledged."), + ] + assert _has_valid_summary_message(msgs) is True + + +def test_empty_summary_not_detected(): + msgs = [ + Message( + role="user", + content="Our previous history conversation summary: ", + ), + ] + assert _has_valid_summary_message(msgs) is False + + +def test_no_summary_prefix(): + msgs = [Message(role="user", content="regular message")] + assert _has_valid_summary_message(msgs) is False + + +# --------------------------------------------------------------------------- +# end-to-end: _save_to_history truncation fallback +# --------------------------------------------------------------------------- + + +class TestSaveToHistoryCompaction: + """Simulate the _save_to_history compaction path without a real LLM. + + We mock the LLM compression provider away so the fallback_truncate + path is exercised. This proves that messages_to_save is *replaced* + with the truncated version before calling update_conversation. + """ + + @pytest.mark.asyncio + async def test_truncation_replaces_messages_to_save(self): + from astrbot.core.agent.context.truncator import ContextTruncator + + # Build 30 turns (60 messages) + 1 extra → 31 user messages + messages = _build_turns(31) + # The _save_to_history skips the very first system message, so we + # don't add one here — just the history + current user. + truncator = ContextTruncator() + truncated = truncator.truncate_by_turns( + messages, + keep_most_recent_turns=25, + drop_turns=10, + ) + # After truncation, should have ~16 turns (high-water 25, low-water 16) + user_count = _count_conversation_turns(truncated) + assert user_count <= 25 + # With 25-10+1=16 turns → at most 16 user messages + assert user_count >= 1 + # Verify the oldest messages are gone + first_user_content = next( + m.content for m in truncated if m.role == "user" + ) + assert isinstance(first_user_content, str) + # Should NOT be "user_msg_0" (that was in the first dropped turn) + assert "user_msg_0" not in str(first_user_content) + + @pytest.mark.asyncio + async def test_truncation_persists_to_conversation(self): + """Verify that after truncation, the conversation is updated with + the *truncated* messages, not the originals. + """ + from unittest.mock import AsyncMock, MagicMock + + from astrbot.core.agent.response import AgentStats + from astrbot.core.pipeline.process_stage.method.agent_sub_stages.internal import ( + InternalAgentSubStage, + ) + from astrbot.core.provider.entities import LLMResponse + + # Build a mock InternalAgentSubStage + stage = InternalAgentSubStage() + # Inject config values + stage.max_context_length = 25 + stage.dequeue_context_length = 10 + stage.context_limit_reached_strategy = "truncate_by_turns" # force truncation path + stage.llm_compress_provider_id = "" + stage.llm_compress_keep_recent = 6 + stage.llm_compress_instruction = "" + + # Mock conversation manager + mock_conv = MagicMock() + mock_conv.cid = "test-cid-123" + stage.conv_manager = AsyncMock() + + # Build a ProviderRequest with conversation + from astrbot.core.provider.entities import ProviderRequest + + req = ProviderRequest(conversation=mock_conv, prompt="hello") + + # Build 30 turns of history → 30 user messages, > 25 + all_messages = _build_turns(30) + # Add the current user message and assistant response + all_messages.append(_make_message("user", "current_user_msg")) + all_messages.append(_make_message("assistant", "current_assistant_msg")) + # total: 31 user messages → exceeds 25 + + llm_resp = LLMResponse( + role="assistant", completion_text="I'm here to help!" + ) + + # Mock event + mock_event = MagicMock() + + await stage._save_to_history( + event=mock_event, + req=req, + llm_response=llm_resp, + all_messages=all_messages, + runner_stats=AgentStats(), + ) + + # Verify update_conversation was called + stage.conv_manager.update_conversation.assert_called_once() + + call_args = stage.conv_manager.update_conversation.call_args + saved_history = call_args.kwargs["history"] + + # The saved history should have FEWER messages than the original + original_msg_count = len(all_messages) - 1 # minus first system skip + assert len(saved_history) < original_msg_count, ( + f"Expected truncated history ({len(saved_history)}) " + f"to be smaller than original ({original_msg_count})" + ) + + # The saved history should NOT contain the oldest messages + saved_contents = [ + m.get("content", "") if isinstance(m, dict) else str(m) + for m in saved_history + ] + joined = " ".join(saved_contents) + assert "user_msg_0" not in joined, ( + "Oldest user message should have been truncated" + ) + + @pytest.mark.asyncio + async def test_next_round_loads_compressed_history(self): + """Simulate a full cycle: compress → save → load → verify compressed + version is what the next round sees. + """ + from astrbot.core.agent.context.truncator import ContextTruncator + + # Round N: 30 turns + original = _build_turns(30) + original.append(_make_message("user", "round_N_user")) + original.append(_make_message("assistant", "round_N_assistant")) + + # Compress + truncator = ContextTruncator() + compressed = truncator.truncate_by_turns( + original, + keep_most_recent_turns=25, + drop_turns=10, + ) + + # "Save" to a mock DB + saved = [m.model_dump() for m in compressed] + + # Round N+1: "load" from DB + loaded = [Message.model_validate(m) for m in saved] + loaded.append(_make_message("user", "round_N+1_user")) + + # The loaded context should start from the compressed version + user_msgs = [m for m in loaded if m.role == "user"] + contents = [m.content for m in user_msgs if isinstance(m.content, str)] + + # "round_N_user" should be present (it was recent enough) + assert any("round_N_user" in c for c in contents), ( + "Recent user message should survive compression" + ) + # "round_N+1_user" is the new message + assert any("round_N+1_user" in c for c in contents) + + # Verify the round count stays under threshold + turn_count = _count_conversation_turns(loaded) + # After compression (16 turns) + 1 new = 17 turns, well under 25 + assert turn_count < 25, ( + f"Next round should stay under limit, got {turn_count}" + ) diff --git a/tests/test_ltm_compaction.py b/tests/test_ltm_compaction.py new file mode 100644 index 0000000000..b8ac69f812 --- /dev/null +++ b/tests/test_ltm_compaction.py @@ -0,0 +1,408 @@ +"""Test LTM (LongTermMemory) context management for group chats. + +LTM has its OWN compaction strategy, independent config keys, and +different persistence model from private chat. These tests verify +that LTM compaction actually modifies self.contexts / self.summaries, +and that the request-time guard does NOT double-compress LTM-managed +contexts. +""" + +import asyncio +from types import SimpleNamespace +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest + +from astrbot.builtin_stars.astrbot.long_term_memory import ( + LongTermMemory, + _build_segments, + _extract_bot_content, + _parse_tool_call, + _parse_tool_result, + _rounds_to_text, + _split_into_rounds, + _truncate_user_segment, +) + + +# --------------------------------------------------------------------------- +# _build_segments helpers +# --------------------------------------------------------------------------- + + +def _raw_bot(time_str: str, content: str) -> str: + return f": {content}" + + +def _raw_user(nick: str, time_str: str, content: str) -> str: + return f"[{nick}/{time_str}]: {content}" + + +# --------------------------------------------------------------------------- +# _parse_tool_call / _parse_tool_result / _extract_bot_content +# --------------------------------------------------------------------------- + + +def test_parse_tool_call_valid(): + result = _parse_tool_call( + '{"id":"abc","name":"search","args":{"q":"x"}}' + ) + assert result is not None + assert result["id"] == "abc" + assert result["type"] == "function" + assert result["function"]["name"] == "search" + + +def test_parse_tool_call_invalid_json(): + assert _parse_tool_call("not json") is None + + +def test_parse_tool_result_valid(): + result = _parse_tool_result("result text") + assert result is not None + assert result["role"] == "tool" + assert result["tool_call_id"] == "abc" + assert result["content"] == "result text" + + +def test_parse_tool_result_no_id(): + assert _parse_tool_result("stuff") is None + + +def test_extract_bot_content(): + assert _extract_bot_content(": hello world") == "hello world" + + +def test_extract_bot_content_no_separator(): + assert _extract_bot_content(" missing colon") is None + + +# --------------------------------------------------------------------------- +# _build_segments +# --------------------------------------------------------------------------- + + +def test_build_segments_empty(): + assert _build_segments([]) == [] + + +def test_build_segments_user_only(): + lines = [ + _raw_user("Alice", "10:00:00", "hello"), + _raw_user("Bob", "10:01:00", "hi"), + ] + segs = _build_segments(lines) + assert len(segs) == 1 + assert segs[0]["role"] == "user" + assert "Alice" in segs[0]["content"] + assert "Bob" in segs[0]["content"] + + +def test_build_segments_tool_chain(): + lines = [ + '{"id":"1","name":"search","args":{"q":"test"}}', + '{"id":"2","name":"calc","args":{"expr":"1+1"}}', + "search results", + "2", + ] + segs = _build_segments(lines) + # Should produce: 1 assistant(tool_calls with 2 tools), 2 tool results + assert len(segs) == 3 + assert segs[0]["role"] == "assistant" + assert segs[0]["tool_calls"] is not None + assert len(segs[0]["tool_calls"]) == 2 # merged consecutive calls + assert segs[1]["role"] == "tool" + assert segs[2]["role"] == "tool" + + +def test_build_segments_bot_reply(): + lines = [ + _raw_user("Alice", "10:00:00", "help"), + _raw_bot("10:00:05", "Sure, what do you need?"), + ] + segs = _build_segments(lines) + assert len(segs) == 2 + assert segs[0]["role"] == "user" + assert segs[1]["role"] == "assistant" + assert "Sure" in segs[1]["content"] + + +def test_build_segments_mixed(): + lines = [ + _raw_user("Alice", "10:00:00", "search python"), + '{"id":"3","name":"search","args":{"q":"python"}}', + "Python is a programming language.", + _raw_bot("10:00:10", "Python is a programming language."), + ] + segs = _build_segments(lines) + assert len(segs) == 4 + roles = [s["role"] for s in segs] + assert roles == ["user", "assistant", "tool", "assistant"] + + +# --------------------------------------------------------------------------- +# _split_into_rounds +# --------------------------------------------------------------------------- + + +def test_split_rounds_empty(): + assert _split_into_rounds([]) == [] + + +def test_split_rounds_single(): + ctxs = [ + {"role": "user", "content": "hi"}, + {"role": "assistant", "content": "hello"}, + ] + rounds = _split_into_rounds(ctxs) + assert len(rounds) == 1 + assert len(rounds[0]) == 2 + + +def test_split_rounds_multi(): + ctxs = [ + {"role": "user", "content": "u1"}, + {"role": "assistant", "content": "a1"}, + {"role": "user", "content": "u2"}, + {"role": "assistant", "content": "a2"}, + {"role": "tool", "tool_call_id": "x", "content": "t1"}, + ] + rounds = _split_into_rounds(ctxs) + assert len(rounds) == 2 + assert len(rounds[0]) == 2 # u1+a1 + assert len(rounds[1]) == 3 # u2+a2+t1 + + +def test_split_rounds_no_user_start(): + """If the first segment is not a user, it forms its own round, + and the next user starts a new round.""" + ctxs = [ + {"role": "assistant", "content": "orphan"}, + {"role": "user", "content": "u1"}, + ] + rounds = _split_into_rounds(ctxs) + # orphan assistant alone → round 0, user starts round 1 + assert len(rounds) == 2 + assert rounds[0][0]["role"] == "assistant" + assert rounds[1][0]["role"] == "user" + + +# --------------------------------------------------------------------------- +# LTM compaction: truncate strategy +# --------------------------------------------------------------------------- + + +class TestLTMTruncateCompaction: + """Verify LTM's truncate-based compaction modifies self.contexts.""" + + def test_truncation_drops_oldest_rounds(self): + """When rounds > ltm_max_rounds, oldest rounds are dropped.""" + # Build 90 rounds + rounds = [] + for i in range(90): + rounds.append([ + {"role": "user", "content": f"u{i}"}, + {"role": "assistant", "content": f"a{i}"}, + ]) + # Flatten to simulate self.contexts + contexts = [seg for rnd in rounds for seg in rnd] + + # Simulate compaction with defaults: max=80, drop=50 + max_rounds = 80 + drop_rounds = 50 + + assert len(rounds) > max_rounds + safe_drop = min(drop_rounds, len(rounds) - 1) + kept = rounds[safe_drop:] + new_contexts = [seg for rnd in kept for seg in rnd] + + # Should have kept 90 - 50 = 40 rounds + assert len(kept) == 40 + # First kept round should be round 50 + assert kept[0][0]["content"] == "u50" + # Verify self.contexts would be replaced + assert new_contexts != contexts + assert len(new_contexts) < len(contexts) + + def test_truncation_not_triggered_under_limit(self): + """Rounds under the limit should not be truncated.""" + rounds = [] + for i in range(70): # < 80 + rounds.append([ + {"role": "user", "content": f"u{i}"}, + {"role": "assistant", "content": f"a{i}"}, + ]) + contexts = [seg for rnd in rounds for seg in rnd] + + max_rounds = 80 + assert len(rounds) <= max_rounds + # No truncation should happen + assert len(rounds) == 70 + + def test_drop_clamped_to_len_minus_one(self): + """If drop_rounds > len(rounds)-1, clamp to len-1 (keep at least 1).""" + rounds = [] + for i in range(5): + rounds.append([ + {"role": "user", "content": f"u{i}"}, + {"role": "assistant", "content": f"a{i}"}, + ]) + + max_rounds = 3 # 5 > 3 → trigger + drop_rounds = 50 # would drop all if not clamped + + safe_drop = min(drop_rounds, len(rounds) - 1) # min(50, 4) = 4 + kept = rounds[safe_drop:] # rounds[4:] = last 1 round + assert len(kept) == 1 + assert kept[0][0]["content"] == "u4" + + +# --------------------------------------------------------------------------- +# LTM compaction: llm_summary strategy +# --------------------------------------------------------------------------- + + +class TestLTMSummaryCompaction: + @pytest.mark.asyncio + async def test_summary_triggers_and_updates_summaries(self): + """LLM summary updates self.summaries and truncates self.contexts.""" + ltm = LongTermMemory( + acm=MagicMock(), + context=MagicMock(), + ) + umo = "test_umo" + # Pre-populate contexts with 85 rounds (> trigger=80) + rounds = [] + for i in range(85): + rounds.append([ + {"role": "user", "content": f"u{i}"}, + {"role": "assistant", "content": f"a{i}"}, + ]) + ltm.contexts[umo] = [seg for rnd in rounds for seg in rnd] + + # Mock provider for summary + mock_provider = MagicMock() + mock_provider.text_chat = AsyncMock() + mock_resp = MagicMock() + mock_resp.completion_text = "Summary: users discussed various topics." + mock_provider.text_chat.return_value = mock_resp + + mock_event = MagicMock() + mock_event.unified_msg_origin = umo + + await ltm._compact_with_llm_summary( + event=mock_event, + provider=mock_provider, + keep_recent=30, + prompt="", + rounds=rounds, + ) + + # Summary should be stored + assert ltm.summaries[umo] == "Summary: users discussed various topics." + + # Contexts should be replaced with only the recent rounds + recent_rounds = rounds[-30:] + expected_contexts = [seg for rnd in recent_rounds for seg in rnd] + assert ltm.contexts[umo] == expected_contexts + assert len(ltm.contexts[umo]) < 85 * 2 + + @pytest.mark.asyncio + async def test_summary_failure_keeps_original_and_sets_cooldown(self): + """Failed summary should NOT replace contexts and should set retry cooldown.""" + ltm = LongTermMemory( + acm=MagicMock(), + context=MagicMock(), + ) + umo = "test_umo" + rounds = [] + for i in range(85): + rounds.append([ + {"role": "user", "content": f"u{i}"}, + {"role": "assistant", "content": f"a{i}"}, + ]) + original_contexts = [seg for rnd in rounds for seg in rnd] + ltm.contexts[umo] = list(original_contexts) + + # Mock provider that raises + mock_provider = MagicMock() + mock_provider.text_chat = AsyncMock(side_effect=Exception("API error")) + + mock_event = MagicMock() + mock_event.unified_msg_origin = umo + + await ltm._compact_with_llm_summary( + event=mock_event, + provider=mock_provider, + keep_recent=30, + prompt="", + rounds=rounds, + ) + + # Contexts should NOT have been modified + assert ltm.contexts[umo] == original_contexts + # Cooldown should be set + assert ltm._summary_next_retry[umo] == 85 + 5 # rounds + SUMMARY_RETRY_COOLDOWN + + @pytest.mark.asyncio + async def test_summary_not_triggered_when_old_rounds_empty(self): + """When rounds <= keep_recent, old_rounds is empty → no provider call.""" + ltm = LongTermMemory( + acm=MagicMock(), + context=MagicMock(), + ) + umo = "test_umo" + # 30 rounds = keep_recent → old_rounds = rounds[:0] = [] + rounds = [] + for i in range(30): + rounds.append([ + {"role": "user", "content": f"u{i}"}, + {"role": "assistant", "content": f"a{i}"}, + ]) + ltm.contexts[umo] = [seg for rnd in rounds for seg in rnd] + + mock_provider = MagicMock() + mock_provider.text_chat = AsyncMock() + + mock_event = MagicMock() + mock_event.unified_msg_origin = umo + + await ltm._compact_with_llm_summary( + event=mock_event, + provider=mock_provider, + keep_recent=30, + prompt="", + rounds=rounds, + ) + + # old_rounds empty → early return before text_chat + mock_provider.text_chat.assert_not_called() + + +# --------------------------------------------------------------------------- +# Guard + LTM interaction: guard should NOT do LLM compression on LTM contexts +# --------------------------------------------------------------------------- + + +class TestGuardDoesNotDoubleCompressLTM: + """With our changes, the request-time guard uses only TruncateByTurns, + never LLMSummaryCompressor. This is correct regardless of whether + LTM is active, but especially important for LTM-managed group chats. + """ + + def test_context_config_no_llm_provider_falls_back_to_truncate(self): + """When llm_compress_provider is None, ContextManager must select + TruncateByTurnsCompressor.""" + from astrbot.core.agent.context.config import ContextConfig + from astrbot.core.agent.context.manager import ContextManager + from astrbot.core.agent.context.compressor import TruncateByTurnsCompressor + + config = ContextConfig( + max_context_tokens=10000, + enforce_max_turns=-1, # disabled + truncate_turns=10, + llm_compress_provider=None, # ← our change + ) + mgr = ContextManager(config) + assert isinstance(mgr.compressor, TruncateByTurnsCompressor) + assert mgr.compressor.truncate_turns == 10 diff --git a/tests/test_tool_loop_agent_runner.py b/tests/test_tool_loop_agent_runner.py index 74d0691085..e095ce01db 100644 --- a/tests/test_tool_loop_agent_runner.py +++ b/tests/test_tool_loop_agent_runner.py @@ -260,6 +260,33 @@ async def text_chat(self, **kwargs) -> LLMResponse: ) +class CapturingToolLoopProvider(MockProvider): + def __init__(self, tool_name: str): + super().__init__() + self.tool_name = tool_name + self.received_contexts = [] + + async def text_chat(self, **kwargs) -> LLMResponse: + self.call_count += 1 + self.received_contexts.append(list(kwargs.get("contexts") or [])) + func_tool = kwargs.get("func_tool") + if func_tool is None or self.call_count > 1: + return LLMResponse( + role="assistant", + completion_text="最终回复", + usage=TokenUsage(input_other=10, output=5), + ) + + return LLMResponse( + role="assistant", + completion_text="", + tools_call_name=[self.tool_name], + tools_call_args=[{"query": "test"}], + tools_call_ids=["call_context_refresh"], + usage=TokenUsage(input_other=10, output=5), + ) + + class SequentialToolProvider(MockProvider): def __init__(self, tool_sequence: list[str]): super().__init__() @@ -450,6 +477,68 @@ async def test_max_step_limit_functionality( assert last_message.role == "assistant", "最后一条消息应该是assistant的最终回答" +@pytest.mark.asyncio +async def test_max_step_final_request_includes_limit_prompt( + runner, provider_request, mock_tool_executor, mock_hooks +): + """The forced final step must use contexts recomputed after max-step prompt.""" + provider = CapturingToolLoopProvider("test_tool") + + await runner.reset( + provider=provider, + request=provider_request, + run_context=ContextWrapper(context=None), + tool_executor=mock_tool_executor, + agent_hooks=mock_hooks, + streaming=False, + ) + + async def snapshot_guard(messages, trusted_token_usage=0): + return list(messages) + + runner.request_context_guard.process = snapshot_guard + + async for _ in runner.step_until_done(1): + pass + + assert provider.call_count == 2 + final_contexts = provider.received_contexts[-1] + assert final_contexts[-1].role == "user" + assert final_contexts[-1].content == runner.MAX_STEPS_REACHED_PROMPT + + +@pytest.mark.asyncio +async def test_tool_loop_next_request_includes_tool_result( + runner, provider_request, mock_tool_executor, mock_hooks +): + """Tool-loop provider contexts must be recomputed after tool results append.""" + provider = CapturingToolLoopProvider("test_tool") + + await runner.reset( + provider=provider, + request=provider_request, + run_context=ContextWrapper(context=None), + tool_executor=mock_tool_executor, + agent_hooks=mock_hooks, + streaming=False, + ) + + async def snapshot_guard(messages, trusted_token_usage=0): + return list(messages) + + runner.request_context_guard.process = snapshot_guard + + async for _ in runner.step_until_done(3): + pass + + assert provider.call_count == 2 + second_contexts = provider.received_contexts[1] + tool_messages = [msg for msg in second_contexts if msg.role == "tool"] + assert len(tool_messages) == 1 + assert tool_messages[0].tool_call_id == "call_context_refresh" + assert "工具执行结果" in tool_messages[0].content + + @pytest.mark.asyncio async def test_normal_completion_without_max_step( runner, mock_provider, provider_request, mock_tool_executor, mock_hooks diff --git a/tests/unit/test_internal_agent_history_compaction.py b/tests/unit/test_internal_agent_history_compaction.py new file mode 100644 index 0000000000..6fd94055eb --- /dev/null +++ b/tests/unit/test_internal_agent_history_compaction.py @@ -0,0 +1,272 @@ +from types import SimpleNamespace +from unittest.mock import AsyncMock, MagicMock + +import pytest + +from astrbot.core.agent.message import Message, ToolCallMessageSegment +from astrbot.core.db.po import Conversation +from astrbot.core.pipeline.process_stage.method.agent_sub_stages.internal import ( + InternalAgentSubStage, + _count_conversation_turns, + _history_exceeds_turn_limit, +) +from astrbot.core.provider.entities import LLMResponse, ProviderRequest +from astrbot.core.provider.provider import Provider + + +class FakeSummaryProvider(Provider): + def __init__(self) -> None: + super().__init__({"id": "summary", "type": "test"}, {}) + self.text_chat_mock = AsyncMock( + return_value=LLMResponse(role="assistant", completion_text="old summary") + ) + + def get_current_key(self) -> str: + return "test-key" + + def set_key(self, key: str) -> None: + pass + + async def get_models(self) -> list[str]: + return ["test-model"] + + async def text_chat(self, **kwargs) -> LLMResponse: + return await self.text_chat_mock(**kwargs) + + +def make_stage( + *, + provider: Provider | None = None, + current_provider: Provider | None = None, + max_turns: int = 3, + provider_id: str | None = "summary", +): + stage = InternalAgentSubStage() + stage.max_context_length = max_turns + stage.dequeue_context_length = 1 + stage.context_limit_reached_strategy = "llm_compress" + stage.llm_compress_provider_id = provider_id or "" + stage.llm_compress_keep_recent = 2 + stage.llm_compress_instruction = "Summarize history" + stage.conv_manager = SimpleNamespace(update_conversation=AsyncMock()) + plugin_context = SimpleNamespace( + get_provider_by_id=MagicMock(return_value=provider), + get_using_provider=MagicMock(return_value=current_provider), + ) + stage.ctx = SimpleNamespace(plugin_manager=SimpleNamespace(context=plugin_context)) + return stage + + +def make_event(): + event = MagicMock() + event.unified_msg_origin = "umo-1" + event.get_extra.return_value = None + return event + + +def make_request() -> ProviderRequest: + return ProviderRequest( + conversation=Conversation(platform_id="test", user_id="user", cid="cid-1") + ) + + +def make_plain_turns(count: int) -> list[Message]: + messages: list[Message] = [] + for index in range(count): + messages.append(Message(role="user", content=f"question {index}")) + messages.append(Message(role="assistant", content=f"answer {index}")) + return messages + + +def make_tool_turn() -> list[Message]: + return [ + Message(role="user", content="use tool"), + Message( + role="assistant", + content=None, + tool_calls=[ + { + "id": "call-1", + "type": "function", + "function": {"name": "search", "arguments": "{}"}, + } + ], + ), + ToolCallMessageSegment( + role="tool", + tool_call_id="call-1", + content="tool result", + ), + Message(role="assistant", content="final answer"), + ] + + +@pytest.mark.parametrize("max_turns", [-1, 0]) +def test_history_turn_limit_disabled(max_turns: int): + assert not _history_exceeds_turn_limit(make_plain_turns(10), max_turns) + + +def test_conversation_turn_count_treats_tool_chain_as_one_turn(): + messages = make_plain_turns(2) + make_tool_turn() + + assert _count_conversation_turns(messages) == 3 + assert not _history_exceeds_turn_limit(messages, 3) + assert _history_exceeds_turn_limit(messages, 2) + + +@pytest.mark.asyncio +async def test_save_history_does_not_summarize_below_turn_limit(): + provider = FakeSummaryProvider() + stage = make_stage(provider=provider, max_turns=3) + + await stage._save_to_history( + make_event(), + make_request(), + LLMResponse(role="assistant", completion_text="latest answer"), + make_plain_turns(3), + runner_stats=None, + ) + + provider.text_chat_mock.assert_not_called() + saved_history = stage.conv_manager.update_conversation.await_args.kwargs["history"] + assert len(saved_history) == 6 + assert saved_history[0]["content"] == "question 0" + + +@pytest.mark.asyncio +async def test_save_history_summarizes_only_after_turn_limit_exceeded(): + provider = FakeSummaryProvider() + stage = make_stage(provider=provider, max_turns=3) + + await stage._save_to_history( + make_event(), + make_request(), + LLMResponse(role="assistant", completion_text="latest answer"), + make_plain_turns(4), + runner_stats=None, + ) + + provider.text_chat_mock.assert_awaited_once() + saved_history = stage.conv_manager.update_conversation.await_args.kwargs["history"] + assert saved_history[0]["role"] == "user" + assert saved_history[0]["content"].startswith( + "Our previous history conversation summary:" + ) + assert saved_history[-2]["content"] == "question 3" + assert saved_history[-1]["content"] == "answer 3" + + +@pytest.mark.asyncio +async def test_save_history_uses_current_provider_when_compress_provider_id_empty(): + provider = FakeSummaryProvider() + stage = make_stage( + current_provider=provider, + max_turns=3, + provider_id="", + ) + + await stage._save_to_history( + make_event(), + make_request(), + LLMResponse(role="assistant", completion_text="latest answer"), + make_plain_turns(4), + runner_stats=None, + ) + + stage.ctx.plugin_manager.context.get_provider_by_id.assert_not_called() + stage.ctx.plugin_manager.context.get_using_provider.assert_called_once_with( + umo="umo-1" + ) + provider.text_chat_mock.assert_awaited_once() + saved_history = stage.conv_manager.update_conversation.await_args.kwargs["history"] + assert saved_history[0]["content"].startswith( + "Our previous history conversation summary:" + ) + + +@pytest.mark.asyncio +async def test_save_history_falls_back_when_summary_returns_empty_text(): + provider = FakeSummaryProvider() + provider.text_chat_mock.return_value = LLMResponse( + role="assistant", completion_text="" + ) + stage = make_stage(provider=provider, max_turns=3) + + await stage._save_to_history( + make_event(), + make_request(), + LLMResponse(role="assistant", completion_text="latest answer"), + make_plain_turns(4), + runner_stats=None, + ) + + provider.text_chat_mock.assert_awaited_once() + saved_history = stage.conv_manager.update_conversation.await_args.kwargs["history"] + assert saved_history[0]["content"] == "question 1" + assert saved_history[-1]["content"] == "answer 3" + + +@pytest.mark.asyncio +async def test_save_history_falls_back_when_summary_provider_raises(): + provider = FakeSummaryProvider() + provider.text_chat_mock.side_effect = RuntimeError("boom") + stage = make_stage(provider=provider, max_turns=3) + + await stage._save_to_history( + make_event(), + make_request(), + LLMResponse(role="assistant", completion_text="latest answer"), + make_plain_turns(4), + runner_stats=None, + ) + + provider.text_chat_mock.assert_awaited_once() + saved_history = stage.conv_manager.update_conversation.await_args.kwargs["history"] + assert saved_history[0]["content"] == "question 1" + assert saved_history[-1]["content"] == "answer 3" + + +@pytest.mark.asyncio +async def test_save_history_tool_chain_does_not_trigger_early_summary(): + provider = FakeSummaryProvider() + stage = make_stage(provider=provider, max_turns=3) + messages = make_plain_turns(2) + make_tool_turn() + + await stage._save_to_history( + make_event(), + make_request(), + LLMResponse(role="assistant", completion_text="latest answer"), + messages, + runner_stats=None, + ) + + provider.text_chat_mock.assert_not_called() + saved_history = stage.conv_manager.update_conversation.await_args.kwargs["history"] + roles = [item["role"] for item in saved_history] + assert roles == [ + "user", + "assistant", + "user", + "assistant", + "user", + "assistant", + "tool", + "assistant", + ] + + +@pytest.mark.asyncio +async def test_save_history_falls_back_to_turn_truncation_after_limit_exceeded(): + stage = make_stage(provider=None, max_turns=3) + + await stage._save_to_history( + make_event(), + make_request(), + LLMResponse(role="assistant", completion_text="latest answer"), + make_plain_turns(4), + runner_stats=None, + ) + + saved_history = stage.conv_manager.update_conversation.await_args.kwargs["history"] + assert saved_history[0]["content"] == "question 1" + assert saved_history[-1]["content"] == "answer 3" diff --git a/tests/unit/test_long_term_memory.py b/tests/unit/test_long_term_memory.py new file mode 100644 index 0000000000..f5bb4046da --- /dev/null +++ b/tests/unit/test_long_term_memory.py @@ -0,0 +1,1834 @@ +"""Tests for LTM v2 — long_term_memory.py""" + +import json +from collections import deque + +import pytest + +from astrbot.builtin_stars.astrbot.long_term_memory import ( + _build_segments, + _extract_bot_content, + _extract_tag_content, + _parse_tool_call, + _parse_tool_result, + _rounds_to_text, + _split_into_rounds, + _truncate_tool_result_for_history, + _truncate_user_segment, +) + + +# ============================================================================= +# _extract_tag_content +# ============================================================================= + +class TestExtractTagContent: + def test_normal(self): + assert _extract_tag_content("hello", "", "") == "hello" + + def test_json(self): + line = '{"id":"x","name":"f"}' + assert _extract_tag_content(line, "", "") == '{"id":"x","name":"f"}' + + def test_no_end_tag(self): + assert _extract_tag_content("hello", "", "") is None + + def test_wrong_end_tag(self): + assert _extract_tag_content("hello", "", "") is None + + +# ============================================================================= +# _extract_bot_content +# ============================================================================= + +class TestExtractBotContent: + def test_normal(self): + assert _extract_bot_content(": 你好呀~") == "你好呀~" + + def test_multiline_inline(self): + assert _extract_bot_content(": reply text") == "reply text" + + def test_no_separator(self): + assert _extract_bot_content("no colon space") is None + + def test_empty_content(self): + assert _extract_bot_content(": ") == "" + + +# ============================================================================= +# _parse_tool_call +# ============================================================================= + +class TestParseToolCall: + def test_single(self): + line = '{"id":"call_001","name":"get_weather","args":{"location":"北京"}}' + result = _parse_tool_call(line) + assert result is not None + assert result["id"] == "call_001" + assert result["type"] == "function" + assert result["function"]["name"] == "get_weather" + assert json.loads(result["function"]["arguments"]) == {"location": "北京"} + + def test_no_args(self): + line = '{"id":"c2","name":"ping","args":{}}' + result = _parse_tool_call(line) + assert result["function"]["arguments"] == "{}" + + def test_bad_json(self): + assert _parse_tool_call("not json") is None + + def test_missing_end_tag(self): + assert _parse_tool_call('{"id":"x"}{"id":"x","args":{}}') is None + # 缺 id + assert _parse_tool_call('{"name":"f","args":{}}') is None + # 全缺 + assert _parse_tool_call('{"foo":1}') is None + # 非 dict JSON(json.loads 返回 int) + assert _parse_tool_call("123") is None + + +# ============================================================================= +# _parse_tool_result +# ============================================================================= + +class TestParseToolResult: + def test_normal(self): + line = "晴天 25°C" + result = _parse_tool_result(line) + assert result == { + "role": "tool", + "tool_call_id": "call_001", + "content": "晴天 25°C", + } + + def test_multiline_content(self): + line = "line1\nline2" + result = _parse_tool_result(line) + assert result["content"] == "line1\nline2" + + def test_no_id(self): + assert _parse_tool_result("content") is None + + def test_bad_prefix(self): + assert _parse_tool_result("garbage") is None + + +# ============================================================================= +# _truncate_tool_result_for_history +# ============================================================================= + +class TestTruncateToolResultForHistory: + def test_under_limit(self): + text = "short result" + assert _truncate_tool_result_for_history(text, 100) == text + + def test_over_limit(self): + text = "x" * 50 + result = _truncate_tool_result_for_history(text, 40) + assert len(result) <= 40 + assert "TRUNCATED" in result + + def test_non_positive_limit_keeps_original(self): + text = "x" * 50 + assert _truncate_tool_result_for_history(text, 0) == text + + +# ============================================================================= +# _truncate_user_segment +# ============================================================================= + +class TestTruncateUserSegment: + def test_under_limit(self): + lines = ["[小明/14:30]: hi", "[小红/14:31]: hello"] + result = _truncate_user_segment(lines) + assert result == lines + + def test_exceeds_msg_count(self): + """最多保留 50 条""" + lines = [f"[user{i}/14:00]: msg{i}" for i in range(60)] + result = _truncate_user_segment(lines) + assert len(result) == 50 + # 保留最近的 + assert result[0] == "[user10/14:00]: msg10" + assert result[-1] == "[user59/14:00]: msg59" + + def test_exceeds_char_limit(self): + """超过 3000 字符时从最早开始丢弃""" + lines = [f"[user{i}/14:00]: {'x' * 80}" for i in range(50)] # 50 × 100 ≈ 5000 chars + result = _truncate_user_segment(lines) + total = sum(len(l) + 1 for l in result) + assert total <= 3000 + # 保留最近的 + assert result[-1] == lines[-1] + + def test_empty(self): + assert _truncate_user_segment([]) == [] + + def test_single_long_line(self): + """单条超长行也被保留""" + line = "x" * 5000 + result = _truncate_user_segment([line]) + assert result == [line] # 至少保留一条 + + +# ============================================================================= +# _build_segments +# ============================================================================= + +class TestBuildSegments: + def test_empty(self): + assert _build_segments([]) == [] + + def test_simple_user_only(self): + lines = [ + "[小明/14:30]: hi", + "[小红/14:31]: hello", + ] + result = _build_segments(lines) + assert len(result) == 1 + assert result[0]["role"] == "user" + assert "[小明/14:30]: hi" in result[0]["content"] + assert "[小红/14:31]: hello" in result[0]["content"] + + def test_user_bot_user(self): + lines = [ + "[小明/14:30]: hi", + ": 你好呀~", + "[小红/14:31]: 哈哈", + ] + result = _build_segments(lines) + assert len(result) == 3 + assert result[0]["role"] == "user" + assert "[小明/14:30]: hi" in result[0]["content"] + assert result[1] == {"role": "assistant", "content": "你好呀~"} + assert result[2]["role"] == "user" + assert "[小红/14:31]: 哈哈" in result[2]["content"] + + def test_multiple_bot_replies(self): + """多次 @bot 交互""" + lines = [ + "[小明/14:00]: @bot 1+1", + ": 等于2", + "[小红/14:01]: 哈哈", + "[小明/14:02]: @bot 2+2呢", + ": 等于4", + "[小红/14:03]: 不错", + ] + result = _build_segments(lines) + assert len(result) == 5 + assert result[0]["role"] == "user" + assert result[1] == {"role": "assistant", "content": "等于2"} + assert result[2]["role"] == "user" + assert result[3] == {"role": "assistant", "content": "等于4"} + assert result[4]["role"] == "user" + + def test_bot_first(self): + """首行即 """ + lines = [ + ": 你们好", + "[小明/14:01]: hello", + ] + result = _build_segments(lines) + assert len(result) == 2 + assert result[0] == {"role": "assistant", "content": "你们好"} + assert result[1]["role"] == "user" + + def test_tool_call_then_result_then_bot(self): + """工具调用链:T:CALL → T:RES → BOT""" + lines = [ + "[小明/14:30]: @bot 查天气", + '{"id":"call_001","name":"get_weather","args":{"location":"北京"}}', + "晴天 25°C", + ": 北京今天晴天,25°C", + ] + result = _build_segments(lines) + assert len(result) == 4 + + assert result[0]["role"] == "user" + assert "[小明/14:30]: @bot 查天气" in result[0]["content"] + + assert result[1]["role"] == "assistant" + assert result[1]["content"] is None + assert len(result[1]["tool_calls"]) == 1 + assert result[1]["tool_calls"][0]["id"] == "call_001" + + assert result[2] == { + "role": "tool", + "tool_call_id": "call_001", + "content": "晴天 25°C", + } + + assert result[3] == { + "role": "assistant", + "content": "北京今天晴天,25°C", + } + + def test_parallel_tool_calls(self): + """并行工具调用合并为一条 assistant(tool_calls)""" + lines = [ + "[小明/14:30]: @bot 查天气和空气", + '{"id":"call_001","name":"get_weather","args":{"location":"北京"}}', + '{"id":"call_002","name":"get_air_quality","args":{"location":"北京"}}', + "晴天 25°C", + "AQI 50 优", + ": 北京晴天,AQI 50 优", + ] + result = _build_segments(lines) + assert result[1]["role"] == "assistant" + assert result[1]["content"] is None + assert len(result[1]["tool_calls"]) == 2 + assert result[1]["tool_calls"][0]["id"] == "call_001" + assert result[1]["tool_calls"][1]["id"] == "call_002" + assert result[2]["role"] == "tool" + assert result[2]["tool_call_id"] == "call_001" + assert result[3]["role"] == "tool" + assert result[3]["tool_call_id"] == "call_002" + + def test_multi_step_tool(self): + """多步工具调用""" + lines = [ + "[小明/14:30]: @bot 帮我查", + '{"id":"c1","name":"search","args":{"q":"x"}}', + "found A", + '{"id":"c2","name":"get_detail","args":{"id":"A"}}', + "detail of A", + ": A 的详情是...", + ] + result = _build_segments(lines) + # tool round 1 + assert result[1]["role"] == "assistant" + assert len(result[1]["tool_calls"]) == 1 + assert result[1]["tool_calls"][0]["id"] == "c1" + assert result[2]["role"] == "tool" + assert result[2]["tool_call_id"] == "c1" + # tool round 2 + assert result[3]["role"] == "assistant" + assert len(result[3]["tool_calls"]) == 1 + assert result[3]["tool_calls"][0]["id"] == "c2" + assert result[4]["role"] == "tool" + assert result[4]["tool_call_id"] == "c2" + + def test_extreme_dense_group_chat(self): + """100 条纯群聊,无人 @bot → 全部合并为单条 user""" + lines = [f"[user{i}/14:{i:02d}]: 消息内容{i}" for i in range(100)] + result = _build_segments(lines) + assert len(result) == 1 + assert result[0]["role"] == "user" + # 段内已裁剪 + lines_in_content = result[0]["content"].split("\n") + assert len(lines_in_content) <= 50 + + def test_extreme_bot_between_group_chat(self): + """100条群聊中间夹了一次 @bot → user + asst + user""" + lines = ( + [f"[user{i}/14:{i:02d}]: 消息{i}" for i in range(40)] + + [": 我来了"] + + [f"[user{i}/14:{i:02d}]: 消息{i}" for i in range(40, 100)] + ) + result = _build_segments(lines) + assert len(result) == 3 + assert result[0]["role"] == "user" + assert result[1] == {"role": "assistant", "content": "我来了"} + assert result[2]["role"] == "user" + # 段内都做了裁剪 + assert len(result[0]["content"].split("\n")) <= 50 + assert len(result[2]["content"].split("\n")) <= 50 + + def test_image_and_at_messages(self): + """图片和 @ 消息作为普通行参与合并""" + lines = [ + "[小明/14:30]: hi [Image: 一只猫]", + "[小红/14:31]: [At: 小明] 好看", + ] + result = _build_segments(lines) + assert len(result) == 1 + assert result[0]["role"] == "user" + assert "[Image: 一只猫]" in result[0]["content"] + assert "[At: 小明]" in result[0]["content"] + + +# ============================================================================= +# LongTermMemory integration (mocked) +# ============================================================================= + +class TestLongTermMemoryIntegration: + """轻量集成测试 — 模拟 handle_message → on_req_llm → on_agent_done 流程""" + + @pytest.fixture + def mock_event(self): + from unittest.mock import MagicMock + from astrbot.api.event import AstrMessageEvent + from astrbot.api.platform import MessageType + event = MagicMock(spec=AstrMessageEvent) + event.unified_msg_origin = "group_123" + event.get_message_type.return_value = MessageType.GROUP_MESSAGE + event.get_messages.return_value = [] + event.message_obj = MagicMock() + event.message_obj.sender.nickname = "小明" + event.get_extra.return_value = -1 + return event + + @pytest.fixture + def mock_context(self): + from unittest.mock import MagicMock + from astrbot.api import star + ctx = MagicMock(spec=star.Context) + cfg = { + "provider_ltm_settings": { + "image_caption": False, + "image_caption_provider_id": "", + "image_caption_prompt": "", + "active_reply": { + "enable": False, + "method": "possibility_reply", + "possibility_reply": 0.0, + "prompt": "", + "whitelist": [], + }, + }, + "provider_settings": { + "image_caption_prompt": "", + }, + } + ctx.get_config.return_value = cfg + return ctx + + @pytest.fixture + def ltm(self, mock_context): + from unittest.mock import MagicMock + from astrbot.builtin_stars.astrbot.long_term_memory import LongTermMemory + + acm = MagicMock() + ltm = LongTermMemory(acm, mock_context) + return ltm + + @pytest.mark.asyncio + async def test_empty_flow(self, ltm, mock_event): + """空 raw_records 时 on_req_llm 应直接返回""" + from astrbot.api.provider import ProviderRequest + req = ProviderRequest() + + mock_event.get_extra.return_value = -1 # no raw idx + await ltm.on_req_llm(mock_event, req) + # No exception, no modification + assert req.contexts == [] + + @pytest.mark.asyncio + async def test_handle_then_on_req_no_bot_yet(self, ltm, mock_event): + """handle message → on_req_llm(bot 还没回)""" + from unittest.mock import MagicMock + from astrbot.api.provider import ProviderRequest + from astrbot.api.message_components import Plain + + comp = Plain(text="你好") + mock_event.get_messages.return_value = [comp] + + recorded_idx = [0] + + def _get_extra(key, default=None): + if key == "_ltm_raw_idx": + return recorded_idx[0] + return default + + mock_event.get_extra = _get_extra + + await ltm.handle_message(mock_event) + req = ProviderRequest() + await ltm.on_req_llm(mock_event, req) + + assert len(req.contexts) <= 0 or all( + isinstance(c, dict) for c in req.contexts + ) + + @pytest.mark.asyncio + async def test_full_roundtrip(self, ltm, mock_event): + """完整一轮:handle → on_req → on_agent_done → on_req""" + from unittest.mock import MagicMock + from astrbot.api.provider import ProviderRequest, LLMResponse + from astrbot.api.message_components import Plain + + comp = Plain(text="@bot hi") + mock_event.get_messages.return_value = [comp] + + raw_idx = 0 + + def _get_extra(key, default=None): + if key == "_ltm_raw_idx": + return raw_idx + return default + + mock_event.get_extra = _get_extra + + await ltm.handle_message(mock_event) + + req = ProviderRequest() + await ltm.on_req_llm(mock_event, req) + assert isinstance(req.contexts, list) + assert "chatroom" in req.system_prompt.lower() + assert req.conversation is None + + raw_idx += 1 + mock_run_ctx = MagicMock() + mock_run_ctx.messages = [] + resp = LLMResponse(role="assistant", completion_text="你好呀~") + + await ltm.on_agent_done(mock_event, mock_run_ctx, resp) + + # BOT 回复被构建进 contexts,cursor 推进后 trim 归零,raw_records 被裁剪清空 + ctx_list = ltm.contexts["group_123"] + bot_ctx = [c for c in ctx_list if c.get("role") == "assistant"] + assert len(bot_ctx) == 1 + assert "你好呀~" in bot_ctx[0]["content"] + # cursor 在 trim 后归零(所有已消费条目被清除) + assert ltm._raw_cursor["group_123"] == 0 + + @pytest.mark.asyncio + async def test_toggle_cleanup(self, ltm, mock_event): + """测试开关切换时的惰性清理""" + ltm.raw_records["group_123"].append("[小明/14:30]: hello") + ltm._raw_cursor["group_123"] = 1 + + await ltm.remove_session(mock_event) + + assert "group_123" not in ltm.raw_records + assert "group_123" not in ltm.contexts + assert "group_123" not in ltm._raw_cursor + + def test_trim_raw_records_preserves_unconsumed(self, ltm): + """_trim_raw_records 只淘汰 cursor 之前的条目""" + umo = "group_123" + for i in range(100): + ltm.raw_records[umo].append(f"[user{i}/14:00]: msg{i}") + ltm._raw_cursor[umo] = 50 # 50 条已消费 + + ltm._trim_raw_records(umo) + + # cursor 之前的 50 条被清掉 + remaining = list(ltm.raw_records[umo]) + assert len(remaining) == 50 + # 保留的是 cursor 之后的 + assert remaining[0] == "[user50/14:00]: msg50" + assert remaining[-1] == "[user99/14:00]: msg99" + assert ltm._raw_cursor[umo] == 0 # 所有剩余条目在 cursor 之前被清除后归零 + + @pytest.mark.asyncio + async def test_slow_summary_does_not_block_other_umo(self): + """A slow LTM summary in one session must not block another session.""" + import asyncio + from unittest.mock import AsyncMock, MagicMock + + from astrbot.api.event import AstrMessageEvent + from astrbot.api.message_components import Plain + from astrbot.api.platform import MessageType + from astrbot.api.provider import LLMResponse, Provider + from astrbot.builtin_stars.astrbot.long_term_memory import LongTermMemory + + slow_umo = "group_slow" + fast_umo = "group_fast" + summary_started = asyncio.Event() + release_summary = asyncio.Event() + + async def slow_text_chat(*args, **kwargs): + summary_started.set() + await release_summary.wait() + return LLMResponse(role="assistant", completion_text="summary") + + provider = MagicMock(spec=Provider) + provider.text_chat = AsyncMock(side_effect=slow_text_chat) + + cfg = { + "provider_ltm_settings": { + "image_caption": False, + "image_caption_provider_id": "", + "active_reply": { + "enable": False, + "method": "possibility_reply", + "possibility_reply": 0.0, + "prompt": "", + "whitelist": [], + }, + "ltm_compaction_strategy": "llm_summary", + "ltm_summary_trigger_rounds": 1, + "ltm_summary_keep_recent_rounds": 1, + "ltm_summary_provider_id": "summary-provider", + }, + "provider_settings": {"image_caption_prompt": ""}, + } + ctx = MagicMock() + ctx.get_config.return_value = cfg + ctx.get_provider_by_id.return_value = provider + ltm = LongTermMemory(MagicMock(), ctx) + + slow_event = MagicMock(spec=AstrMessageEvent) + slow_event.unified_msg_origin = slow_umo + slow_event.get_message_type.return_value = MessageType.GROUP_MESSAGE + slow_event.message_obj = MagicMock() + slow_event.message_obj.sender.nickname = "slow" + + fast_event = MagicMock(spec=AstrMessageEvent) + fast_event.unified_msg_origin = fast_umo + fast_event.get_message_type.return_value = MessageType.GROUP_MESSAGE + fast_event.get_messages.return_value = [Plain(text="hello from fast")] + fast_event.message_obj = MagicMock() + fast_event.message_obj.sender.nickname = "fast" + + ltm.raw_records[slow_umo].append("[slow/00:00:00]: already consumed") + ltm._raw_cursor[slow_umo] = 1 + ltm.contexts[slow_umo] = [ + {"role": "user", "content": "old question"}, + {"role": "assistant", "content": "old answer"}, + {"role": "user", "content": "recent question"}, + {"role": "assistant", "content": "recent answer"}, + ] + + slow_task = asyncio.create_task( + ltm.on_agent_done( + slow_event, + MagicMock(messages=[]), + LLMResponse(role="assistant", completion_text="slow reply"), + ) + ) + await asyncio.wait_for(summary_started.wait(), timeout=1) + + await asyncio.wait_for(ltm.handle_message(fast_event), timeout=0.2) + assert len(ltm.raw_records[fast_umo]) == 1 + assert fast_event.set_extra.call_args.args == ("_ltm_raw_idx", 0) + + release_summary.set() + await asyncio.wait_for(slow_task, timeout=1) + + +# ============================================================================= +# 多轮累积 +# ============================================================================= + +class TestMultiRoundAccumulation: + """验证多轮对话中 contexts 累积行为。""" + + @pytest.fixture + def mock_event(self): + from unittest.mock import MagicMock + from astrbot.api.event import AstrMessageEvent + from astrbot.api.platform import MessageType + event = MagicMock(spec=AstrMessageEvent) + event.unified_msg_origin = "group_123" + event.get_message_type.return_value = MessageType.GROUP_MESSAGE + event.message_obj = MagicMock() + event.message_obj.sender.nickname = "小明" + event.get_extra.return_value = -1 + return event + + @pytest.fixture + def ltm(self): + from unittest.mock import MagicMock + from astrbot.builtin_stars.astrbot.long_term_memory import LongTermMemory + ctx = MagicMock() + ctx.get_config.return_value = { + "provider_ltm_settings": { + "image_caption": False, "image_caption_provider_id": "", + "image_caption_prompt": "", + "active_reply": {"enable": False, "method": "possibility_reply", + "possibility_reply": 0.0, "prompt": "", "whitelist": []}, + }, + "provider_settings": {"image_caption_prompt": ""}, + } + return LongTermMemory(MagicMock(), ctx) + + @pytest.mark.asyncio + async def test_three_rounds_contexts_grow(self, ltm, mock_event): + """三轮 @bot 对话后 contexts 累积 3 条 user + 3 条 assistant(含 prompt)。""" + from unittest.mock import MagicMock + from astrbot.api.provider import ProviderRequest, LLMResponse + from astrbot.api.message_components import Plain + + for user_text, bot_text in [ + ("@bot 你好", "你好呀~"), + ("@bot 1+1", "等于2"), + ("@bot 再见", "拜拜~"), + ]: + def _get_extra(key, default=None): + if key == "_ltm_raw_idx": + return len(ltm.raw_records["group_123"]) + return default + mock_event.get_extra = _get_extra + comp = Plain(text=user_text) + mock_event.get_messages.return_value = [comp] + + await ltm.handle_message(mock_event) + + req = ProviderRequest() + await ltm.on_req_llm(mock_event, req) + assert req.conversation is None + + mock_run_ctx = MagicMock(messages=[]) + resp = LLMResponse(role="assistant", completion_text=bot_text) + await ltm.on_agent_done(mock_event, mock_run_ctx, resp) + + ctxs = ltm.contexts["group_123"] + roles = [c["role"] for c in ctxs] + # on_agent_done 从 cursor 起构建(含 @bot prompt) → user 段也进 contexts + assert roles == ["user", "assistant", "user", "assistant", "user", "assistant"] + assert "你好呀~" in ctxs[1]["content"] + assert "等于2" in ctxs[3]["content"] + assert "拜拜~" in ctxs[5]["content"] + + @pytest.mark.asyncio + async def test_contexts_only_appended_never_rebuilt(self, ltm, mock_event): + """验证 contexts 是追加式,不会被重建。""" + from unittest.mock import MagicMock + from astrbot.api.provider import ProviderRequest, LLMResponse + from astrbot.api.message_components import Plain + + # Round 1 + def _get_extra_r1(key, default=None): + return 0 if key == "_ltm_raw_idx" else default + mock_event.get_extra = _get_extra_r1 + mock_event.get_messages.return_value = [Plain(text="@bot hi")] + await ltm.handle_message(mock_event) + req = ProviderRequest() + await ltm.on_req_llm(mock_event, req) + await ltm.on_agent_done( + mock_event, + MagicMock(messages=[]), + LLMResponse(role="assistant", completion_text="hello"), + ) + ctx_after_r1 = list(ltm.contexts["group_123"]) + assert len(ctx_after_r1) >= 1 # 至少有 assistant + + # Round 2 + def _get_extra_r2(key, default=None): + return 2 if key == "_ltm_raw_idx" else default + mock_event.get_extra = _get_extra_r2 + mock_event.get_messages.return_value = [Plain(text="@bot again")] + await ltm.handle_message(mock_event) + req = ProviderRequest() + await ltm.on_req_llm(mock_event, req) + await ltm.on_agent_done( + mock_event, + MagicMock(messages=[]), + LLMResponse(role="assistant", completion_text="world"), + ) + ctx_after_r2 = list(ltm.contexts["group_123"]) + # 旧条目保留,新条目追加 + assert len(ctx_after_r2) >= len(ctx_after_r1) + assert ctx_after_r2[:len(ctx_after_r1)] == ctx_after_r1 + + +# ============================================================================= +# on_agent_done 工具链 +# ============================================================================= + +class TestAgentDoneToolChains: + """验证 on_agent_done 正确记录工具调用链。""" + + @pytest.fixture + def mock_event(self): + from unittest.mock import MagicMock + from astrbot.api.event import AstrMessageEvent + from astrbot.api.platform import MessageType + event = MagicMock(spec=AstrMessageEvent) + event.unified_msg_origin = "group_123" + event.get_message_type.return_value = MessageType.GROUP_MESSAGE + event.message_obj = MagicMock() + event.message_obj.sender.nickname = "小明" + event.get_extra.return_value = 0 + event.get_messages.return_value = [MagicMock()] + return event + + @pytest.fixture + def ltm(self): + from unittest.mock import MagicMock + from astrbot.builtin_stars.astrbot.long_term_memory import LongTermMemory + ctx = MagicMock() + ctx.get_config.return_value = { + "provider_ltm_settings": { + "image_caption": False, "image_caption_provider_id": "", + "image_caption_prompt": "", + "active_reply": {"enable": False, "method": "possibility_reply", + "possibility_reply": 0.0, "prompt": "", "whitelist": []}, + }, + "provider_settings": {"image_caption_prompt": ""}, + } + return LongTermMemory(MagicMock(), ctx) + + @pytest.mark.asyncio + async def test_tool_call_recorded_in_raw(self, ltm, mock_event): + """工具调用被记录到 contexts 中的 assistant(tool_calls) + tool 消息。""" + from astrbot.api.provider import LLMResponse + from astrbot.api.message_components import Plain + + mock_event.get_messages.return_value = [Plain(text="@bot weather")] + await ltm.handle_message(mock_event) + + from unittest.mock import MagicMock + tc_msg = MagicMock() + tc_msg.role = "assistant" + tc_msg.tool_calls = [{"id": "c1", "function": {"name": "weather", "arguments": '{"city":"bj"}'}}] + tool_msg = MagicMock() + tool_msg.role = "tool" + tool_msg.tool_call_id = "c1" + tool_msg.content = "sunny" + + run_ctx = MagicMock() + run_ctx.messages = [tc_msg, tool_msg] + resp = LLMResponse(role="assistant", completion_text="今天晴天") + + await ltm.on_agent_done(mock_event, run_ctx, resp) + + # on_agent_done 构建 contexts 后 trim raw_records,检查 contexts + ctxs = ltm.contexts["group_123"] + roles = [c["role"] for c in ctxs] + has_tool_call = any(c.get("tool_calls") for c in ctxs if c["role"] == "assistant") + has_tool_result = any(c["role"] == "tool" for c in ctxs) + has_final_asst = any( + c["role"] == "assistant" and c.get("content") == "今天晴天" + for c in ctxs + ) + assert has_tool_call, f"expected assistant with tool_calls in contexts" + assert has_tool_result, f"expected tool message in contexts" + assert has_final_asst, f"expected final assistant text in contexts" + + @pytest.mark.asyncio + async def test_tool_call_no_final_text(self, ltm, mock_event): + """工具调用后没有最终文本回复时 contexts 有 tool_calls 但没有最终 assistant 文本。""" + from astrbot.api.provider import LLMResponse + from astrbot.api.message_components import Plain + + mock_event.get_messages.return_value = [Plain(text="@bot task")] + await ltm.handle_message(mock_event) + + from unittest.mock import MagicMock + tc_msg = MagicMock() + tc_msg.role = "assistant" + tc_msg.tool_calls = [{"id": "c2", "function": {"name": "calc", "arguments": "{}"}}] + + run_ctx = MagicMock() + run_ctx.messages = [tc_msg] + resp = LLMResponse(role="assistant") # 无 completion_text + + await ltm.on_agent_done(mock_event, run_ctx, resp) + + ctxs = ltm.contexts["group_123"] + has_tool_call = any(c.get("tool_calls") for c in ctxs if c["role"] == "assistant") + has_final_text = any( + c["role"] == "assistant" and c.get("content") + for c in ctxs + ) + assert has_tool_call + assert not has_final_text + + @pytest.mark.asyncio + async def test_tool_call_dedup_across_rounds(self, ltm, mock_event): + """历史工具调用不应被重复持久化——防重复注入的核心回归测试。""" + from astrbot.api.provider import LLMResponse + from unittest.mock import MagicMock + + umo = "group_123" + + # 确保 raw_records 存在(on_agent_done 前置条件) + ltm.raw_records[umo].append("[dummy/00:00]: hi") + ltm._raw_cursor[umo] = 1 + + tc_msg = MagicMock() + tc_msg.role = "assistant" + tc_msg.tool_calls = [ + {"id": "c1", "function": {"name": "tool_a", "arguments": "{}"}} + ] + tool_msg = MagicMock() + tool_msg.role = "tool" + tool_msg.tool_call_id = "c1" + tool_msg.content = "result_a" + + run_ctx = MagicMock(messages=[tc_msg, tool_msg]) + resp = LLMResponse(role="assistant", completion_text="ok") + + # Round 1 — 首次记录 + await ltm.on_agent_done(mock_event, run_ctx, resp) + assert ltm._persisted_tool_call_ids[umo] == {"c1"} + assert ltm._persisted_tool_result_ids[umo] == {"c1"} + + # Round 2 — 模拟历史注入:run_context.messages 仍包含 c1 + ltm.raw_records[umo].append("[dummy2/00:01]: hi2") + resp2 = LLMResponse(role="assistant", completion_text="ok2") + await ltm.on_agent_done(mock_event, run_ctx, resp2) + + # 去重集不变 — c1 未被重复持久化 + assert ltm._persisted_tool_call_ids[umo] == {"c1"} + assert ltm._persisted_tool_result_ids[umo] == {"c1"} + + # Round 3 — 历史 c1 + 新调用 c2 混合 + ltm.raw_records[umo].append("[dummy3/00:02]: hi3") + new_tc = MagicMock() + new_tc.role = "assistant" + new_tc.tool_calls = [ + {"id": "c2", "function": {"name": "tool_b", "arguments": "{}"}} + ] + new_tool = MagicMock() + new_tool.role = "tool" + new_tool.tool_call_id = "c2" + new_tool.content = "result_b" + + mixed_ctx = MagicMock(messages=[tc_msg, tool_msg, new_tc, new_tool]) + resp3 = LLMResponse(role="assistant", completion_text="ok3") + await ltm.on_agent_done(mock_event, mixed_ctx, resp3) + + # c1 不变,c2 被添加 + assert ltm._persisted_tool_call_ids[umo] == {"c1", "c2"} + assert ltm._persisted_tool_result_ids[umo] == {"c1", "c2"} + + @pytest.mark.asyncio + async def test_tool_result_not_truncated_when_disabled(self, mock_event): + """history_tool_result_truncate=False 时工具结果不被截断。""" + from collections import deque + from unittest.mock import MagicMock + from astrbot.builtin_stars.astrbot.long_term_memory import LongTermMemory + + ctx = MagicMock() + ctx.get_config.return_value = { + "provider_ltm_settings": { + "image_caption": False, "image_caption_provider_id": "", + "active_reply": {"enable": False, "method": "possibility_reply", + "possibility_reply": 0.0, "prompt": "", "whitelist": []}, + "history_tool_result_truncate": False, + "history_tool_result_max_chars": 10, # 如果截断就会生效 + "ltm_compaction_strategy": "truncate", + "ltm_max_rounds": 80, + "ltm_truncate_drop_rounds": 50, + }, + "provider_settings": {"image_caption_prompt": ""}, + } + + ltm = LongTermMemory(MagicMock(), ctx) + umo = mock_event.unified_msg_origin + ltm.raw_records[umo] = deque() + + long_content = "TOOL-RESULT-" + ("X" * 5000) + + tool_msg = MagicMock() + tool_msg.role = "tool" + tool_msg.tool_call_id = "tool-1" + tool_msg.content = long_content + + run_ctx = MagicMock(messages=[tool_msg]) + + await ltm.on_agent_done(mock_event, run_ctx, None) + + # raw_records 被 _trim_raw_records 清空,检查 contexts + ctxs = ltm.contexts[umo] + tool_msgs = [c for c in ctxs if c.get("role") == "tool"] + assert len(tool_msgs) > 0 + assert long_content in tool_msgs[0]["content"] + assert "[TRUNCATED" not in tool_msgs[0]["content"] + + +# ============================================================================= +# 极端数据 +# ============================================================================= + +class TestExtremeData: + """边界和极端输入。""" + + @pytest.fixture + def ltm(self): + from unittest.mock import MagicMock + from astrbot.builtin_stars.astrbot.long_term_memory import LongTermMemory + ctx = MagicMock() + ctx.get_config.return_value = { + "provider_ltm_settings": { + "image_caption": False, "image_caption_provider_id": "", + "image_caption_prompt": "", + "active_reply": {"enable": False, "method": "possibility_reply", + "possibility_reply": 0.0, "prompt": "", "whitelist": []}, + }, + "provider_settings": {"image_caption_prompt": ""}, + } + return LongTermMemory(MagicMock(), ctx) + + def test_emoji_and_unicode_user_segment(self): + """emoji 和 Unicode 在 user 段中正确处理。""" + lines = [ + "[小明/14:30]: 😂😂😂 哈哈哈哈", + "[小红/14:31]: 你好世界 🌍 !", + "[小刚/14:32]: 日本語テスト \u3067\u3059", + ] + result = _build_segments(lines) + assert len(result) == 1 + assert result[0]["role"] == "user" + assert "😂😂😂" in result[0]["content"] + assert "🌍" in result[0]["content"] + assert "日本語テスト" in result[0]["content"] + + def test_mixed_image_at_plain(self): + """Image + At + Plain 混合行。""" + # 模拟 handle_message 构建的 raw line + raw_lines = [ + "[小明/14:30]: hi [Image: 一只猫] [At: 小红] ok", + "[小红/14:31]: 收到", + ] + result = _build_segments(raw_lines) + assert len(result) == 1 + assert "[Image: 一只猫]" in result[0]["content"] + assert "[At: 小红]" in result[0]["content"] + assert "ok" in result[0]["content"] + + def test_max_raw_bytes_triggers_cursor_trim(self, ltm): + """MAX_RAW_BYTES 超限时淘汰已消费条目。""" + umo = "big_group" + # 塞入 60KB 已消费 + 10KB 未消费 + consumed = "x" * 6000 # ~6KB per line + for i in range(10): + ltm.raw_records[umo].append(consumed + f" consumed{i}") + ltm._raw_cursor[umo] += 1 + unconsumed = "y" * 1000 # ~1KB per line + for i in range(10): + ltm.raw_records[umo].append(unconsumed + f" unconsumed{i}") + + ltm._trim_raw_records(umo) + + remaining = list(ltm.raw_records[umo]) + assert len(remaining) > 0 + assert all("unconsumed" in s or "consumed" in s for s in remaining) + # cursor 归零(所有已消费全部清掉) + assert ltm._raw_cursor[umo] == 0 + + def test_size_based_trim_actually_activates(self, ltm): + """size-based 淘汰在超限时真正触发,且不依赖 cursor > 0。""" + umo = "overflow_group" + # 20 条 unconsumed,每条 ~55 bytes → ~1100 bytes >> 100 + for i in range(20): + ltm.raw_records[umo].append("x" * 50 + f" msg{i}") + ltm._raw_cursor[umo] = 0 + + ltm._trim_raw_records(umo, max_bytes=100) + + remaining = list(ltm.raw_records[umo]) + total = sum(len(s.encode()) for s in remaining) + assert total <= 100, f"expected ≤100 bytes, got {total}" + assert ltm._raw_cursor[umo] == 0 + + def test_size_based_trim_with_mixed_consumed(self, ltm): + """size-based 淘汰在混合 consumed/unconsumed 时正确工作。""" + umo = "mixed_group" + # 5 consumed (cursor=5) + 15 unconsumed → 先清 consumed,再按 size 淘汰 unconsumed + for i in range(5): + ltm.raw_records[umo].append("x" * 50 + f" consumed{i}") + ltm._raw_cursor[umo] = i + 1 + for i in range(15): + ltm.raw_records[umo].append("y" * 50 + f" unconsumed{i}") + + ltm._trim_raw_records(umo, max_bytes=100) + + remaining = list(ltm.raw_records[umo]) + total = sum(len(s.encode()) for s in remaining) + assert total <= 100, f"expected ≤100 bytes, got {total}" + assert ltm._raw_cursor[umo] == 0 + # consumed 条目不应残留(排除 unconsumed 子串匹配) + assert not any(" consumed" in s for s in remaining) + + +# ============================================================================= +# Persona begin_dialogs 前置保留 +# ============================================================================= + +class TestPersonaBeginDialogs: + """验证 req.contexts 已有内容时 LTM 前置保留。""" + + @pytest.fixture + def mock_event(self): + from unittest.mock import MagicMock + from astrbot.api.event import AstrMessageEvent + from astrbot.api.platform import MessageType + event = MagicMock(spec=AstrMessageEvent) + event.unified_msg_origin = "group_123" + event.get_message_type.return_value = MessageType.GROUP_MESSAGE + event.message_obj = MagicMock() + event.message_obj.sender.nickname = "小明" + event.get_extra.return_value = 0 + event.get_messages.return_value = [MagicMock()] + return event + + @pytest.fixture + def ltm(self): + from unittest.mock import MagicMock + from astrbot.builtin_stars.astrbot.long_term_memory import LongTermMemory + ctx = MagicMock() + ctx.get_config.return_value = { + "provider_ltm_settings": { + "image_caption": False, "image_caption_provider_id": "", + "image_caption_prompt": "", + "active_reply": {"enable": False, "method": "possibility_reply", + "possibility_reply": 0.0, "prompt": "", "whitelist": []}, + }, + "provider_settings": {"image_caption_prompt": ""}, + } + return LongTermMemory(MagicMock(), ctx) + + @pytest.mark.asyncio + async def test_existing_contexts_preserved(self, ltm, mock_event): + """Persona 注入的 begin_dialogs 在 contexts 之前。""" + from astrbot.api.provider import ProviderRequest + from astrbot.api.message_components import Plain + + mock_event.get_messages.return_value = [Plain(text="@bot hi")] + await ltm.handle_message(mock_event) + + # 模拟 Persona 已注入的内容 + persona_dialogs = [ + {"role": "system", "content": "sample-only"}, + {"role": "user", "content": "sample-only"}, + {"role": "assistant", "content": "sample-only"}, + ] + req = ProviderRequest(contexts=persona_dialogs) + await ltm.on_req_llm(mock_event, req) + + # Persona 内容在 LTM 内容之前 + assert req.contexts[:3] == persona_dialogs + + +# ============================================================================= +# 并发安全 +# ============================================================================= + +class TestConcurrentSafety: + """验证 asyncio.Lock 下的并发安全性。""" + + @pytest.fixture + def mock_event_factory(self): + from unittest.mock import MagicMock + from astrbot.api.event import AstrMessageEvent + from astrbot.api.platform import MessageType + + def _make(umo="group_123", raw_idx=0, text="hi"): + event = MagicMock(spec=AstrMessageEvent) + event.unified_msg_origin = umo + event.get_message_type.return_value = MessageType.GROUP_MESSAGE + event.message_obj = MagicMock() + event.message_obj.sender.nickname = "小明" + + def _ge(key, default=None): + return raw_idx if key == "_ltm_raw_idx" else default + event.get_extra = _ge + from astrbot.api.message_components import Plain + event.get_messages.return_value = [Plain(text=text)] + return event + return _make + + @pytest.fixture + def ltm(self): + from unittest.mock import MagicMock + from astrbot.builtin_stars.astrbot.long_term_memory import LongTermMemory + ctx = MagicMock() + ctx.get_config.return_value = { + "provider_ltm_settings": { + "image_caption": False, "image_caption_provider_id": "", + "image_caption_prompt": "", + "active_reply": {"enable": False, "method": "possibility_reply", + "possibility_reply": 0.0, "prompt": "", "whitelist": []}, + }, + "provider_settings": {"image_caption_prompt": ""}, + } + return LongTermMemory(MagicMock(), ctx) + + @pytest.mark.asyncio + async def test_concurrent_handle_same_umo(self, ltm, mock_event_factory): + """同一群并发 handle_message 不会丢失消息。""" + import asyncio + + texts = ["msg1", "msg2", "msg3", "msg4", "msg5"] + tasks = [ + ltm.handle_message(mock_event_factory(raw_idx=i, text=t)) + for i, t in enumerate(texts) + ] + await asyncio.gather(*tasks) + + raw = list(ltm.raw_records["group_123"]) + assert len(raw) == 5 + for t in texts: + assert any(t in s for s in raw) + + @pytest.mark.asyncio + async def test_concurrent_handle_keeps_lock_integrity(self, ltm, mock_event_factory): + """并发 handle 后 raw_records 无交错损坏。""" + import asyncio + + async def record_with_delay(text, delay: float = 0): + await asyncio.sleep(delay) + await ltm.handle_message( + mock_event_factory(raw_idx=0, text=text) + ) + + await asyncio.gather( + record_with_delay("a", 0.01), + record_with_delay("b", 0.02), + record_with_delay("c", 0.0), + ) + + raw = list(ltm.raw_records["group_123"]) + assert len(raw) == 3 + assert all(any(t in s for s in raw) for t in ["a", "b", "c"]) + + +# ============================================================================= +# _split_into_rounds +# ============================================================================= + + +class TestSplitIntoRounds: + def test_empty(self): + assert _split_into_rounds([]) == [] + + def test_single_user(self): + rounds = _split_into_rounds([{"role": "user", "content": "hi"}]) + assert len(rounds) == 1 + assert len(rounds[0]) == 1 + + def test_user_assistant_single_round(self): + ctxs = [ + {"role": "user", "content": "hi"}, + {"role": "assistant", "content": "hello"}, + ] + rounds = _split_into_rounds(ctxs) + assert len(rounds) == 1 + assert len(rounds[0]) == 2 + + def test_multi_round(self): + ctxs = [ + {"role": "user", "content": "r1"}, + {"role": "assistant", "content": "a1"}, + {"role": "user", "content": "r2"}, + {"role": "assistant", "content": "a2"}, + ] + rounds = _split_into_rounds(ctxs) + assert len(rounds) == 2 + assert len(rounds[0]) == 2 + assert len(rounds[1]) == 2 + assert rounds[0][0]["content"] == "r1" + assert rounds[1][0]["content"] == "r2" + + def test_tool_chain_single_round(self): + ctxs = [ + {"role": "user", "content": "@bot weather"}, + {"role": "assistant", "content": None, "tool_calls": [{"id": "c1"}]}, + {"role": "tool", "tool_call_id": "c1", "content": "sunny"}, + {"role": "assistant", "content": "it's sunny"}, + ] + rounds = _split_into_rounds(ctxs) + assert len(rounds) == 1 + assert len(rounds[0]) == 4 # tool chain stays together + + def test_multi_step_tool_chain(self): + ctxs = [ + {"role": "user", "content": "@bot complex"}, + {"role": "assistant", "content": None, "tool_calls": [{"id": "c1"}]}, + {"role": "tool", "tool_call_id": "c1", "content": "r1"}, + {"role": "assistant", "content": None, "tool_calls": [{"id": "c2"}]}, + {"role": "tool", "tool_call_id": "c2", "content": "r2"}, + {"role": "assistant", "content": "done"}, + ] + rounds = _split_into_rounds(ctxs) + assert len(rounds) == 1 + assert len(rounds[0]) == 6 # multi-step tool chain in one round + + def test_two_rounds_with_tools(self): + ctxs = [ + {"role": "user", "content": "@bot weather"}, + {"role": "assistant", "content": None, "tool_calls": [{"id": "c1"}]}, + {"role": "tool", "tool_call_id": "c1", "content": "sunny"}, + {"role": "assistant", "content": "it's sunny"}, + {"role": "user", "content": "@bot search"}, + {"role": "assistant", "content": None, "tool_calls": [{"id": "c2"}]}, + {"role": "tool", "tool_call_id": "c2", "content": "results"}, + {"role": "assistant", "content": "done"}, + ] + rounds = _split_into_rounds(ctxs) + assert len(rounds) == 2 + assert len(rounds[0]) == 4 + assert len(rounds[1]) == 4 + + def test_starts_with_assistant(self): + """Defensive: first segment isn't user.""" + rounds = _split_into_rounds([{"role": "assistant", "content": "orphan"}]) + assert len(rounds) == 1 + assert rounds[0][0]["role"] == "assistant" + + def test_consecutive_users(self): + """Two user segments in a row → second starts new round.""" + ctxs = [ + {"role": "user", "content": "u1"}, + {"role": "user", "content": "u2"}, + {"role": "assistant", "content": "a1"}, + ] + rounds = _split_into_rounds(ctxs) + assert len(rounds) == 2 + assert rounds[0] == [{"role": "user", "content": "u1"}] + assert rounds[1] == [ + {"role": "user", "content": "u2"}, + {"role": "assistant", "content": "a1"}, + ] + + +# ============================================================================= +# _rounds_to_text +# ============================================================================= + + +class TestRoundsToText: + def test_empty(self): + assert _rounds_to_text([]) == "" + + def test_single_round(self): + rounds = [ + [ + {"role": "user", "content": "hi"}, + {"role": "assistant", "content": "hello"}, + ], + ] + text = _rounds_to_text(rounds) + assert "--- Round 1 ---" in text + assert "[user] hi" in text + assert "[assistant] hello" in text + + def test_multi_round(self): + rounds = [ + [{"role": "user", "content": "r1"}], + [{"role": "assistant", "content": "a1"}], + ] + text = _rounds_to_text(rounds) + assert "--- Round 1 ---" in text + assert "--- Round 2 ---" in text + + def test_tool_calls_serialized(self): + """tool_calls (list) is json.dumps-ed, not crashing.""" + rounds = [ + [ + { + "role": "assistant", + "content": None, + "tool_calls": [{"id": "c1", "function": {"name": "f"}}], + }, + ], + ] + text = _rounds_to_text(rounds) + assert '"id"' in text # json-serialized + assert "c1" in text + + +# ============================================================================= +# LTM truncation compaction +# ============================================================================= + + +class TestLTMTruncationCompaction: + @pytest.fixture + def mock_event(self): + from unittest.mock import MagicMock + from astrbot.api.event import AstrMessageEvent + from astrbot.api.platform import MessageType + event = MagicMock(spec=AstrMessageEvent) + event.unified_msg_origin = "group_123" + event.get_message_type.return_value = MessageType.GROUP_MESSAGE + event.get_extra.return_value = 0 + event.get_messages.return_value = [] + event.message_obj = MagicMock() + event.message_obj.sender_nickname = "小明" + return event + + def make_contexts(self, n_rounds: int) -> list[dict]: + """Build N simple user→assistant rounds.""" + ctxs = [] + for i in range(n_rounds): + ctxs.append({"role": "user", "content": f"q{i}"}) + ctxs.append({"role": "assistant", "content": f"a{i}"}) + return ctxs + + @pytest.mark.asyncio + async def test_no_truncation_when_under_limit(self, mock_event): + from astrbot.builtin_stars.astrbot.long_term_memory import LongTermMemory + from unittest.mock import MagicMock + + ctx = MagicMock() + ctx.get_config.return_value = { + "provider_ltm_settings": { + "image_caption": False, "image_caption_provider_id": "", + "active_reply": {"enable": False, "method": "possibility_reply", + "possibility_reply": 0.0, "prompt": "", "whitelist": []}, + "ltm_compaction_strategy": "truncate", + "ltm_max_rounds": 10, + "ltm_truncate_drop_rounds": 5, + }, + "provider_settings": {"image_caption_prompt": ""}, + } + ltm = LongTermMemory(MagicMock(), ctx) + umo = mock_event.unified_msg_origin + + # 5 rounds → not over 10 limit + ltm.contexts[umo] = self.make_contexts(5) + rounds_before = _split_into_rounds(ltm.contexts[umo]) + + cfg = ltm.cfg(mock_event) + if len(rounds_before) > cfg["ltm_max_rounds"]: + kept = rounds_before[cfg["ltm_truncate_drop_rounds"] :] + ltm.contexts[umo] = [seg for rnd in kept for seg in rnd] + + rounds_after = _split_into_rounds(ltm.contexts[umo]) + assert len(rounds_after) == 5 + assert rounds_after[0][0]["content"] == "q0" + + @pytest.mark.asyncio + async def test_truncation_burst_drop(self, mock_event): + """超过 ltm_max_rounds 时从前面弹掉 ltm_truncate_drop_rounds 轮。""" + from astrbot.builtin_stars.astrbot.long_term_memory import LongTermMemory + from unittest.mock import MagicMock + + ctx = MagicMock() + ctx.get_config.return_value = { + "provider_ltm_settings": { + "image_caption": False, "image_caption_provider_id": "", + "active_reply": {"enable": False, "method": "possibility_reply", + "possibility_reply": 0.0, "prompt": "", "whitelist": []}, + "ltm_compaction_strategy": "truncate", + "ltm_max_rounds": 10, + "ltm_truncate_drop_rounds": 4, + }, + "provider_settings": {"image_caption_prompt": ""}, + } + ltm = LongTermMemory(MagicMock(), ctx) + umo = mock_event.unified_msg_origin + + # 12 rounds → over 10 → drop 4 from front → 8 remain + ltm.contexts[umo] = self.make_contexts(12) + rounds_before = _split_into_rounds(ltm.contexts[umo]) + + cfg = ltm.cfg(mock_event) + if len(rounds_before) > cfg["ltm_max_rounds"]: + kept = rounds_before[cfg["ltm_truncate_drop_rounds"] :] + ltm.contexts[umo] = [seg for rnd in kept for seg in rnd] + + rounds_after = _split_into_rounds(ltm.contexts[umo]) + assert len(rounds_after) == 8 + # first retained should be q4 (index 4 after dropping 0-3) + assert rounds_after[0][0]["content"] == "q4" + + @pytest.mark.asyncio + async def test_truncation_burst_drop_huge_drop(self, mock_event): + """drop_rounds >= total 时保留最后 1 轮(防御边界)。""" + from astrbot.builtin_stars.astrbot.long_term_memory import LongTermMemory + from unittest.mock import MagicMock + + ctx = MagicMock() + ctx.get_config.return_value = { + "provider_ltm_settings": { + "image_caption": False, "image_caption_provider_id": "", + "active_reply": {"enable": False, "method": "possibility_reply", + "possibility_reply": 0.0, "prompt": "", "whitelist": []}, + "ltm_compaction_strategy": "truncate", + "ltm_max_rounds": 5, + "ltm_truncate_drop_rounds": 50, + }, + "provider_settings": {"image_caption_prompt": ""}, + } + ltm = LongTermMemory(MagicMock(), ctx) + umo = mock_event.unified_msg_origin + + ltm.contexts[umo] = self.make_contexts(10) + rounds_before = _split_into_rounds(ltm.contexts[umo]) + + cfg = ltm.cfg(mock_event) + if len(rounds_before) > cfg["ltm_max_rounds"]: + safe_drop = min(cfg["ltm_truncate_drop_rounds"], len(rounds_before) - 1) + kept = rounds_before[safe_drop:] + ltm.contexts[umo] = [seg for rnd in kept for seg in rnd] + + rounds_after = _split_into_rounds(ltm.contexts[umo]) + # drop=50 but only 10 exist → safe_drop=9, keeps last 1 round + assert len(rounds_after) == 1 + + @pytest.mark.asyncio + async def test_tool_chain_not_split(self, mock_event): + """截断不应拆散工具链。""" + from astrbot.builtin_stars.astrbot.long_term_memory import LongTermMemory + from unittest.mock import MagicMock + + ctx = MagicMock() + ctx.get_config.return_value = { + "provider_ltm_settings": { + "image_caption": False, "image_caption_provider_id": "", + "active_reply": {"enable": False, "method": "possibility_reply", + "possibility_reply": 0.0, "prompt": "", "whitelist": []}, + "ltm_compaction_strategy": "truncate", + "ltm_max_rounds": 2, + "ltm_truncate_drop_rounds": 2, + }, + "provider_settings": {"image_caption_prompt": ""}, + } + ltm = LongTermMemory(MagicMock(), ctx) + umo = mock_event.unified_msg_origin + + # 3 rounds, last one has tool chain + ctxs = [ + {"role": "user", "content": "q0"}, + {"role": "assistant", "content": "a0"}, + {"role": "user", "content": "q1"}, + {"role": "assistant", "content": "a1"}, + {"role": "user", "content": "q2"}, + {"role": "assistant", "content": None, "tool_calls": [{"id": "c1"}]}, + {"role": "tool", "tool_call_id": "c1", "content": "result"}, + {"role": "assistant", "content": "final"}, + ] + ltm.contexts[umo] = ctxs + + rounds_before = _split_into_rounds(ltm.contexts[umo]) + cfg = ltm.cfg(mock_event) + if len(rounds_before) > cfg["ltm_max_rounds"]: + kept = rounds_before[cfg["ltm_truncate_drop_rounds"] :] + ltm.contexts[umo] = [seg for rnd in kept for seg in rnd] + + rounds_after = _split_into_rounds(ltm.contexts[umo]) + assert len(rounds_after) == 1 + # round preserved should have all 4 tool-chain segs intact + assert len(rounds_after[0]) == 4 + # verify the tool message is there + assert rounds_after[0][2]["tool_call_id"] == "c1" + + +# ============================================================================= +# Summary injection (on_req_llm) +# ============================================================================= + + +class TestSummaryInjection: + @pytest.mark.asyncio + async def test_summary_injected_when_present(self): + from unittest.mock import MagicMock + from astrbot.api.provider import ProviderRequest + from astrbot.builtin_stars.astrbot.long_term_memory import LongTermMemory + + ctx = MagicMock() + ctx.get_config.return_value = { + "provider_ltm_settings": { + "image_caption": False, "image_caption_provider_id": "", + "active_reply": {"enable": False, "method": "possibility_reply", + "possibility_reply": 0.0, "prompt": "", "whitelist": []}, + "ltm_compaction_strategy": "truncate", + "ltm_max_rounds": 80, + }, + "provider_settings": {"image_caption_prompt": ""}, + } + ltm = LongTermMemory(MagicMock(), ctx) + umo = "group_123" + ltm.raw_records[umo].append("[小明/14:30]: @bot hi") + ltm._raw_cursor[umo] = 0 + ltm.summaries[umo] = "Test summary text" + + event = MagicMock() + event.unified_msg_origin = umo + event.get_extra.return_value = 0 + + req = ProviderRequest() + req.contexts = [{"role": "user", "content": "persona dial"}] + + # simulate on_req_llm injection + existing = req.contexts or [] + ctxs = list(existing) + summary = ltm.summaries.get(umo, "") + if summary: + ctxs.append({ + "role": "system", + "content": "Long-term group memory summary:\n" + summary, + }) + ctxs.extend(ltm.contexts.get(umo, [])) + req.contexts = ctxs + + # persona dialog still present + assert req.contexts[0]["role"] == "user" + assert req.contexts[0]["content"] == "persona dial" + # summary injected after persona, before LTM contexts + assert req.contexts[1]["role"] == "system" + assert "Test summary text" in req.contexts[1]["content"] + + @pytest.mark.asyncio + async def test_no_summary_when_empty(self): + from unittest.mock import MagicMock + from astrbot.api.provider import ProviderRequest + from astrbot.builtin_stars.astrbot.long_term_memory import LongTermMemory + + ctx = MagicMock() + ctx.get_config.return_value = { + "provider_ltm_settings": { + "image_caption": False, "image_caption_provider_id": "", + "active_reply": {"enable": False, "method": "possibility_reply", + "possibility_reply": 0.0, "prompt": "", "whitelist": []}, + "ltm_compaction_strategy": "truncate", + "ltm_max_rounds": 80, + }, + "provider_settings": {"image_caption_prompt": ""}, + } + ltm = LongTermMemory(MagicMock(), ctx) + umo = "group_123" + ltm.raw_records[umo].append("[小明/14:30]: @bot hi") + ltm._raw_cursor[umo] = 0 + + event = MagicMock() + event.unified_msg_origin = umo + event.get_extra.return_value = 0 + + req = ProviderRequest() + req.contexts = [{"role": "user", "content": "persona dial"}] + + existing = req.contexts or [] + ctxs = list(existing) + summary = ltm.summaries.get(umo, "") + if summary: + ctxs.append({"role": "system", "content": "..."}) + ctxs.extend(ltm.contexts.get(umo, [])) + req.contexts = ctxs + + # no system summary injected + assert all(s["role"] != "system" for s in req.contexts) + + +# ============================================================================= +# remove_session cleanup +# ============================================================================= + + +class TestRemoveSessionCleanup: + @pytest.mark.asyncio + async def test_summaries_cleaned(self, mock_event): + from astrbot.builtin_stars.astrbot.long_term_memory import LongTermMemory + from unittest.mock import MagicMock + + ctx = MagicMock() + ctx.get_config.return_value = { + "provider_ltm_settings": { + "image_caption": False, "image_caption_provider_id": "", + "active_reply": {"enable": False, "method": "possibility_reply", + "possibility_reply": 0.0, "prompt": "", "whitelist": []}, + "ltm_compaction_strategy": "truncate", + "ltm_max_rounds": 80, + }, + "provider_settings": {"image_caption_prompt": ""}, + } + ltm = LongTermMemory(MagicMock(), ctx) + umo = mock_event.unified_msg_origin + ltm.summaries[umo] = "test" + ltm._persisted_tool_call_ids[umo].add("c1") + ltm._persisted_tool_result_ids[umo].add("c1") + ltm.raw_records[umo].append("[小明/14:30]: hi") + + await ltm.remove_session(mock_event) + + assert umo not in ltm.summaries + assert umo not in ltm._persisted_tool_call_ids + assert umo not in ltm._persisted_tool_result_ids + assert umo not in ltm.raw_records + + +# ============================================================================= +# LLM summary error paths +# ============================================================================= + + +class TestLLMSummaryErrorPath: + # test_missing_provider_does_not_crash removed: provider-availability + # check moved to on_agent_done which resolves the provider before + # calling _compact_with_llm_summary. + + @pytest.mark.asyncio + async def test_below_keep_recent_no_op(self, mock_event): + """rounds <= keep_recent 时不触发压缩。""" + from astrbot.builtin_stars.astrbot.long_term_memory import LongTermMemory + from astrbot.api.provider import Provider + from unittest.mock import MagicMock + + ctx = MagicMock() + ctx.get_config.return_value = {} + ltm = LongTermMemory(MagicMock(), ctx) + umo = mock_event.unified_msg_origin + + ctxs = [ + {"role": "user", "content": "q0"}, + {"role": "assistant", "content": "a0"}, + ] + ltm.contexts[umo] = ctxs + original_len = len(ltm.contexts[umo]) + + rounds = _split_into_rounds(ctxs) + await ltm._compact_with_llm_summary( + event=mock_event, + provider=MagicMock(spec=Provider), + keep_recent=5, + prompt="", + rounds=rounds, + ) + # 1 round < 5 keep_recent → no-op + assert len(ltm.contexts[umo]) == original_len + + @pytest.mark.asyncio + async def test_empty_summary_response_is_no_op(self, mock_event): + """LLM 返回空文本时不得覆盖 context/summary,并设置冷却期。""" + from astrbot.builtin_stars.astrbot.long_term_memory import ( + LongTermMemory, + SUMMARY_RETRY_COOLDOWN, + ) + from astrbot.api.provider import Provider + from unittest.mock import MagicMock, AsyncMock + + fake_resp = MagicMock() + fake_resp.completion_text = " " # whitespace-only + + fake_provider = MagicMock(spec=Provider) + fake_provider.text_chat = AsyncMock(return_value=fake_resp) + + ctx = MagicMock() + ctx.get_config.return_value = {} + + ltm = LongTermMemory(MagicMock(), ctx) + umo = mock_event.unified_msg_origin + + old_ctxs = [ + {"role": "user", "content": "old"}, + {"role": "assistant", "content": "old reply"}, + {"role": "user", "content": "new"}, + {"role": "assistant", "content": "new reply"}, + ] + ltm.contexts[umo] = old_ctxs + ltm.summaries[umo] = "existing summary" + + rounds = _split_into_rounds(old_ctxs) # 2 rounds + + # keep_recent=1 → old_rounds has 1 round, provider will be called + await ltm._compact_with_llm_summary( + event=mock_event, + provider=fake_provider, + keep_recent=1, + prompt="", + rounds=rounds, + ) + + # Both must be untouched + assert ltm.contexts[umo] is old_ctxs + assert ltm.summaries[umo] == "existing summary" + # Cooldown set + assert ltm._summary_next_retry[umo] == len(rounds) + SUMMARY_RETRY_COOLDOWN + + @pytest.mark.asyncio + async def test_summary_exception_sets_cooldown(self, mock_event): + """LLM 调用抛异常时设置冷却期。""" + from astrbot.builtin_stars.astrbot.long_term_memory import ( + LongTermMemory, + SUMMARY_RETRY_COOLDOWN, + ) + from astrbot.api.provider import Provider + from unittest.mock import MagicMock, AsyncMock + + fake_provider = MagicMock(spec=Provider) + fake_provider.text_chat = AsyncMock(side_effect=RuntimeError("boom")) + + ctx = MagicMock() + ctx.get_config.return_value = {} + + ltm = LongTermMemory(MagicMock(), ctx) + umo = mock_event.unified_msg_origin + + ctxs = [ + {"role": "user", "content": "q0"}, + {"role": "assistant", "content": "a0"}, + {"role": "user", "content": "q1"}, + {"role": "assistant", "content": "a1"}, + ] + ltm.contexts[umo] = ctxs + ltm.summaries[umo] = "existing summary" + + rounds = _split_into_rounds(ctxs) # 2 rounds + + await ltm._compact_with_llm_summary( + event=mock_event, + provider=fake_provider, + keep_recent=1, + prompt="", + rounds=rounds, + ) + + assert ltm.contexts[umo] is ctxs + assert ltm.summaries[umo] == "existing summary" + assert ltm._summary_next_retry[umo] == len(rounds) + SUMMARY_RETRY_COOLDOWN + + @pytest.mark.asyncio + async def test_summary_success_clears_cooldown(self, mock_event): + """LLM 调用成功时清除冷却标记。""" + from astrbot.builtin_stars.astrbot.long_term_memory import LongTermMemory + from astrbot.api.provider import Provider + from unittest.mock import MagicMock, AsyncMock + + fake_resp = MagicMock() + fake_resp.completion_text = "good summary" + + fake_provider = MagicMock(spec=Provider) + fake_provider.text_chat = AsyncMock(return_value=fake_resp) + + ctx = MagicMock() + ctx.get_config.return_value = {} + + ltm = LongTermMemory(MagicMock(), ctx) + umo = mock_event.unified_msg_origin + # Pre-set cooldown to simulate a previous failure + ltm._summary_next_retry[umo] = 999 + + ctxs = [ + {"role": "user", "content": "q0"}, + {"role": "assistant", "content": "a0"}, + {"role": "user", "content": "q1"}, + {"role": "assistant", "content": "a1"}, + ] + rounds = _split_into_rounds(ctxs) # 2 rounds + + await ltm._compact_with_llm_summary( + event=mock_event, + provider=fake_provider, + keep_recent=1, + prompt="", + rounds=rounds, + ) + + # Cooldown cleared + assert umo not in ltm._summary_next_retry + assert ltm.summaries[umo] == "good summary" + + +# ============================================================================= +# Config defaults +# ============================================================================= + + +class TestConfigDefaults: + @pytest.mark.asyncio + async def test_defaults(self, mock_event): + from astrbot.builtin_stars.astrbot.long_term_memory import LongTermMemory + from unittest.mock import MagicMock + + ctx = MagicMock() + ctx.get_config.return_value = { + "provider_ltm_settings": { + "image_caption": False, "image_caption_provider_id": "", + "active_reply": {"enable": False, "method": "possibility_reply", + "possibility_reply": 0.0, "prompt": "", "whitelist": []}, + }, + "provider_settings": {"image_caption_prompt": ""}, + } + ltm = LongTermMemory(MagicMock(), ctx) + cfg = ltm.cfg(mock_event) + + assert cfg["ltm_compaction_strategy"] == "truncate" + assert cfg["ltm_max_rounds"] == 80 + assert cfg["ltm_truncate_drop_rounds"] == 50 + assert cfg["ltm_summary_trigger_rounds"] == 80 + assert cfg["ltm_summary_keep_recent_rounds"] == 30 + assert cfg["ltm_summary_provider_id"] == "" + assert cfg["ltm_summary_prompt"] == "" + assert cfg["ltm_raw_records_max_bytes"] == 500000