Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion .env.example
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,8 @@ KEYWORD_OPTIMIZER_MODEL_NAME=
# ================== 网络工具配置 ====================
# Tavily API密钥,用于Tavily网络搜索,申请地址:https://www.tavily.com/
TAVILY_API_KEY=
# Adanos Market Sentiment API 密钥(可选,用于结构化股票与市场情绪研究)
ADANOS_API_KEY=

# 网络搜索工具类型,支持BochaAPI或AnspireAPI两种,默认为AnspireAPI
SEARCH_TOOL_TYPE=AnspireAPI
Expand All @@ -74,4 +76,4 @@ ANSPIRE_API_KEY=

# Bocha AI Search API(用于Bocha多模态搜索,这里密钥名称虽然是Web Search,但其实是要AI Search的,申请地址:https://open.bochaai.com/)
BOCHA_BASE_URL=https://api.bocha.cn/v1/ai-search
BOCHA_WEB_SEARCH_API_KEY=
BOCHA_WEB_SEARCH_API_KEY=
101 changes: 56 additions & 45 deletions QueryEngine/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
ReportFormattingNode
)
from .state import State
from .tools import TavilyNewsAgency, TavilyResponse
from .tools import AdanosSentimentAgency, TavilyNewsAgency, TavilyResponse
from .utils import Settings, format_search_results_for_prompt
from loguru import logger

Expand All @@ -42,6 +42,11 @@ def __init__(self, config: Optional[Settings] = None):

# 初始化搜索工具集
self.search_agency = TavilyNewsAgency(api_key=self.config.TAVILY_API_KEY)
self.market_sentiment_agency = (
AdanosSentimentAgency(api_key=self.config.ADANOS_API_KEY)
if self.config.ADANOS_API_KEY
else None
)

# 初始化节点
self._initialize_nodes()
Expand All @@ -54,7 +59,10 @@ def __init__(self, config: Optional[Settings] = None):

logger.info(f"Query Agent已初始化")
logger.info(f"使用LLM: {self.llm_client.get_model_info()}")
logger.info(f"搜索工具集: TavilyNewsAgency (支持6种搜索工具)")
search_tools = "TavilyNewsAgency (支持6种新闻搜索工具)"
if self.market_sentiment_agency:
search_tools += " + AdanosSentimentAgency (结构化市场情绪工具)"
logger.info(f"搜索工具集: {search_tools}")

def _initialize_llm(self) -> LLMClient:
"""初始化LLM客户端"""
Expand All @@ -66,8 +74,15 @@ def _initialize_llm(self) -> LLMClient:

def _initialize_nodes(self):
"""初始化处理节点"""
self.first_search_node = FirstSearchNode(self.llm_client)
self.reflection_node = ReflectionNode(self.llm_client)
enable_market_sentiment = bool(self.config.ADANOS_API_KEY)
self.first_search_node = FirstSearchNode(
self.llm_client,
enable_market_sentiment=enable_market_sentiment,
)
self.reflection_node = ReflectionNode(
self.llm_client,
enable_market_sentiment=enable_market_sentiment,
)
self.first_summary_node = FirstSummaryNode(self.llm_client)
self.reflection_summary_node = ReflectionSummaryNode(self.llm_client)
self.report_formatting_node = ReportFormattingNode(self.llm_client)
Expand Down Expand Up @@ -96,6 +111,30 @@ def _validate_date_format(self, date_str: str) -> bool:
return True
except ValueError:
return False

def _build_search_kwargs(self, tool_name: str, tool_output: Dict[str, Any], log_prefix: str) -> tuple[str, Dict[str, Any]]:
"""Normalize optional tool arguments and downgrade invalid date searches safely."""
search_kwargs: Dict[str, Any] = {}

if tool_name != "search_news_by_date":
return tool_name, search_kwargs

start_date = tool_output.get("start_date")
end_date = tool_output.get("end_date")

if start_date and end_date:
if self._validate_date_format(start_date) and self._validate_date_format(end_date):
search_kwargs["start_date"] = start_date
search_kwargs["end_date"] = end_date
logger.info(f"{log_prefix}时间范围: {start_date} 到 {end_date}")
return tool_name, search_kwargs

logger.info(f"{log_prefix}⚠️ 日期格式错误(应为YYYY-MM-DD),改用基础搜索")
logger.info(f"{log_prefix} 提供的日期: start_date={start_date}, end_date={end_date}")
return "basic_search_news", search_kwargs

logger.info(f"{log_prefix}⚠️ search_news_by_date工具缺少时间参数,改用基础搜索")
return "basic_search_news", search_kwargs

def execute_search_tool(self, tool_name: str, query: str, **kwargs) -> TavilyResponse:
"""
Expand All @@ -109,6 +148,7 @@ def execute_search_tool(self, tool_name: str, query: str, **kwargs) -> TavilyRes
- "search_news_last_week": 本周新闻
- "search_images_for_news": 新闻图片搜索
- "search_news_by_date": 按日期范围搜索新闻
- "search_market_sentiment": 结构化股票与市场情绪研究
query: 搜索查询
**kwargs: 额外参数(如start_date, end_date, max_results)

Expand All @@ -134,6 +174,16 @@ def execute_search_tool(self, tool_name: str, query: str, **kwargs) -> TavilyRes
if not start_date or not end_date:
raise ValueError("search_news_by_date工具需要start_date和end_date参数")
return self.search_agency.search_news_by_date(query, start_date, end_date)
elif tool_name == "search_market_sentiment":
if not self.market_sentiment_agency:
logger.info(" → Adanos市场情绪工具未配置,跳过结构化情绪研究")
return TavilyResponse(
query=query,
answer="Structured market sentiment tool is not configured.",
results=[],
)
days = kwargs.get("days", 7)
return self.market_sentiment_agency.search_market_sentiment(query, days=days)
else:
logger.warning(f" ⚠️ 未知的搜索工具: {tool_name},使用默认基础搜索")
return self.search_agency.basic_search_news(query)
Expand Down Expand Up @@ -238,26 +288,7 @@ def _initial_search_and_summary(self, paragraph_index: int):
# 执行搜索
logger.info(" - 执行网络搜索...")

# 处理search_news_by_date的特殊参数
search_kwargs = {}
if search_tool == "search_news_by_date":
start_date = search_output.get("start_date")
end_date = search_output.get("end_date")

if start_date and end_date:
# 验证日期格式
if self._validate_date_format(start_date) and self._validate_date_format(end_date):
search_kwargs["start_date"] = start_date
search_kwargs["end_date"] = end_date
logger.info(f" - 时间范围: {start_date} 到 {end_date}")
else:
logger.info(f" ⚠️ 日期格式错误(应为YYYY-MM-DD),改用基础搜索")
logger.info(f" 提供的日期: start_date={start_date}, end_date={end_date}")
search_tool = "basic_search_news"
else:
logger.info(f" ⚠️ search_news_by_date工具缺少时间参数,改用基础搜索")
search_tool = "basic_search_news"

search_tool, search_kwargs = self._build_search_kwargs(search_tool, search_output, " - ")
search_response = self.execute_search_tool(search_tool, search_query, **search_kwargs)

# 转换为兼容格式
Expand Down Expand Up @@ -328,27 +359,7 @@ def _reflection_loop(self, paragraph_index: int):
logger.info(f" 选择的工具: {search_tool}")
logger.info(f" 反思推理: {reasoning}")

# 执行反思搜索
# 处理search_news_by_date的特殊参数
search_kwargs = {}
if search_tool == "search_news_by_date":
start_date = reflection_output.get("start_date")
end_date = reflection_output.get("end_date")

if start_date and end_date:
# 验证日期格式
if self._validate_date_format(start_date) and self._validate_date_format(end_date):
search_kwargs["start_date"] = start_date
search_kwargs["end_date"] = end_date
logger.info(f" 时间范围: {start_date} 到 {end_date}")
else:
logger.info(f" ⚠️ 日期格式错误(应为YYYY-MM-DD),改用基础搜索")
logger.info(f" 提供的日期: start_date={start_date}, end_date={end_date}")
search_tool = "basic_search_news"
else:
logger.info(f" ⚠️ search_news_by_date工具缺少时间参数,改用基础搜索")
search_tool = "basic_search_news"

search_tool, search_kwargs = self._build_search_kwargs(search_tool, reflection_output, " ")
search_response = self.execute_search_tool(search_tool, search_query, **search_kwargs)

# 转换为兼容格式
Expand Down
52 changes: 36 additions & 16 deletions QueryEngine/nodes/search_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,12 @@
"""

import json
from typing import Dict, Any
from typing import Any, Dict
from json.decoder import JSONDecodeError
from loguru import logger

from .base_node import BaseNode
from ..prompts import SYSTEM_PROMPT_FIRST_SEARCH, SYSTEM_PROMPT_REFLECTION
from ..prompts import build_first_search_prompt, build_reflection_prompt
from ..utils.text_processing import (
remove_reasoning_from_output,
clean_json_tags,
Expand All @@ -21,14 +21,15 @@
class FirstSearchNode(BaseNode):
"""为段落生成首次搜索查询的节点"""

def __init__(self, llm_client):
def __init__(self, llm_client, enable_market_sentiment: bool = False):
"""
初始化首次搜索节点

Args:
llm_client: LLM客户端
"""
super().__init__(llm_client, "FirstSearchNode")
self.system_prompt = build_first_search_prompt(enable_market_sentiment)

def validate_input(self, input_data: Any) -> bool:
"""验证输入数据"""
Expand All @@ -42,7 +43,7 @@ def validate_input(self, input_data: Any) -> bool:
return "title" in input_data and "content" in input_data
return False

def run(self, input_data: Any, **kwargs) -> Dict[str, str]:
def run(self, input_data: Any, **kwargs) -> Dict[str, Any]:
"""
调用LLM生成搜索查询和理由

Expand All @@ -66,7 +67,7 @@ def run(self, input_data: Any, **kwargs) -> Dict[str, str]:
logger.info("正在生成首次搜索查询")

# 调用LLM
response = self.llm_client.stream_invoke_to_string(SYSTEM_PROMPT_FIRST_SEARCH, message)
response = self.llm_client.stream_invoke_to_string(self.system_prompt, message)

# 处理响应
processed_response = self.process_output(response)
Expand All @@ -78,7 +79,7 @@ def run(self, input_data: Any, **kwargs) -> Dict[str, str]:
logger.exception(f"生成首次搜索查询失败: {str(e)}")
raise e

def process_output(self, output: str) -> Dict[str, str]:
def process_output(self, output: str) -> Dict[str, Any]:
"""
处理LLM输出,提取搜索查询和推理

Expand Down Expand Up @@ -123,22 +124,28 @@ def process_output(self, output: str) -> Dict[str, str]:
# 验证和清理结果
search_query = result.get("search_query", "")
reasoning = result.get("reasoning", "")
search_tool = result.get("search_tool") or "basic_search_news"
start_date = result.get("start_date")
end_date = result.get("end_date")

if not search_query:
logger.warning("未找到搜索查询,使用默认查询")
return self._get_default_search_query()

return {
"search_query": search_query,
"reasoning": reasoning
"search_tool": search_tool,
"reasoning": reasoning,
"start_date": start_date,
"end_date": end_date,
}

except Exception as e:
self.log_error(f"处理输出失败: {str(e)}")
# 返回默认查询
return self._get_default_search_query()

def _get_default_search_query(self) -> Dict[str, str]:
def _get_default_search_query(self) -> Dict[str, Any]:
"""
获取默认搜索查询

Expand All @@ -147,21 +154,25 @@ def _get_default_search_query(self) -> Dict[str, str]:
"""
return {
"search_query": "相关主题研究",
"reasoning": "由于解析失败,使用默认搜索查询"
"search_tool": "basic_search_news",
"reasoning": "由于解析失败,使用默认搜索查询",
"start_date": None,
"end_date": None,
}


class ReflectionNode(BaseNode):
"""反思段落并生成新搜索查询的节点"""

def __init__(self, llm_client):
def __init__(self, llm_client, enable_market_sentiment: bool = False):
"""
初始化反思节点

Args:
llm_client: LLM客户端
"""
super().__init__(llm_client, "ReflectionNode")
self.system_prompt = build_reflection_prompt(enable_market_sentiment)

def validate_input(self, input_data: Any) -> bool:
"""验证输入数据"""
Expand All @@ -177,7 +188,7 @@ def validate_input(self, input_data: Any) -> bool:
return all(field in input_data for field in required_fields)
return False

def run(self, input_data: Any, **kwargs) -> Dict[str, str]:
def run(self, input_data: Any, **kwargs) -> Dict[str, Any]:
"""
调用LLM反思并生成搜索查询

Expand All @@ -201,7 +212,7 @@ def run(self, input_data: Any, **kwargs) -> Dict[str, str]:
logger.info("正在进行反思并生成新搜索查询")

# 调用LLM
response = self.llm_client.stream_invoke_to_string(SYSTEM_PROMPT_REFLECTION, message)
response = self.llm_client.stream_invoke_to_string(self.system_prompt, message)

# 处理响应
processed_response = self.process_output(response)
Expand All @@ -213,7 +224,7 @@ def run(self, input_data: Any, **kwargs) -> Dict[str, str]:
logger.exception(f"反思生成搜索查询失败: {str(e)}")
raise e

def process_output(self, output: str) -> Dict[str, str]:
def process_output(self, output: str) -> Dict[str, Any]:
"""
处理LLM输出,提取搜索查询和推理

Expand Down Expand Up @@ -258,22 +269,28 @@ def process_output(self, output: str) -> Dict[str, str]:
# 验证和清理结果
search_query = result.get("search_query", "")
reasoning = result.get("reasoning", "")
search_tool = result.get("search_tool") or "basic_search_news"
start_date = result.get("start_date")
end_date = result.get("end_date")

if not search_query:
logger.warning("未找到搜索查询,使用默认查询")
return self._get_default_reflection_query()

return {
"search_query": search_query,
"reasoning": reasoning
"search_tool": search_tool,
"reasoning": reasoning,
"start_date": start_date,
"end_date": end_date,
}

except Exception as e:
logger.exception(f"处理输出失败: {str(e)}")
# 返回默认查询
return self._get_default_reflection_query()

def _get_default_reflection_query(self) -> Dict[str, str]:
def _get_default_reflection_query(self) -> Dict[str, Any]:
"""
获取默认反思搜索查询

Expand All @@ -282,5 +299,8 @@ def _get_default_reflection_query(self) -> Dict[str, str]:
"""
return {
"search_query": "深度研究补充信息",
"reasoning": "由于解析失败,使用默认反思搜索查询"
"search_tool": "basic_search_news",
"reasoning": "由于解析失败,使用默认反思搜索查询",
"start_date": None,
"end_date": None,
}
8 changes: 6 additions & 2 deletions QueryEngine/prompts/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,9 @@
output_schema_first_summary,
output_schema_reflection,
output_schema_reflection_summary,
input_schema_report_formatting
input_schema_report_formatting,
build_first_search_prompt,
build_reflection_prompt,
)

__all__ = [
Expand All @@ -30,5 +32,7 @@
"output_schema_first_summary",
"output_schema_reflection",
"output_schema_reflection_summary",
"input_schema_report_formatting"
"input_schema_report_formatting",
"build_first_search_prompt",
"build_reflection_prompt",
]
Loading