diff --git a/.env.example b/.env.example index 486bfe0a4..d5c8ca36e 100644 --- a/.env.example +++ b/.env.example @@ -64,6 +64,8 @@ KEYWORD_OPTIMIZER_MODEL_NAME= # ================== 网络工具配置 ==================== # Tavily API密钥,用于Tavily网络搜索,申请地址:https://www.tavily.com/ TAVILY_API_KEY= +# Adanos Market Sentiment API 密钥(可选,用于结构化股票与市场情绪研究) +ADANOS_API_KEY= # 网络搜索工具类型,支持BochaAPI或AnspireAPI两种,默认为AnspireAPI SEARCH_TOOL_TYPE=AnspireAPI @@ -74,4 +76,4 @@ ANSPIRE_API_KEY= # Bocha AI Search API(用于Bocha多模态搜索,这里密钥名称虽然是Web Search,但其实是要AI Search的,申请地址:https://open.bochaai.com/) BOCHA_BASE_URL=https://api.bocha.cn/v1/ai-search -BOCHA_WEB_SEARCH_API_KEY= \ No newline at end of file +BOCHA_WEB_SEARCH_API_KEY= diff --git a/QueryEngine/agent.py b/QueryEngine/agent.py index 10859810c..1f3bfd094 100644 --- a/QueryEngine/agent.py +++ b/QueryEngine/agent.py @@ -19,7 +19,7 @@ ReportFormattingNode ) from .state import State -from .tools import TavilyNewsAgency, TavilyResponse +from .tools import AdanosSentimentAgency, TavilyNewsAgency, TavilyResponse from .utils import Settings, format_search_results_for_prompt from loguru import logger @@ -42,6 +42,11 @@ def __init__(self, config: Optional[Settings] = None): # 初始化搜索工具集 self.search_agency = TavilyNewsAgency(api_key=self.config.TAVILY_API_KEY) + self.market_sentiment_agency = ( + AdanosSentimentAgency(api_key=self.config.ADANOS_API_KEY) + if self.config.ADANOS_API_KEY + else None + ) # 初始化节点 self._initialize_nodes() @@ -54,7 +59,10 @@ def __init__(self, config: Optional[Settings] = None): logger.info(f"Query Agent已初始化") logger.info(f"使用LLM: {self.llm_client.get_model_info()}") - logger.info(f"搜索工具集: TavilyNewsAgency (支持6种搜索工具)") + search_tools = "TavilyNewsAgency (支持6种新闻搜索工具)" + if self.market_sentiment_agency: + search_tools += " + AdanosSentimentAgency (结构化市场情绪工具)" + logger.info(f"搜索工具集: {search_tools}") def _initialize_llm(self) -> LLMClient: """初始化LLM客户端""" @@ -66,8 +74,15 @@ def _initialize_llm(self) -> LLMClient: def _initialize_nodes(self): """初始化处理节点""" - self.first_search_node = FirstSearchNode(self.llm_client) - self.reflection_node = ReflectionNode(self.llm_client) + enable_market_sentiment = bool(self.config.ADANOS_API_KEY) + self.first_search_node = FirstSearchNode( + self.llm_client, + enable_market_sentiment=enable_market_sentiment, + ) + self.reflection_node = ReflectionNode( + self.llm_client, + enable_market_sentiment=enable_market_sentiment, + ) self.first_summary_node = FirstSummaryNode(self.llm_client) self.reflection_summary_node = ReflectionSummaryNode(self.llm_client) self.report_formatting_node = ReportFormattingNode(self.llm_client) @@ -96,6 +111,30 @@ def _validate_date_format(self, date_str: str) -> bool: return True except ValueError: return False + + def _build_search_kwargs(self, tool_name: str, tool_output: Dict[str, Any], log_prefix: str) -> tuple[str, Dict[str, Any]]: + """Normalize optional tool arguments and downgrade invalid date searches safely.""" + search_kwargs: Dict[str, Any] = {} + + if tool_name != "search_news_by_date": + return tool_name, search_kwargs + + start_date = tool_output.get("start_date") + end_date = tool_output.get("end_date") + + if start_date and end_date: + if self._validate_date_format(start_date) and self._validate_date_format(end_date): + search_kwargs["start_date"] = start_date + search_kwargs["end_date"] = end_date + logger.info(f"{log_prefix}时间范围: {start_date} 到 {end_date}") + return tool_name, search_kwargs + + logger.info(f"{log_prefix}⚠️ 日期格式错误(应为YYYY-MM-DD),改用基础搜索") + logger.info(f"{log_prefix} 提供的日期: start_date={start_date}, end_date={end_date}") + return "basic_search_news", search_kwargs + + logger.info(f"{log_prefix}⚠️ search_news_by_date工具缺少时间参数,改用基础搜索") + return "basic_search_news", search_kwargs def execute_search_tool(self, tool_name: str, query: str, **kwargs) -> TavilyResponse: """ @@ -109,6 +148,7 @@ def execute_search_tool(self, tool_name: str, query: str, **kwargs) -> TavilyRes - "search_news_last_week": 本周新闻 - "search_images_for_news": 新闻图片搜索 - "search_news_by_date": 按日期范围搜索新闻 + - "search_market_sentiment": 结构化股票与市场情绪研究 query: 搜索查询 **kwargs: 额外参数(如start_date, end_date, max_results) @@ -134,6 +174,16 @@ def execute_search_tool(self, tool_name: str, query: str, **kwargs) -> TavilyRes if not start_date or not end_date: raise ValueError("search_news_by_date工具需要start_date和end_date参数") return self.search_agency.search_news_by_date(query, start_date, end_date) + elif tool_name == "search_market_sentiment": + if not self.market_sentiment_agency: + logger.info(" → Adanos市场情绪工具未配置,跳过结构化情绪研究") + return TavilyResponse( + query=query, + answer="Structured market sentiment tool is not configured.", + results=[], + ) + days = kwargs.get("days", 7) + return self.market_sentiment_agency.search_market_sentiment(query, days=days) else: logger.warning(f" ⚠️ 未知的搜索工具: {tool_name},使用默认基础搜索") return self.search_agency.basic_search_news(query) @@ -238,26 +288,7 @@ def _initial_search_and_summary(self, paragraph_index: int): # 执行搜索 logger.info(" - 执行网络搜索...") - # 处理search_news_by_date的特殊参数 - search_kwargs = {} - if search_tool == "search_news_by_date": - start_date = search_output.get("start_date") - end_date = search_output.get("end_date") - - if start_date and end_date: - # 验证日期格式 - if self._validate_date_format(start_date) and self._validate_date_format(end_date): - search_kwargs["start_date"] = start_date - search_kwargs["end_date"] = end_date - logger.info(f" - 时间范围: {start_date} 到 {end_date}") - else: - logger.info(f" ⚠️ 日期格式错误(应为YYYY-MM-DD),改用基础搜索") - logger.info(f" 提供的日期: start_date={start_date}, end_date={end_date}") - search_tool = "basic_search_news" - else: - logger.info(f" ⚠️ search_news_by_date工具缺少时间参数,改用基础搜索") - search_tool = "basic_search_news" - + search_tool, search_kwargs = self._build_search_kwargs(search_tool, search_output, " - ") search_response = self.execute_search_tool(search_tool, search_query, **search_kwargs) # 转换为兼容格式 @@ -328,27 +359,7 @@ def _reflection_loop(self, paragraph_index: int): logger.info(f" 选择的工具: {search_tool}") logger.info(f" 反思推理: {reasoning}") - # 执行反思搜索 - # 处理search_news_by_date的特殊参数 - search_kwargs = {} - if search_tool == "search_news_by_date": - start_date = reflection_output.get("start_date") - end_date = reflection_output.get("end_date") - - if start_date and end_date: - # 验证日期格式 - if self._validate_date_format(start_date) and self._validate_date_format(end_date): - search_kwargs["start_date"] = start_date - search_kwargs["end_date"] = end_date - logger.info(f" 时间范围: {start_date} 到 {end_date}") - else: - logger.info(f" ⚠️ 日期格式错误(应为YYYY-MM-DD),改用基础搜索") - logger.info(f" 提供的日期: start_date={start_date}, end_date={end_date}") - search_tool = "basic_search_news" - else: - logger.info(f" ⚠️ search_news_by_date工具缺少时间参数,改用基础搜索") - search_tool = "basic_search_news" - + search_tool, search_kwargs = self._build_search_kwargs(search_tool, reflection_output, " ") search_response = self.execute_search_tool(search_tool, search_query, **search_kwargs) # 转换为兼容格式 diff --git a/QueryEngine/nodes/search_node.py b/QueryEngine/nodes/search_node.py index e44ee72da..eefb1f8d2 100644 --- a/QueryEngine/nodes/search_node.py +++ b/QueryEngine/nodes/search_node.py @@ -4,12 +4,12 @@ """ import json -from typing import Dict, Any +from typing import Any, Dict from json.decoder import JSONDecodeError from loguru import logger from .base_node import BaseNode -from ..prompts import SYSTEM_PROMPT_FIRST_SEARCH, SYSTEM_PROMPT_REFLECTION +from ..prompts import build_first_search_prompt, build_reflection_prompt from ..utils.text_processing import ( remove_reasoning_from_output, clean_json_tags, @@ -21,7 +21,7 @@ class FirstSearchNode(BaseNode): """为段落生成首次搜索查询的节点""" - def __init__(self, llm_client): + def __init__(self, llm_client, enable_market_sentiment: bool = False): """ 初始化首次搜索节点 @@ -29,6 +29,7 @@ def __init__(self, llm_client): llm_client: LLM客户端 """ super().__init__(llm_client, "FirstSearchNode") + self.system_prompt = build_first_search_prompt(enable_market_sentiment) def validate_input(self, input_data: Any) -> bool: """验证输入数据""" @@ -42,7 +43,7 @@ def validate_input(self, input_data: Any) -> bool: return "title" in input_data and "content" in input_data return False - def run(self, input_data: Any, **kwargs) -> Dict[str, str]: + def run(self, input_data: Any, **kwargs) -> Dict[str, Any]: """ 调用LLM生成搜索查询和理由 @@ -66,7 +67,7 @@ def run(self, input_data: Any, **kwargs) -> Dict[str, str]: logger.info("正在生成首次搜索查询") # 调用LLM - response = self.llm_client.stream_invoke_to_string(SYSTEM_PROMPT_FIRST_SEARCH, message) + response = self.llm_client.stream_invoke_to_string(self.system_prompt, message) # 处理响应 processed_response = self.process_output(response) @@ -78,7 +79,7 @@ def run(self, input_data: Any, **kwargs) -> Dict[str, str]: logger.exception(f"生成首次搜索查询失败: {str(e)}") raise e - def process_output(self, output: str) -> Dict[str, str]: + def process_output(self, output: str) -> Dict[str, Any]: """ 处理LLM输出,提取搜索查询和推理 @@ -123,6 +124,9 @@ def process_output(self, output: str) -> Dict[str, str]: # 验证和清理结果 search_query = result.get("search_query", "") reasoning = result.get("reasoning", "") + search_tool = result.get("search_tool") or "basic_search_news" + start_date = result.get("start_date") + end_date = result.get("end_date") if not search_query: logger.warning("未找到搜索查询,使用默认查询") @@ -130,7 +134,10 @@ def process_output(self, output: str) -> Dict[str, str]: return { "search_query": search_query, - "reasoning": reasoning + "search_tool": search_tool, + "reasoning": reasoning, + "start_date": start_date, + "end_date": end_date, } except Exception as e: @@ -138,7 +145,7 @@ def process_output(self, output: str) -> Dict[str, str]: # 返回默认查询 return self._get_default_search_query() - def _get_default_search_query(self) -> Dict[str, str]: + def _get_default_search_query(self) -> Dict[str, Any]: """ 获取默认搜索查询 @@ -147,14 +154,17 @@ def _get_default_search_query(self) -> Dict[str, str]: """ return { "search_query": "相关主题研究", - "reasoning": "由于解析失败,使用默认搜索查询" + "search_tool": "basic_search_news", + "reasoning": "由于解析失败,使用默认搜索查询", + "start_date": None, + "end_date": None, } class ReflectionNode(BaseNode): """反思段落并生成新搜索查询的节点""" - def __init__(self, llm_client): + def __init__(self, llm_client, enable_market_sentiment: bool = False): """ 初始化反思节点 @@ -162,6 +172,7 @@ def __init__(self, llm_client): llm_client: LLM客户端 """ super().__init__(llm_client, "ReflectionNode") + self.system_prompt = build_reflection_prompt(enable_market_sentiment) def validate_input(self, input_data: Any) -> bool: """验证输入数据""" @@ -177,7 +188,7 @@ def validate_input(self, input_data: Any) -> bool: return all(field in input_data for field in required_fields) return False - def run(self, input_data: Any, **kwargs) -> Dict[str, str]: + def run(self, input_data: Any, **kwargs) -> Dict[str, Any]: """ 调用LLM反思并生成搜索查询 @@ -201,7 +212,7 @@ def run(self, input_data: Any, **kwargs) -> Dict[str, str]: logger.info("正在进行反思并生成新搜索查询") # 调用LLM - response = self.llm_client.stream_invoke_to_string(SYSTEM_PROMPT_REFLECTION, message) + response = self.llm_client.stream_invoke_to_string(self.system_prompt, message) # 处理响应 processed_response = self.process_output(response) @@ -213,7 +224,7 @@ def run(self, input_data: Any, **kwargs) -> Dict[str, str]: logger.exception(f"反思生成搜索查询失败: {str(e)}") raise e - def process_output(self, output: str) -> Dict[str, str]: + def process_output(self, output: str) -> Dict[str, Any]: """ 处理LLM输出,提取搜索查询和推理 @@ -258,6 +269,9 @@ def process_output(self, output: str) -> Dict[str, str]: # 验证和清理结果 search_query = result.get("search_query", "") reasoning = result.get("reasoning", "") + search_tool = result.get("search_tool") or "basic_search_news" + start_date = result.get("start_date") + end_date = result.get("end_date") if not search_query: logger.warning("未找到搜索查询,使用默认查询") @@ -265,7 +279,10 @@ def process_output(self, output: str) -> Dict[str, str]: return { "search_query": search_query, - "reasoning": reasoning + "search_tool": search_tool, + "reasoning": reasoning, + "start_date": start_date, + "end_date": end_date, } except Exception as e: @@ -273,7 +290,7 @@ def process_output(self, output: str) -> Dict[str, str]: # 返回默认查询 return self._get_default_reflection_query() - def _get_default_reflection_query(self) -> Dict[str, str]: + def _get_default_reflection_query(self) -> Dict[str, Any]: """ 获取默认反思搜索查询 @@ -282,5 +299,8 @@ def _get_default_reflection_query(self) -> Dict[str, str]: """ return { "search_query": "深度研究补充信息", - "reasoning": "由于解析失败,使用默认反思搜索查询" + "search_tool": "basic_search_news", + "reasoning": "由于解析失败,使用默认反思搜索查询", + "start_date": None, + "end_date": None, } diff --git a/QueryEngine/prompts/__init__.py b/QueryEngine/prompts/__init__.py index e395aef58..cb473a346 100644 --- a/QueryEngine/prompts/__init__.py +++ b/QueryEngine/prompts/__init__.py @@ -15,7 +15,9 @@ output_schema_first_summary, output_schema_reflection, output_schema_reflection_summary, - input_schema_report_formatting + input_schema_report_formatting, + build_first_search_prompt, + build_reflection_prompt, ) __all__ = [ @@ -30,5 +32,7 @@ "output_schema_first_summary", "output_schema_reflection", "output_schema_reflection_summary", - "input_schema_report_formatting" + "input_schema_report_formatting", + "build_first_search_prompt", + "build_reflection_prompt", ] diff --git a/QueryEngine/prompts/prompts.py b/QueryEngine/prompts/prompts.py index 3d891e958..f5bffbec8 100644 --- a/QueryEngine/prompts/prompts.py +++ b/QueryEngine/prompts/prompts.py @@ -123,30 +123,7 @@ # ===== 系统提示词定义 ===== -# 生成报告结构的系统提示词 -SYSTEM_PROMPT_REPORT_STRUCTURE = f""" -你是一位深度研究助手。给定一个查询,你需要规划一个报告的结构和其中包含的段落。最多五个段落。 -确保段落的排序合理有序。 -一旦大纲创建完成,你将获得工具来分别为每个部分搜索网络并进行反思。 -请按照以下JSON模式定义格式化输出: - - -{json.dumps(output_schema_report_structure, indent=2, ensure_ascii=False)} - - -标题和内容属性将用于更深入的研究。 -确保输出是一个符合上述输出JSON模式定义的JSON对象。 -只返回JSON对象,不要有解释或额外文本。 -""" - -# 每个段落第一次搜索的系统提示词 -SYSTEM_PROMPT_FIRST_SEARCH = f""" -你是一位深度研究助手。你将获得报告中的一个段落,其标题和预期内容将按照以下JSON模式定义提供: - - -{json.dumps(input_schema_first_search, indent=2, ensure_ascii=False)} - - +BASE_SEARCH_TOOL_SECTION = """ 你可以使用以下6种专业的新闻搜索工具: 1. **basic_search_news** - 基础新闻搜索工具 @@ -174,12 +151,42 @@ - 特点:可以指定开始和结束日期进行搜索 - 特殊要求:需要提供start_date和end_date参数,格式为'YYYY-MM-DD' - 注意:只有这个工具需要额外的时间参数 +""".strip() + +MARKET_SENTIMENT_TOOL_SECTION = """ + +7. **search_market_sentiment** - 结构化市场情绪研究工具 + - 适用于:股票、ETF、市场情绪、看多看空、舆情热度、市场关注度等金融研究问题 + - 特点:直接返回来自 Reddit、X、财经新闻、Polymarket 的结构化市场情绪指标,而非普通网页搜索结果 + - 特殊建议:如果研究特定股票,请在查询中尽量包含明确股票代码(例如 AAPL、TSLA、NVDA) +""".rstrip() + + +def _build_tool_section(enable_market_sentiment: bool) -> str: + return BASE_SEARCH_TOOL_SECTION + (MARKET_SENTIMENT_TOOL_SECTION if enable_market_sentiment else "") + + +def build_first_search_prompt(enable_market_sentiment: bool = False) -> str: + tool_section = _build_tool_section(enable_market_sentiment) + market_note = ( + "\n5. 如果主题明确是股票、ETF 或市场情绪分析,并且存在结构化市场情绪工具,请优先考虑使用该工具。" + if enable_market_sentiment + else "" + ) + return f""" +你是一位深度研究助手。你将获得报告中的一个段落,其标题和预期内容将按照以下JSON模式定义提供: + + +{json.dumps(input_schema_first_search, indent=2, ensure_ascii=False)} + + +{tool_section} 你的任务是: 1. 根据段落主题选择最合适的搜索工具 2. 制定最佳的搜索查询 3. 如果选择search_news_by_date工具,必须同时提供start_date和end_date参数(格式:YYYY-MM-DD) -4. 解释你的选择理由 +4. 解释你的选择理由{market_note} 5. 仔细核查新闻中的可疑点,破除谣言和误导,尽力还原事件原貌 注意:除了search_news_by_date工具外,其他工具都不需要额外参数。 @@ -189,10 +196,65 @@ {json.dumps(output_schema_first_search, indent=2, ensure_ascii=False)} +确保输出是一个符合上述输出JSON模式定义的JSON对象。 +只返回JSON对象,不要有解释或额外文本。 +""".strip() + + +def build_reflection_prompt(enable_market_sentiment: bool = False) -> str: + tool_section = _build_tool_section(enable_market_sentiment) + market_note = ( + "\n6. 如果当前段落明显属于股票或市场分析,并且结构化市场情绪工具可用,应优先用它补齐定量证据。" + if enable_market_sentiment + else "" + ) + return f""" +你是一位深度研究助手。你负责为研究报告构建全面的段落。你将获得段落标题、计划内容摘要,以及你已经创建的段落最新状态,所有这些都将按照以下JSON模式定义提供: + + +{json.dumps(input_schema_reflection, indent=2, ensure_ascii=False)} + + +{tool_section} + +你的任务是: +1. 反思段落文本的当前状态,思考是否遗漏了主题的某些关键方面 +2. 选择最合适的搜索工具来补充缺失信息 +3. 制定精确的搜索查询 +4. 如果选择search_news_by_date工具,必须同时提供start_date和end_date参数(格式:YYYY-MM-DD) +5. 解释你的选择和推理{market_note} +6. 仔细核查新闻中的可疑点,破除谣言和误导,尽力还原事件原貌 + +注意:除了search_news_by_date工具外,其他工具都不需要额外参数。 +请按照以下JSON模式定义格式化输出: + + +{json.dumps(output_schema_reflection, indent=2, ensure_ascii=False)} + + +确保输出是一个符合上述输出JSON模式定义的JSON对象。 +只返回JSON对象,不要有解释或额外文本。 +""".strip() + +# 生成报告结构的系统提示词 +SYSTEM_PROMPT_REPORT_STRUCTURE = f""" +你是一位深度研究助手。给定一个查询,你需要规划一个报告的结构和其中包含的段落。最多五个段落。 +确保段落的排序合理有序。 +一旦大纲创建完成,你将获得工具来分别为每个部分搜索网络并进行反思。 +请按照以下JSON模式定义格式化输出: + + +{json.dumps(output_schema_report_structure, indent=2, ensure_ascii=False)} + + +标题和内容属性将用于更深入的研究。 确保输出是一个符合上述输出JSON模式定义的JSON对象。 只返回JSON对象,不要有解释或额外文本。 """ +# 每个段落第一次搜索的系统提示词 +SYSTEM_PROMPT_FIRST_SEARCH = build_first_search_prompt(enable_market_sentiment=False) + # 每个段落第一次总结的系统提示词 SYSTEM_PROMPT_FIRST_SUMMARY = f""" 你是一位专业的新闻分析师和深度内容创作专家。你将获得搜索查询、搜索结果以及你正在研究的报告段落,数据将按照以下JSON模式定义提供: @@ -268,40 +330,7 @@ """ # 反思(Reflect)的系统提示词 -SYSTEM_PROMPT_REFLECTION = f""" -你是一位深度研究助手。你负责为研究报告构建全面的段落。你将获得段落标题、计划内容摘要,以及你已经创建的段落最新状态,所有这些都将按照以下JSON模式定义提供: - - -{json.dumps(input_schema_reflection, indent=2, ensure_ascii=False)} - - -你可以使用以下6种专业的新闻搜索工具: - -1. **basic_search_news** - 基础新闻搜索工具 -2. **deep_search_news** - 深度新闻分析工具 -3. **search_news_last_24_hours** - 24小时最新新闻工具 -4. **search_news_last_week** - 本周新闻工具 -5. **search_images_for_news** - 图片搜索工具 -6. **search_news_by_date** - 按日期范围搜索工具(需要时间参数) - -你的任务是: -1. 反思段落文本的当前状态,思考是否遗漏了主题的某些关键方面 -2. 选择最合适的搜索工具来补充缺失信息 -3. 制定精确的搜索查询 -4. 如果选择search_news_by_date工具,必须同时提供start_date和end_date参数(格式:YYYY-MM-DD) -5. 解释你的选择和推理 -6. 仔细核查新闻中的可疑点,破除谣言和误导,尽力还原事件原貌 - -注意:除了search_news_by_date工具外,其他工具都不需要额外参数。 -请按照以下JSON模式定义格式化输出: - - -{json.dumps(output_schema_reflection, indent=2, ensure_ascii=False)} - - -确保输出是一个符合上述输出JSON模式定义的JSON对象。 -只返回JSON对象,不要有解释或额外文本。 -""" +SYSTEM_PROMPT_REFLECTION = build_reflection_prompt(enable_market_sentiment=False) # 总结反思的系统提示词 SYSTEM_PROMPT_REFLECTION_SUMMARY = f""" diff --git a/QueryEngine/tools/__init__.py b/QueryEngine/tools/__init__.py index aa055685d..944ab9677 100644 --- a/QueryEngine/tools/__init__.py +++ b/QueryEngine/tools/__init__.py @@ -10,9 +10,11 @@ ImageResult, print_response_summary ) +from .market_sentiment import AdanosSentimentAgency __all__ = [ "TavilyNewsAgency", + "AdanosSentimentAgency", "SearchResult", "TavilyResponse", "ImageResult", diff --git a/QueryEngine/tools/market_sentiment.py b/QueryEngine/tools/market_sentiment.py new file mode 100644 index 000000000..7aeb93cd8 --- /dev/null +++ b/QueryEngine/tools/market_sentiment.py @@ -0,0 +1,335 @@ +""" +Optional structured market sentiment research tools powered by Adanos. + +The QueryEngine consumes external tools through a news-like result interface. +This module adapts Adanos stock and market sentiment data into that shape so the +existing summary/report nodes can reuse it without special handling. +""" + +from __future__ import annotations + +import os +import re +import sys +from statistics import mean +from typing import Any, Dict, Iterable, List, Optional + +import requests + +current_dir = os.path.dirname(os.path.abspath(__file__)) +root_dir = os.path.dirname(os.path.dirname(current_dir)) +utils_dir = os.path.join(root_dir, "utils") +if utils_dir not in sys.path: + sys.path.append(utils_dir) + +from retry_helper import SEARCH_API_RETRY_CONFIG, with_graceful_retry + +from .search import SearchResult, TavilyResponse + + +class AdanosSentimentAgency: + """Optional adapter around the Adanos Market Sentiment API.""" + + _BASE_URL = "https://api.adanos.org" + _DOCS_URL = "https://api.adanos.org/docs" + _SOURCES = ("news", "reddit", "x", "polymarket") + _NON_TICKER_TOKENS = { + "A", + "AN", + "AND", + "CPI", + "EPS", + "ETF", + "ETFS", + "EUR", + "FED", + "FOMC", + "FOR", + "GDP", + "IPO", + "NEWS", + "PCE", + "PE", + "SEC", + "STOCK", + "THE", + "USD", + "USA", + "WITH", + } + + def __init__(self, api_key: Optional[str] = None): + api_key = api_key or os.getenv("ADANOS_API_KEY") + if not api_key: + raise ValueError("Adanos API Key未找到!请设置ADANOS_API_KEY环境变量或在初始化时提供") + + self._api_key = api_key + self._session = requests.Session() + + def search_market_sentiment(self, query: str, days: int = 7) -> TavilyResponse: + """ + Return stock-specific or market-wide sentiment as SearchResult-like items. + + If the query contains explicit tickers (AAPL, $TSLA, BRK.A), the tool returns + a cross-source stock snapshot and source-level breakdown. Otherwise it returns + service-level market sentiment across the supported sources. + """ + tickers = self._extract_tickers(query) + if tickers: + return self._search_stock_sentiment(query=query, tickers=tickers[:3], days=days) + return self._search_market_overview(query=query, days=days) + + @with_graceful_retry(SEARCH_API_RETRY_CONFIG, default_return=None) + def _request_json(self, path: str, params: Dict[str, Any]) -> Optional[Dict[str, Any]]: + response = self._session.get( + f"{self._BASE_URL}{path}", + params=params, + headers={"X-API-Key": self._api_key}, + timeout=20, + ) + + if response.status_code == 404: + return None + + response.raise_for_status() + return response.json() + + def _search_stock_sentiment(self, query: str, tickers: List[str], days: int) -> TavilyResponse: + results: List[SearchResult] = [] + + for ticker in tickers: + snapshots: Dict[str, Dict[str, Any]] = {} + for source in self._SOURCES: + payload = self._request_json(f"/{source}/stocks/v1/stock/{ticker}", {"days": days}) + if payload and payload.get("found"): + snapshots[source] = payload + + if not snapshots: + continue + + results.append(self._build_stock_overview_result(ticker, snapshots, days)) + for source, payload in snapshots.items(): + results.append(self._build_stock_source_result(ticker, source, payload, days)) + + answer = ( + f"Found structured Adanos sentiment coverage for {', '.join(tickers)}." + if results + else "No Adanos sentiment data found for the requested ticker(s)." + ) + return TavilyResponse(query=query, answer=answer, results=results) + + def _search_market_overview(self, query: str, days: int) -> TavilyResponse: + snapshots: Dict[str, Dict[str, Any]] = {} + for source in self._SOURCES: + payload = self._request_json(f"/{source}/stocks/v1/market-sentiment", {"days": days}) + if payload: + snapshots[source] = payload + + if not snapshots: + return TavilyResponse( + query=query, + answer="No Adanos market-wide sentiment sources were available.", + results=[], + ) + + results = [self._build_market_overview_result(snapshots, days)] + results.extend( + self._build_market_source_result(source=source, payload=payload, days=days) + for source, payload in snapshots.items() + ) + + return TavilyResponse( + query=query, + answer="Structured cross-source market sentiment snapshot retrieved from Adanos.", + results=results, + ) + + def _build_stock_overview_result( + self, ticker: str, snapshots: Dict[str, Dict[str, Any]], days: int + ) -> SearchResult: + buzz_values = [payload.get("buzz_score") for payload in snapshots.values() if payload.get("buzz_score") is not None] + bullish_values = [ + payload.get("bullish_pct") for payload in snapshots.values() if payload.get("bullish_pct") is not None + ] + sentiment_values = [ + payload.get("sentiment_score") for payload in snapshots.values() if payload.get("sentiment_score") is not None + ] + + lines = [ + f"Ticker: {ticker}", + f"Lookback window: {days} days", + f"Sources with signal: {', '.join(source.upper() for source in snapshots.keys())}", + f"Average buzz score: {mean(buzz_values):.1f}" if buzz_values else "Average buzz score: unavailable", + ( + f"Average bullish percentage: {mean(bullish_values):.1f}%" + if bullish_values + else "Average bullish percentage: unavailable" + ), + self._format_alignment(sentiment_values), + "Source snapshots:", + ] + for source, payload in snapshots.items(): + lines.append( + f"- {source.upper()}: buzz {self._format_number(payload.get('buzz_score'))}, " + f"bullish {self._format_percent(payload.get('bullish_pct'))}, " + f"sentiment {self._format_number(payload.get('sentiment_score'))}, " + f"activity {self._format_activity(payload)}" + ) + + return SearchResult( + title=f"{ticker} cross-source market sentiment snapshot", + url=self._DOCS_URL, + content="\n".join(lines), + score=float(mean(buzz_values)) if buzz_values else None, + ) + + def _build_stock_source_result( + self, ticker: str, source: str, payload: Dict[str, Any], days: int + ) -> SearchResult: + lines = [ + f"Ticker: {ticker}", + f"Source: {source.upper()}", + f"Lookback window: {days} days", + f"Buzz score: {self._format_number(payload.get('buzz_score'))}", + f"Sentiment score: {self._format_number(payload.get('sentiment_score'))}", + f"Bullish percentage: {self._format_percent(payload.get('bullish_pct'))}", + f"Trend: {payload.get('trend') or 'unknown'}", + f"Activity: {self._format_activity(payload)}", + ] + + return SearchResult( + title=f"{ticker} {source.upper()} sentiment details", + url=self._DOCS_URL, + content="\n".join(lines), + score=payload.get("buzz_score"), + ) + + def _build_market_overview_result( + self, snapshots: Dict[str, Dict[str, Any]], days: int + ) -> SearchResult: + buzz_values = [payload.get("buzz_score") for payload in snapshots.values() if payload.get("buzz_score") is not None] + bullish_values = [ + payload.get("bullish_pct") for payload in snapshots.values() if payload.get("bullish_pct") is not None + ] + sentiment_values = [ + payload.get("sentiment_score") for payload in snapshots.values() if payload.get("sentiment_score") is not None + ] + + lines = [ + f"Lookback window: {days} days", + f"Covered sources: {', '.join(source.upper() for source in snapshots.keys())}", + f"Average market buzz score: {mean(buzz_values):.1f}" if buzz_values else "Average market buzz score: unavailable", + ( + f"Average bullish percentage: {mean(bullish_values):.1f}%" + if bullish_values + else "Average bullish percentage: unavailable" + ), + self._format_alignment(sentiment_values), + "Top drivers by source:", + ] + + for source, payload in snapshots.items(): + drivers = payload.get("drivers") or [] + if drivers: + top_driver = drivers[0] + driver_text = ( + f"{top_driver.get('ticker')} (buzz {self._format_number(top_driver.get('buzz_score'))}, " + f"sentiment {self._format_number(top_driver.get('sentiment_score'))})" + ) + else: + driver_text = "no driver data" + lines.append(f"- {source.upper()}: {driver_text}") + + return SearchResult( + title="Cross-source US market sentiment overview", + url=self._DOCS_URL, + content="\n".join(lines), + score=float(mean(buzz_values)) if buzz_values else None, + ) + + def _build_market_source_result(self, source: str, payload: Dict[str, Any], days: int) -> SearchResult: + lines = [ + f"Source: {source.upper()}", + f"Lookback window: {days} days", + f"Buzz score: {self._format_number(payload.get('buzz_score'))}", + f"Sentiment score: {self._format_number(payload.get('sentiment_score'))}", + f"Bullish percentage: {self._format_percent(payload.get('bullish_pct'))}", + f"Trend: {payload.get('trend') or 'unknown'}", + f"Activity breadth: {self._format_market_activity(payload)}", + ] + drivers = payload.get("drivers") or [] + if drivers: + lines.append("Top drivers:") + for driver in drivers[:3]: + lines.append( + f"- {driver.get('ticker')}: buzz {self._format_number(driver.get('buzz_score'))}, " + f"sentiment {self._format_number(driver.get('sentiment_score'))}" + ) + + return SearchResult( + title=f"{source.upper()} market sentiment overview", + url=self._DOCS_URL, + content="\n".join(lines), + score=payload.get("buzz_score"), + ) + + @staticmethod + def _extract_tickers(query: str) -> List[str]: + candidates = re.findall(r"\$?[A-Za-z]{1,5}(?:\.[A-Za-z])?\b", query) + seen = set() + tickers = [] + for raw_candidate in candidates: + candidate = raw_candidate.lstrip("$") + if "$" not in raw_candidate and candidate.upper() != raw_candidate: + continue + candidate = candidate.upper() + if candidate in AdanosSentimentAgency._NON_TICKER_TOKENS: + continue + if candidate not in seen: + seen.add(candidate) + tickers.append(candidate) + return tickers + + @staticmethod + def _format_alignment(sentiment_values: Iterable[Optional[float]]) -> str: + numeric_values = [float(value) for value in sentiment_values if value is not None] + if len(numeric_values) < 2: + return "Cross-source alignment: insufficient signal" + spread = max(numeric_values) - min(numeric_values) + if spread < 0.15: + label = "strongly aligned" + elif spread < 0.35: + label = "moderately aligned" + else: + label = "divergent" + return f"Cross-source alignment: {label}" + + @staticmethod + def _format_market_activity(payload: Dict[str, Any]) -> str: + if "mentions" in payload: + return f"{payload.get('mentions', 0)} mentions across {payload.get('active_tickers', 0)} active tickers" + if "trade_count" in payload: + return f"{payload.get('trade_count', 0)} trades across {payload.get('active_tickers', 0)} active tickers" + return "activity unavailable" + + @staticmethod + def _format_activity(payload: Dict[str, Any]) -> str: + if payload.get("mentions") is not None: + return f"{int(payload.get('mentions', 0))} mentions" + if payload.get("trade_count") is not None: + return f"{int(payload.get('trade_count', 0))} trades" + if payload.get("unique_tweets") is not None: + return f"{int(payload.get('unique_tweets', 0))} tweets" + return "unavailable" + + @staticmethod + def _format_number(value: Any) -> str: + if value is None: + return "unavailable" + return f"{float(value):.2f}" + + @staticmethod + def _format_percent(value: Any) -> str: + if value is None: + return "unavailable" + return f"{float(value):.0f}%" diff --git a/QueryEngine/utils/config.py b/QueryEngine/utils/config.py index 19128b40c..a281bfc25 100644 --- a/QueryEngine/utils/config.py +++ b/QueryEngine/utils/config.py @@ -33,6 +33,7 @@ class Settings(BaseSettings): # ================== 网络工具配置 ==================== TAVILY_API_KEY: str = Field(..., description="Tavily API(申请地址:https://www.tavily.com/)API密钥,用于Tavily网络搜索") + ADANOS_API_KEY: Optional[str] = Field(None, description="Adanos Market Sentiment API 密钥。可选;启用后 Query Agent 可调用结构化股票/市场情绪研究工具。") # ================== 搜索参数配置 ==================== SEARCH_TIMEOUT: int = Field(240, description="搜索超时(秒)") @@ -67,6 +68,7 @@ def print_config(config: Settings): message += f"LLM 模型: {config.QUERY_ENGINE_MODEL_NAME}\n" message += f"LLM Base URL: {config.QUERY_ENGINE_BASE_URL or '(默认)'}\n" message += f"Tavily API Key: {'已配置' if config.TAVILY_API_KEY else '未配置'}\n" + message += f"Adanos API Key: {'已配置' if config.ADANOS_API_KEY else '未配置'}\n" message += f"搜索超时: {config.SEARCH_TIMEOUT} 秒\n" message += f"最长内容长度: {config.SEARCH_CONTENT_MAX_LENGTH}\n" message += f"最大反思次数: {config.MAX_REFLECTIONS}\n" diff --git a/README-EN.md b/README-EN.md index aaaab4c49..00d7128db 100644 --- a/README-EN.md +++ b/README-EN.md @@ -57,6 +57,8 @@ Beyond just report quality, compared to similar products, we have 🚀 six major > For example, you only need to simply modify the API parameters and prompts of the Agent toolset to transform it into a financial market analysis system. > +> This repository now also supports an optional `ADANOS_API_KEY` path for Query Agent, enabling structured stock and market sentiment research without changing the default public-opinion workflow. +> > Here's a relatively active Linux.do project discussion thread: https://linux.do/t/topic/1009280 > > Check out the comparison by a Linux.do fellow: [Open Source Project (BettaFish) vs manus|minimax|ChatGPT Comparison](https://linux.do/t/topic/1148040) diff --git a/README.md b/README.md index 27926cd2e..ece85f86d 100644 --- a/README.md +++ b/README.md @@ -58,6 +58,8 @@ > 举个例子. 你只需简单修改Agent工具集的api参数与prompt,就可以把他变成一个金融领域的市场分析系统 > +> 现在仓库也支持一个可选的 `ADANOS_API_KEY` 路径,让 Query Agent 在不改变默认舆情工作流的前提下,补充结构化的股票与市场情绪研究能力。 +> > 附一个比较活跃的L站项目讨论帖:https://linux.do/t/topic/1009280 > > 查看L站佬友做的测评 [开源项目(微舆)与manus|minimax|ChatGPT|Perplexity对比](https://linux.do/t/topic/1148040) diff --git a/app.py b/app.py index 8ad9422f5..1edb46875 100644 --- a/app.py +++ b/app.py @@ -111,6 +111,7 @@ def _safe_finish(self, *args, **kwargs): # pragma: no cover - 运行时才会 'KEYWORD_OPTIMIZER_BASE_URL', 'KEYWORD_OPTIMIZER_MODEL_NAME', 'TAVILY_API_KEY', + 'ADANOS_API_KEY', 'SEARCH_TOOL_TYPE', 'BOCHA_WEB_SEARCH_API_KEY', 'ANSPIRE_API_KEY' diff --git a/config.py b/config.py index d5005a5ed..61e432d7b 100644 --- a/config.py +++ b/config.py @@ -79,6 +79,7 @@ class Settings(BaseSettings): # ================== 网络工具配置 ==================== # Tavily API(申请地址:https://www.tavily.com/) TAVILY_API_KEY: Optional[str] = Field(None, description="Tavily API(申请地址:https://www.tavily.com/)API密钥,用于Tavily网络搜索") + ADANOS_API_KEY: Optional[str] = Field(None, description="Adanos Market Sentiment API 密钥。可选;启用后 Query Agent 可调用结构化股票/市场情绪研究工具。") SEARCH_TOOL_TYPE: Literal["AnspireAPI", "BochaAPI"] = Field("AnspireAPI", description="网络搜索工具类型,支持BochaAPI或AnspireAPI两种,默认为AnspireAPI") # Bocha API(申请地址:https://open.bochaai.com/) diff --git a/templates/index.html b/templates/index.html index 4219eaab3..a27edc473 100644 --- a/templates/index.html +++ b/templates/index.html @@ -2108,7 +2108,7 @@ }, { title: '外部检索工具', - subtitle: '联动搜索引擎、网站抓取等在线服务,两个都需配置', + subtitle: '联动搜索引擎、网站抓取和可选的结构化市场情绪研究服务', fields: [ { key: 'SEARCH_TOOL_TYPE', @@ -2117,6 +2117,7 @@ options: ['BochaAPI', 'AnspireAPI'] }, { key: 'TAVILY_API_KEY', label: 'Tavily API Key', type: 'password' }, + { key: 'ADANOS_API_KEY', label: 'Adanos API Key(可选)', type: 'password' }, { key: 'BOCHA_WEB_SEARCH_API_KEY', label: 'Bocha API Key', type: 'password', condition: { key: 'SEARCH_TOOL_TYPE', value: 'BochaAPI' } }, { key: 'ANSPIRE_API_KEY', label: 'Anspire API Key', type: 'password', condition: { key: 'SEARCH_TOOL_TYPE', value: 'AnspireAPI' } } ] diff --git a/tests/test_query_engine_market_sentiment.py b/tests/test_query_engine_market_sentiment.py new file mode 100644 index 000000000..95cc764f8 --- /dev/null +++ b/tests/test_query_engine_market_sentiment.py @@ -0,0 +1,231 @@ +import os +import sys +import types +from pathlib import Path + + +PROJECT_ROOT = Path(__file__).resolve().parent.parent +if str(PROJECT_ROOT) not in sys.path: + sys.path.insert(0, str(PROJECT_ROOT)) + +os.environ.setdefault("QUERY_ENGINE_API_KEY", "test-query-key") +os.environ.setdefault("QUERY_ENGINE_MODEL_NAME", "test-model") +os.environ.setdefault("TAVILY_API_KEY", "test-tavily-key") + +if "tavily" not in sys.modules: + tavily_module = types.ModuleType("tavily") + + class DummyTavilyClient: + def __init__(self, *args, **kwargs): + pass + + def search(self, **kwargs): + return { + "query": kwargs.get("query", ""), + "results": [], + "images": [], + } + + tavily_module.TavilyClient = DummyTavilyClient + sys.modules["tavily"] = tavily_module + +if "openai" not in sys.modules: + openai_module = types.ModuleType("openai") + + class DummyOpenAI: + def __init__(self, *args, **kwargs): + pass + + openai_module.OpenAI = DummyOpenAI + sys.modules["openai"] = openai_module + +from QueryEngine.agent import DeepSearchAgent +from QueryEngine.nodes.search_node import FirstSearchNode, ReflectionNode +from QueryEngine.prompts.prompts import build_first_search_prompt, build_reflection_prompt +from QueryEngine.tools.market_sentiment import AdanosSentimentAgency +from QueryEngine.tools.search import TavilyResponse + + +def test_market_sentiment_prompt_toggle(): + first_prompt = build_first_search_prompt(enable_market_sentiment=False) + reflection_prompt = build_reflection_prompt(enable_market_sentiment=False) + + assert "search_market_sentiment" not in first_prompt + assert "search_market_sentiment" not in reflection_prompt + + enabled_first_prompt = build_first_search_prompt(enable_market_sentiment=True) + enabled_reflection_prompt = build_reflection_prompt(enable_market_sentiment=True) + + assert "search_market_sentiment" in enabled_first_prompt + assert "search_market_sentiment" in enabled_reflection_prompt + + +def test_first_search_node_preserves_tool_and_dates(): + node = FirstSearchNode(llm_client=None, enable_market_sentiment=True) + result = node.process_output( + """ + { + "search_query": "AAPL market sentiment", + "search_tool": "search_market_sentiment", + "reasoning": "需要结构化情绪数据", + "start_date": "2026-04-01", + "end_date": "2026-04-07" + } + """ + ) + + assert result["search_query"] == "AAPL market sentiment" + assert result["search_tool"] == "search_market_sentiment" + assert result["start_date"] == "2026-04-01" + assert result["end_date"] == "2026-04-07" + + +def test_reflection_node_preserves_tool_selection(): + node = ReflectionNode(llm_client=None, enable_market_sentiment=True) + result = node.process_output( + """ + { + "search_query": "TSLA bulls vs bears", + "search_tool": "search_market_sentiment", + "reasoning": "需要补齐量化情绪证据" + } + """ + ) + + assert result["search_query"] == "TSLA bulls vs bears" + assert result["search_tool"] == "search_market_sentiment" + assert result["start_date"] is None + assert result["end_date"] is None + + +def test_adanos_sentiment_agency_formats_stock_results(): + agency = AdanosSentimentAgency(api_key="test-adanos-key") + + def fake_request(path, params): + if path == "/reddit/stocks/v1/stock/AAPL": + return { + "found": True, + "buzz_score": 74.2, + "sentiment_score": 0.41, + "bullish_pct": 68, + "trend": "rising", + "mentions": 120, + } + if path == "/x/stocks/v1/stock/AAPL": + return { + "found": True, + "buzz_score": 61.8, + "sentiment_score": 0.22, + "bullish_pct": 57, + "trend": "stable", + "unique_tweets": 84, + } + return None + + agency._request_json = fake_request + + response = agency.search_market_sentiment("Evaluate AAPL sentiment", days=14) + + assert response.query == "Evaluate AAPL sentiment" + assert response.results + assert response.results[0].title == "AAPL cross-source market sentiment snapshot" + assert "Average buzz score" in response.results[0].content + assert any(result.title == "AAPL REDDIT sentiment details" for result in response.results) + assert any(result.title == "AAPL X sentiment details" for result in response.results) + assert all(result.url == agency._DOCS_URL for result in response.results) + + +def test_adanos_sentiment_agency_formats_market_overview(): + agency = AdanosSentimentAgency(api_key="test-adanos-key") + + def fake_request(path, params): + if path == "/news/stocks/v1/market-sentiment": + return { + "buzz_score": 55.0, + "sentiment_score": 0.18, + "bullish_pct": 54, + "trend": "stable", + "mentions": 540, + "active_tickers": 31, + "drivers": [{"ticker": "NVDA", "buzz_score": 82.1, "sentiment_score": 0.44}], + } + if path == "/reddit/stocks/v1/market-sentiment": + return { + "buzz_score": 63.5, + "sentiment_score": 0.27, + "bullish_pct": 61, + "trend": "rising", + "mentions": 710, + "active_tickers": 45, + "drivers": [{"ticker": "PLTR", "buzz_score": 76.0, "sentiment_score": 0.39}], + } + return None + + agency._request_json = fake_request + + response = agency.search_market_sentiment("What is the current stock market mood?", days=7) + + assert response.results + assert response.results[0].title == "Cross-source US market sentiment overview" + assert "Covered sources: NEWS, REDDIT" in response.results[0].content + assert any(result.title == "NEWS market sentiment overview" for result in response.results) + + +def test_adanos_sentiment_agency_ignores_common_finance_acronyms(): + agency = AdanosSentimentAgency(api_key="test-adanos-key") + + def fake_request(path, params): + if path == "/news/stocks/v1/market-sentiment": + return { + "buzz_score": 58.0, + "sentiment_score": 0.14, + "bullish_pct": 53, + "trend": "stable", + "mentions": 321, + "active_tickers": 20, + } + return None + + agency._request_json = fake_request + + response = agency.search_market_sentiment("What ETF sentiment drivers matter now?", days=7) + + assert response.results + assert response.results[0].title == "Cross-source US market sentiment overview" + + +def test_execute_search_tool_market_sentiment_is_fail_open(): + agent = DeepSearchAgent.__new__(DeepSearchAgent) + agent.market_sentiment_agency = None + + response = agent.execute_search_tool("search_market_sentiment", "AAPL") + + assert isinstance(response, TavilyResponse) + assert response.results == [] + assert "not configured" in (response.answer or "").lower() + + +def test_build_search_kwargs_keeps_valid_date_range(): + agent = DeepSearchAgent.__new__(DeepSearchAgent) + + tool_name, search_kwargs = agent._build_search_kwargs( + "search_news_by_date", + {"start_date": "2026-04-01", "end_date": "2026-04-07"}, + " - ", + ) + + assert tool_name == "search_news_by_date" + assert search_kwargs == {"start_date": "2026-04-01", "end_date": "2026-04-07"} + + +def test_build_search_kwargs_downgrades_invalid_date_range(): + agent = DeepSearchAgent.__new__(DeepSearchAgent) + + tool_name, search_kwargs = agent._build_search_kwargs( + "search_news_by_date", + {"start_date": "2026/04/01", "end_date": "2026-04-07"}, + " - ", + ) + + assert tool_name == "basic_search_news" + assert search_kwargs == {}