检索策略选择与RAG核心逻辑¶
学习目标¶
- 掌握如何根据查询特点动态选择最优策略。
- 理解RAG系统如何整合查询分类、检索和生成阶段。
- 能够应用策略选择和RAG逻辑,构建一个从输入处理到答案生成的高效系统。
strategy_selector.py和rag_system.py是EduRAG系统中core模块的核心组件,分别负责检索策略的选择和RAG系统的整体逻辑整合。strategy_selector.py通过大语言模型动态选择适合用户查询的检索策略,而rag_system.py则将前几章介绍的模块(如向量存储、Prompt管理和查询分类)整合起来,完成从查询输入到答案生成的完整流程。这两个模块共同确保系统能够高效、准确地响应用户的各种查询。
6.1 检索策略选择¶
功能概述¶
strategy_selector.py定义了StrategySelector类,通过调用大语言模型根据用户查询选择最合适的检索策略。支持的策略包括直接检索、HyDE(假设问题检索)、子查询检索和回溯问题检索,旨在优化检索阶段的输入处理。
代码示例¶
# core/strategy_selector.py 源码
# 导入 LangChain 提示模板
from langchain.prompts import PromptTemplate
# 导入日志和配置
from base import logger, Config
# 导入 OpenAI
from openai import OpenAI
class StrategySelector:
def __init__(self):
# 初始化 OpenAI 客户端
self.client = OpenAI(api_key=Config().DASHSCOPE_API_KEY,
base_url=Config().DASHSCOPE_BASE_URL)
# 获取策略选择提示模板
self.strategy_prompt_template = self._get_strategy_prompt()
def call_dashscope(self, prompt):
# 调用 DashScope API
try:
# 创建聊天完成请求
completion = self.client.chat.completions.create(
model=Config().LLM_MODEL,
messages=[
{"role": "system", "content": "你是一个有用的助手。"},
{"role": "user", "content": prompt},
],
temperature=0.1
)
# 返回完成结果
return completion.choices[0].message.content if completion.choices else "直接检索"
except Exception as e:
# 记录 API 调用失败
logger.error(f"DashScope API 调用失败: {e}")
# 默认返回直接检索
return "直接检索"
def _get_strategy_prompt(self):
# 定义私有方法,获取策略选择 Prompt 模板
return PromptTemplate(
template="""
你是一个智能助手,负责分析用户查询 {query},并从以下四种检索增强策略中选择一个最适合的策略,直接返回策略名称,不需要解释过程。
以下是几种检索增强策略及其适用场景:
1. **直接检索:**
* 描述:对用户查询直接进行检索,不进行任何增强处理。
* 适用场景:适用于查询意图明确,需要从知识库中检索**特定信息**的问题,例如:
* 示例:
* 查询:AI 学科学费是多少?
* 策略:直接检索
* 查询:JAVA的课程大纲是什么?
* 策略:直接检索
2. **假设问题检索(HyDE):**
* 描述:使用 LLM 生成一个假设的答案,然后基于假设答案进行检索。
* 适用场景:适用于查询较为抽象,直接检索效果不佳的问题,例如:
* 示例:
* 查询:人工智能在教育领域的应用有哪些?
* 策略:假设问题检索
3. **子查询检索:**
* 描述:将复杂的用户查询拆分为多个简单的子查询,分别检索并合并结果。
* 适用场景:适用于查询涉及多个实体或方面,需要分别检索不同信息的问题,例如:
* 示例:
* 查询:比较 Milvus 和 Zilliz Cloud 的优缺点。
* 策略:子查询检索
4. **回溯问题检索:**
* 描述:将复杂的用户查询转化为更基础、更易于检索的问题,然后进行检索。
* 适用场景:适用于查询较为复杂,需要简化后才能有效检索的问题,例如:
* 示例:
* 查询:我有一个包含 100 亿条记录的数据集,想把它存储到 Milvus 中进行查询。可以吗?
* 策略:回溯问题检索
根据用户查询 {query},直接返回最适合的策略名称,例如 "直接检索"。不要输出任何分析过程或其他内容。
"""
,
input_variables=["query"],
)
# 定义方法,选择检索策略
def select_strategy(self, query):
# 调用 LLM 获取检索策略
strategy = self.call_dashscope(self.strategy_prompt_template.format(query=query)).strip()
logger.info(f"为查询 '{query}' 选择的检索策略:{strategy}")
return strategy
if __name__ == '__main__':
ss = StrategySelector()
ss.select_strategy('你好吗')
实现细节¶
__init__:- 作用:初始化DashScope客户端和策略选择Prompt。
- 逻辑:连接大语言模型API,准备Prompt模板。
call_dashscope:- 作用:封装DashScope API调用,处理异常并返回模型输出。
- 逻辑:确保API调用的鲁棒性,记录错误日志。
_get_strategy_prompt:- 作用:定义用于策略选择的Prompt模板。
- 设计逻辑:简洁描述四种策略及其适用场景,要求模型直接返回策略名称。
select_strategy:- 作用:根据查询调用模型选择策略并返回。
- 逻辑:记录选择的策略,便于调试。
说明¶
- 动态选择:利用大语言模型的语义理解能力,灵活适应不同查询。
- 策略多样性:支持四种策略,覆盖常见查询场景。
- 效率:直接返回策略名称,避免冗余输出。
6.2 RAG核心逻辑¶
功能概述¶
rag_system.py定义了RAGSystem类,整合系统的各个模块,完成从查询输入到答案生成的完整流程。它通过查询分类选择处理路径,利用检索策略优化文档检索,并结合上下文生成最终答案。
代码示例¶
# core/rag_system.py 源码
from prompts import RAGPrompts
# 导入 time 模块,用于计算时间
import time
from base import logger, Config
from query_classifier import QueryClassifier # 导入查询分类器
from strategy_selector import StrategySelector # 导入策略选择器
conf = Config()
# 定义 RAGSystem 类,封装 RAG 系统的核心逻辑
class RAGSystem:
# 初始化方法,设置 RAG 系统的基本参数
def __init__(self, vector_store, llm):
# 设置向量数据库对象
self.vector_store = vector_store
# 设置大语言模型调用函数
self.llm = llm
# 获取 RAG 提示模板
self.rag_prompt = RAGPrompts.rag_prompt()
# 初始化查询分类器
self.query_classifier = QueryClassifier(model_path='./core/bert_query_classifier')
# 初始化策略选择器
self.strategy_selector = StrategySelector()
# 定义私有方法,使用假设文档进行检索(HyDE)
def _retrieve_with_hyde(self, query):
logger.info(f"使用 HyDE 策略进行检索 (查询: '{query}')")
# 获取假设问题生成的 Prompt 模板
hyde_prompt_template = RAGPrompts.hyde_prompt() # 使用 template 后缀区分
# 调用大语言模型生成假设答案
try:
hypo_answer = self.llm(hyde_prompt_template.format(query=query)).strip()
logger.info(f"HyDE 生成的假设答案: '{hypo_answer}'")
# 使用假设答案进行检索,并返回检索结果
# 注意:HyDE 通常只用于生成检索向量,不一定需要 rerank 这一步,但这里复用了
return self.vector_store.hybrid_search_with_rerank(
hypo_answer, k=conf.RETRIEVAL_K # 使用 K 而非 M
)
except Exception as e:
logger.error(f"HyDE 策略执行失败: {e}")
return []
# 定义私有方法,使用子查询进行检索
def _retrieve_with_subqueries(self, query):
logger.info(f"使用子查询策略进行检索 (查询: '{query}')")
# 获取子查询生成的 Prompt 模板
subquery_prompt_template = RAGPrompts.subquery_prompt() # 使用 template 后缀区分
try:
# 调用大语言模型生成子查询列表
subqueries_text = self.llm(subquery_prompt_template.format(query=query)).strip()
subqueries = [q.strip() for q in subqueries_text.split("\n") if q.strip()]
logger.info(f"生成的子查询: {subqueries}")
if not subqueries:
logger.warning("未能生成有效的子查询")
return []
# 初始化空列表,用于存储所有子查询的检索结果
all_docs = []
# 遍历每个子查询
for sub_q in subqueries:
# 使用子查询进行检索,并将结果添加到列表中
# 这里对每个子查询都执行了 hybrid search + rerank,开销可能较大
docs = self.vector_store.hybrid_search_with_rerank(
sub_q, k=conf.RETRIEVAL_K # 使用 K
)
all_docs.extend(docs)
logger.info(f"子查询 '{sub_q}' 检索到 {len(docs)} 个文档")
# 对所有检索结果进行去重 (基于对象内存地址,如果 Document 内容相同但对象不同则无法去重)
# 更可靠的去重方式是基于文档内容或 ID
unique_docs_dict = {doc.page_content: doc for doc in all_docs} # 基于内容去重
unique_docs = list(unique_docs_dict.values())
logger.info(f"所有子查询共检索到 {len(all_docs)} 个文档, 去重后剩 {len(unique_docs)} 个")
# 返回去重后的文档,限制数量 (是否需要在此处限制? retrieve_and_merge 末尾会限制)
# return unique_docs[: Config.CANDIDATE_M]
return unique_docs # 返回所有唯一文档,让 retrieve_and_merge 处理数量
except Exception as e:
logger.error(f"子查询策略执行失败: {e}")
return []
# 定义私有方法,使用回溯问题进行检索
def _retrieve_with_backtracking(self, query):
logger.info(f"使用回溯问题策略进行检索 (查询: '{query}')")
# 获取回溯问题生成的 Prompt 模板
backtrack_prompt_template = RAGPrompts.backtracking_prompt() # 使用 template 后缀区分
try:
# 调用大语言模型生成回溯问题
simplified_query = self.llm(backtrack_prompt_template.format(query=query)).strip()
logger.info(f"生成的回溯问题: '{simplified_query}'")
# 使用回溯问题进行检索,并返回检索结果
return self.vector_store.hybrid_search_with_rerank(
simplified_query, k=conf.RETRIEVAL_K # 使用 K
)
except Exception as e:
logger.error(f"回溯问题策略执行失败: {e}")
return []
# 定义方法,检索并合并相关文档
def retrieve_and_merge(self, query, source_filter=None, strategy=None): # 新增 strategy 参数
# 如果未指定检索策略,则使用策略选择器选择
if not strategy:
strategy = self.strategy_selector.select_strategy(query)
# 根据检索策略选择不同的检索方式
ranked_sub_chunks = [] # 初始化
if strategy == "回溯问题检索":
ranked_sub_chunks = self._retrieve_with_backtracking(query)
elif strategy == "子查询检索":
ranked_sub_chunks = self._retrieve_with_subqueries(query) # 返回的是唯一文档列表
# 注意:子查询返回的是已 rerank 过的父文档或子块列表,后续合并逻辑可能需要调整
# 当前实现中,子查询返回的是初步检索(可能已rerank)的块,再进行合并
elif strategy == "假设问题检索":
ranked_sub_chunks = self._retrieve_with_hyde(query)
else: # 默认或“直接检索”
logger.info(f"使用直接检索策略 (查询: '{query}')")
ranked_sub_chunks = self.vector_store.hybrid_search_with_rerank(
query, k=conf.RETRIEVAL_K, source_filter=source_filter
) # 注意 hybrid_search_with_rerank 返回的是 rerank 后的父文档
logger.info(f"策略 '{strategy}' 检索到 {len(ranked_sub_chunks)} 个候选文档 (可能已是父文档)")
final_context_docs = ranked_sub_chunks[:conf.CANDIDATE_M]
logger.info(f"最终选取 {len(final_context_docs)} 个文档作为上下文")
return final_context_docs
# 定义方法,生成答案
def generate_answer(self, query, source_filter=None):
# 记录查询开始时间
start_time = time.time()
logger.info(f"开始处理查询: '{query}', 学科过滤: {source_filter}")
# 判断查询类型
query_category = self.query_classifier.predict_category(query)
logger.info(f"查询分类结果:{query_category} (查询: '{query}')")
# 如果查询属于“通用知识”类别,则直接使用 LLM 回答
if query_category == "通用知识":
logger.info("查询为通用知识,直接调用 LLM")
prompt_input = self.rag_prompt.format(
context="", question=query, phone=conf.CUSTOMER_SERVICE_PHONE
) # 不使用上下文
try:
answer = self.llm(prompt_input)
except Exception as e:
logger.error(f"直接调用 LLM 失败: {e}")
answer = f"抱歉,处理您的通用知识问题时出错。请联系人工客服:{conf.CUSTOMER_SERVICE_PHONE}"
processing_time = time.time() - start_time
logger.info(
f"通用知识查询处理完成 (耗时: {processing_time:.2f}s, 查询: '{query}')"
)
return answer
# 否则,进行 RAG 检索并生成答案
logger.info("查询为专业咨询,执行 RAG 流程")
# 选择检索策略
strategy = self.strategy_selector.select_strategy(query)
# 检索相关文档
context_docs = self.retrieve_and_merge(
query, source_filter=source_filter, strategy=strategy
) # 传递 strategy
# 准备上下文
if context_docs:
context = "\n\n".join([doc.page_content for doc in context_docs]) # 使用换行符分隔文档
logger.info(f"构建上下文完成,包含 {len(context_docs)} 个文档块")
# logger.debug(f"上下文内容:\n{context[:500]}...") # Debug 日志可以打印部分上下文
else:
context = ""
logger.info("未检索到相关文档,上下文为空")
# 构造 Prompt,调用大语言模型生成答案
prompt_input = self.rag_prompt.format(
context=context, question=query, phone=conf.CUSTOMER_SERVICE_PHONE
)
# logger.debug(f"最终生成的 Prompt:\n{prompt_input}") # Debug 日志
try:
answer = self.llm(prompt_input)
except Exception as e:
logger.error(f"调用 LLM 生成最终答案失败: {e}")
answer = f"抱歉,处理您的专业咨询问题时出错。请联系人工客服:{conf.CUSTOMER_SERVICE_PHONE}"
# 记录查询处理完成的日志
processing_time = time.time() - start_time
logger.info(f"查询处理完成 (耗时: {processing_time:.2f}s, 查询: '{query}')")
return answer
实现细节¶
- init:
- 作用:初始化RAG系统,整合向量存储、大语言模型和其他核心组件。
- 依赖:依赖VectorStore、RAGPrompts、QueryClassifier和StrategySelector。
- _retrieve_with_hyde:
- 作用:生成假设答案并调用混合检索,适合抽象查询。
- 逻辑:使用hyde_prompt生成假设答案,传递给hybrid_search_with_rerank。
- _retrieve_with_subqueries:
- 作用:分解查询为子查询,分别检索并去重。
- 逻辑:使用subquery_prompt分解查询,合并结果并限制数量。
- _retrieve_with_backtracking:
- 作用:简化查询后检索,降低复杂度。
- 逻辑:使用backtracking_prompt简化查询,调用混合检索。
- retrieve_and_merge:
- 作用:根据策略选择执行检索,直接返回结果。
- 优化:移除冗余的合并逻辑,直接使用hybrid_search_with_rerank的结果(去重的父文档)。
- generate_answer:
- 作用:整合分类、检索和生成,输出最终答案。
- 流程:
- 使用QueryClassifier判断查询类型。
- “通用知识”直接生成答案,“专业咨询”触发检索。
- 结合上下文调用rag_prompt生成回答。
6.3 完整流程的整合¶
从查询到回答的流程¶
- 输入处理:
QueryClassifier分类查询,决定是否需要检索。
- 策略选择:
StrategySelector根据查询选择最佳检索策略。
- 文档检索:
- 根据策略调用
VectorStore的混合检索,获取相关文档。
- 根据策略调用
- 答案生成:
- 使用
RAGPrompts的模板,结合上下文调用大语言模型生成答案。
- 使用
- 输出:
- 返回最终答案,并记录日志。
代码示例(整合逻辑)¶
# 示例:完整查询处理
query = "AI学科学费是多少?"
rag_system = RAGSystem(vector_store, llm)
answer = rag_system.generate_answer(query)
print(answer)
说明¶
- 端到端设计:从输入到输出无缝衔接。
- 智能优化:通过分类和策略选择减少不必要计算。
章节总结¶
本章节深入探讨了strategy_selector.py和rag_system.py:
- strategy_selector.py:通过大语言模型动态选择检索策略(直接检索、HyDE、子查询、回溯),优化检索输入。
rag_system.py:整合查询分类、检索策略、向量存储和Prompt管理,完成从查询到回答的完整RAG流程。- 协同作用:两者结合实现了灵活、高效的查询处理,确保系统能够准确响应用户需求。
学习者通过本章节掌握了RAG系统的核心工作机制,为后续的系统运行和扩展奠定了基础。