跳转至

检索策略选择与RAG核心逻辑

学习目标

  • 掌握如何根据查询特点动态选择最优策略。
  • 理解RAG系统如何整合查询分类、检索和生成阶段。
  • 能够应用策略选择和RAG逻辑,构建一个从输入处理到答案生成的高效系统。

strategy_selector.pyrag_system.py是EduRAG系统中core模块的核心组件,分别负责检索策略的选择和RAG系统的整体逻辑整合。strategy_selector.py通过大语言模型动态选择适合用户查询的检索策略,而rag_system.py则将前几章介绍的模块(如向量存储、Prompt管理和查询分类)整合起来,完成从查询输入到答案生成的完整流程。这两个模块共同确保系统能够高效、准确地响应用户的各种查询。

6.1 检索策略选择

功能概述

strategy_selector.py定义了StrategySelector类,通过调用大语言模型根据用户查询选择最合适的检索策略。支持的策略包括直接检索、HyDE(假设问题检索)、子查询检索和回溯问题检索,旨在优化检索阶段的输入处理。

代码示例

# core/strategy_selector.py 源码
# 导入 LangChain 提示模板
from langchain.prompts import PromptTemplate
# 导入日志和配置
from base import logger, Config
# 导入 OpenAI
from openai import OpenAI


class StrategySelector:
    def __init__(self):
        # 初始化 OpenAI 客户端
        self.client = OpenAI(api_key=Config().DASHSCOPE_API_KEY,
                             base_url=Config().DASHSCOPE_BASE_URL)
        # 获取策略选择提示模板
        self.strategy_prompt_template = self._get_strategy_prompt()

    def call_dashscope(self, prompt):
        # 调用 DashScope API
        try:
            # 创建聊天完成请求
            completion = self.client.chat.completions.create(
                model=Config().LLM_MODEL,
                messages=[
                    {"role": "system", "content": "你是一个有用的助手。"},
                    {"role": "user", "content": prompt},
                ],
                temperature=0.1
            )
            # 返回完成结果
            return completion.choices[0].message.content if completion.choices else "直接检索"
        except Exception as e:
            # 记录 API 调用失败
            logger.error(f"DashScope API 调用失败: {e}")
            # 默认返回直接检索
            return "直接检索"


    def _get_strategy_prompt(self):
        #   定义私有方法,获取策略选择 Prompt 模板
        return PromptTemplate(
            template="""
            你是一个智能助手,负责分析用户查询 {query},并从以下四种检索增强策略中选择一个最适合的策略,直接返回策略名称,不需要解释过程。

            以下是几种检索增强策略及其适用场景:

            1.  **直接检索:**
                * 描述:对用户查询直接进行检索,不进行任何增强处理。
                * 适用场景:适用于查询意图明确,需要从知识库中检索**特定信息**的问题,例如:
                    * 示例:
                        * 查询:AI 学科学费是多少?
                        * 策略:直接检索
                    * 查询:JAVA的课程大纲是什么?
                        * 策略:直接检索
            2.  **假设问题检索(HyDE):**
                * 描述:使用 LLM 生成一个假设的答案,然后基于假设答案进行检索。
                * 适用场景:适用于查询较为抽象,直接检索效果不佳的问题,例如:
                    * 示例:
                        * 查询:人工智能在教育领域的应用有哪些?
                        * 策略:假设问题检索
            3.  **子查询检索:**
                * 描述:将复杂的用户查询拆分为多个简单的子查询,分别检索并合并结果。
                * 适用场景:适用于查询涉及多个实体或方面,需要分别检索不同信息的问题,例如:
                    * 示例:
                        * 查询:比较 Milvus 和 Zilliz Cloud 的优缺点。
                        * 策略:子查询检索
            4.  **回溯问题检索:**
                * 描述:将复杂的用户查询转化为更基础、更易于检索的问题,然后进行检索。
                * 适用场景:适用于查询较为复杂,需要简化后才能有效检索的问题,例如:
                    * 示例:
                        * 查询:我有一个包含 100 亿条记录的数据集,想把它存储到 Milvus 中进行查询。可以吗?
                        * 策略:回溯问题检索

            根据用户查询 {query},直接返回最适合的策略名称,例如 "直接检索"。不要输出任何分析过程或其他内容。
            """
            ,
            input_variables=["query"],
        )

    #   定义方法,选择检索策略
    def select_strategy(self, query):
        #   调用 LLM 获取检索策略
        strategy = self.call_dashscope(self.strategy_prompt_template.format(query=query)).strip()
        logger.info(f"为查询 '{query}' 选择的检索策略:{strategy}")
        return strategy

if __name__ == '__main__':
    ss = StrategySelector()
    ss.select_strategy('你好吗')

实现细节

  1. __init__
    • 作用:初始化DashScope客户端和策略选择Prompt。
    • 逻辑:连接大语言模型API,准备Prompt模板。
  2. call_dashscope
    • 作用:封装DashScope API调用,处理异常并返回模型输出。
    • 逻辑:确保API调用的鲁棒性,记录错误日志。
  3. _get_strategy_prompt
    • 作用:定义用于策略选择的Prompt模板。
    • 设计逻辑:简洁描述四种策略及其适用场景,要求模型直接返回策略名称。
  4. select_strategy
    • 作用:根据查询调用模型选择策略并返回。
    • 逻辑:记录选择的策略,便于调试。

说明

  • 动态选择:利用大语言模型的语义理解能力,灵活适应不同查询。
  • 策略多样性:支持四种策略,覆盖常见查询场景。
  • 效率:直接返回策略名称,避免冗余输出。

6.2 RAG核心逻辑

功能概述

rag_system.py定义了RAGSystem类,整合系统的各个模块,完成从查询输入到答案生成的完整流程。它通过查询分类选择处理路径,利用检索策略优化文档检索,并结合上下文生成最终答案。

代码示例

# core/rag_system.py 源码
from prompts import RAGPrompts
#   导入 time 模块,用于计算时间
import time
from base import logger, Config
from query_classifier import QueryClassifier  #   导入查询分类器
from strategy_selector import StrategySelector  #   导入策略选择器

conf = Config()

#   定义 RAGSystem 类,封装 RAG 系统的核心逻辑
class RAGSystem:
    #   初始化方法,设置 RAG 系统的基本参数
    def __init__(self, vector_store, llm):
        #   设置向量数据库对象
        self.vector_store = vector_store
        #   设置大语言模型调用函数
        self.llm = llm
        #   获取 RAG 提示模板
        self.rag_prompt = RAGPrompts.rag_prompt()
        #   初始化查询分类器
        self.query_classifier = QueryClassifier(model_path='./core/bert_query_classifier')
        #   初始化策略选择器
        self.strategy_selector = StrategySelector()

    #   定义私有方法,使用假设文档进行检索(HyDE)
    def _retrieve_with_hyde(self, query):
        logger.info(f"使用 HyDE 策略进行检索 (查询: '{query}')")
        #   获取假设问题生成的 Prompt 模板
        hyde_prompt_template = RAGPrompts.hyde_prompt() # 使用 template 后缀区分
        #   调用大语言模型生成假设答案
        try:
            hypo_answer = self.llm(hyde_prompt_template.format(query=query)).strip()
            logger.info(f"HyDE 生成的假设答案: '{hypo_answer}'")
            #   使用假设答案进行检索,并返回检索结果
            #   注意:HyDE 通常只用于生成检索向量,不一定需要 rerank 这一步,但这里复用了
            return self.vector_store.hybrid_search_with_rerank(
                hypo_answer, k=conf.RETRIEVAL_K # 使用 K 而非 M
            )
        except Exception as e:
            logger.error(f"HyDE 策略执行失败: {e}")
            return []


    #   定义私有方法,使用子查询进行检索
    def _retrieve_with_subqueries(self, query):
        logger.info(f"使用子查询策略进行检索 (查询: '{query}')")
        #   获取子查询生成的 Prompt 模板
        subquery_prompt_template = RAGPrompts.subquery_prompt() # 使用 template 后缀区分
        try:
            #   调用大语言模型生成子查询列表
            subqueries_text = self.llm(subquery_prompt_template.format(query=query)).strip()
            subqueries = [q.strip() for q in subqueries_text.split("\n") if q.strip()]
            logger.info(f"生成的子查询: {subqueries}")
            if not subqueries:
                 logger.warning("未能生成有效的子查询")
                 return []

            #   初始化空列表,用于存储所有子查询的检索结果
            all_docs = []
            #   遍历每个子查询
            for sub_q in subqueries:
                #   使用子查询进行检索,并将结果添加到列表中
                #   这里对每个子查询都执行了 hybrid search + rerank,开销可能较大
                docs = self.vector_store.hybrid_search_with_rerank(
                    sub_q, k=conf.RETRIEVAL_K # 使用 K
                )
                all_docs.extend(docs)
                logger.info(f"子查询 '{sub_q}' 检索到 {len(docs)} 个文档")

            #   对所有检索结果进行去重 (基于对象内存地址,如果 Document 内容相同但对象不同则无法去重)
            #   更可靠的去重方式是基于文档内容或 ID
            unique_docs_dict = {doc.page_content: doc for doc in all_docs} # 基于内容去重
            unique_docs = list(unique_docs_dict.values())

            logger.info(f"所有子查询共检索到 {len(all_docs)} 个文档, 去重后剩 {len(unique_docs)} 个")
            #   返回去重后的文档,限制数量 (是否需要在此处限制? retrieve_and_merge 末尾会限制)
            # return unique_docs[: Config.CANDIDATE_M]
            return unique_docs # 返回所有唯一文档,让 retrieve_and_merge 处理数量

        except Exception as e:
            logger.error(f"子查询策略执行失败: {e}")
            return []

    #   定义私有方法,使用回溯问题进行检索
    def _retrieve_with_backtracking(self, query):
        logger.info(f"使用回溯问题策略进行检索 (查询: '{query}')")
        #   获取回溯问题生成的 Prompt 模板
        backtrack_prompt_template = RAGPrompts.backtracking_prompt() # 使用 template 后缀区分
        try:
            #   调用大语言模型生成回溯问题
            simplified_query = self.llm(backtrack_prompt_template.format(query=query)).strip()
            logger.info(f"生成的回溯问题: '{simplified_query}'")
            #   使用回溯问题进行检索,并返回检索结果
            return self.vector_store.hybrid_search_with_rerank(
                simplified_query, k=conf.RETRIEVAL_K # 使用 K
            )
        except Exception as e:
            logger.error(f"回溯问题策略执行失败: {e}")
            return []

    #   定义方法,检索并合并相关文档
    def retrieve_and_merge(self, query, source_filter=None, strategy=None):  #   新增 strategy 参数
        #   如果未指定检索策略,则使用策略选择器选择
        if not strategy:
            strategy = self.strategy_selector.select_strategy(query)

        #   根据检索策略选择不同的检索方式
        ranked_sub_chunks = [] # 初始化
        if strategy == "回溯问题检索":
            ranked_sub_chunks = self._retrieve_with_backtracking(query)
        elif strategy == "子查询检索":
            ranked_sub_chunks = self._retrieve_with_subqueries(query) # 返回的是唯一文档列表
             # 注意:子查询返回的是已 rerank 过的父文档或子块列表,后续合并逻辑可能需要调整
             # 当前实现中,子查询返回的是初步检索(可能已rerank)的块,再进行合并
        elif strategy == "假设问题检索":
            ranked_sub_chunks = self._retrieve_with_hyde(query)
        else:  #   默认或“直接检索”
            logger.info(f"使用直接检索策略 (查询: '{query}')")
            ranked_sub_chunks = self.vector_store.hybrid_search_with_rerank(
                query, k=conf.RETRIEVAL_K, source_filter=source_filter
            ) # 注意 hybrid_search_with_rerank 返回的是 rerank 后的父文档

        logger.info(f"策略 '{strategy}' 检索到 {len(ranked_sub_chunks)} 个候选文档 (可能已是父文档)")
        final_context_docs = ranked_sub_chunks[:conf.CANDIDATE_M]
        logger.info(f"最终选取 {len(final_context_docs)} 个文档作为上下文")
        return final_context_docs

    #   定义方法,生成答案
    def generate_answer(self, query, source_filter=None):
        #   记录查询开始时间
        start_time = time.time()
        logger.info(f"开始处理查询: '{query}', 学科过滤: {source_filter}")

        #   判断查询类型
        query_category = self.query_classifier.predict_category(query)
        logger.info(f"查询分类结果:{query_category} (查询: '{query}')")

        #   如果查询属于“通用知识”类别,则直接使用 LLM 回答
        if query_category == "通用知识":
            logger.info("查询为通用知识,直接调用 LLM")
            prompt_input = self.rag_prompt.format(
                context="", question=query, phone=conf.CUSTOMER_SERVICE_PHONE
            )  #   不使用上下文
            try:
                answer = self.llm(prompt_input)
            except Exception as e:
                logger.error(f"直接调用 LLM 失败: {e}")
                answer = f"抱歉,处理您的通用知识问题时出错。请联系人工客服:{conf.CUSTOMER_SERVICE_PHONE}"
            processing_time = time.time() - start_time
            logger.info(
                f"通用知识查询处理完成 (耗时: {processing_time:.2f}s, 查询: '{query}')"
            )
            return answer

        #   否则,进行 RAG 检索并生成答案
        logger.info("查询为专业咨询,执行 RAG 流程")
        #   选择检索策略
        strategy = self.strategy_selector.select_strategy(query)

        #   检索相关文档
        context_docs = self.retrieve_and_merge(
            query, source_filter=source_filter, strategy=strategy
        )  #   传递 strategy

        #   准备上下文
        if context_docs:
            context = "\n\n".join([doc.page_content for doc in context_docs]) # 使用换行符分隔文档
            logger.info(f"构建上下文完成,包含 {len(context_docs)} 个文档块")
            # logger.debug(f"上下文内容:\n{context[:500]}...") # Debug 日志可以打印部分上下文
        else:
            context = ""
            logger.info("未检索到相关文档,上下文为空")

        #   构造 Prompt,调用大语言模型生成答案
        prompt_input = self.rag_prompt.format(
            context=context, question=query, phone=conf.CUSTOMER_SERVICE_PHONE
        )
        # logger.debug(f"最终生成的 Prompt:\n{prompt_input}") # Debug 日志

        try:
            answer = self.llm(prompt_input)
        except Exception as e:
            logger.error(f"调用 LLM 生成最终答案失败: {e}")
            answer = f"抱歉,处理您的专业咨询问题时出错。请联系人工客服:{conf.CUSTOMER_SERVICE_PHONE}"


        #   记录查询处理完成的日志
        processing_time = time.time() - start_time
        logger.info(f"查询处理完成 (耗时: {processing_time:.2f}s, 查询: '{query}')")
        return answer

实现细节

  1. init
    • 作用:初始化RAG系统,整合向量存储、大语言模型和其他核心组件。
    • 依赖:依赖VectorStore、RAGPrompts、QueryClassifier和StrategySelector。
  2. _retrieve_with_hyde
    • 作用:生成假设答案并调用混合检索,适合抽象查询。
    • 逻辑:使用hyde_prompt生成假设答案,传递给hybrid_search_with_rerank。
  3. _retrieve_with_subqueries
    • 作用:分解查询为子查询,分别检索并去重。
    • 逻辑:使用subquery_prompt分解查询,合并结果并限制数量。
  4. _retrieve_with_backtracking
    • 作用:简化查询后检索,降低复杂度。
    • 逻辑:使用backtracking_prompt简化查询,调用混合检索。
  5. retrieve_and_merge
    • 作用:根据策略选择执行检索,直接返回结果。
    • 优化:移除冗余的合并逻辑,直接使用hybrid_search_with_rerank的结果(去重的父文档)。
  6. generate_answer
    • 作用:整合分类、检索和生成,输出最终答案。
    • 流程:
      • 使用QueryClassifier判断查询类型。
      • “通用知识”直接生成答案,“专业咨询”触发检索。
      • 结合上下文调用rag_prompt生成回答。

6.3 完整流程的整合

从查询到回答的流程

  1. 输入处理
    • QueryClassifier分类查询,决定是否需要检索。
  2. 策略选择
    • StrategySelector根据查询选择最佳检索策略。
  3. 文档检索
    • 根据策略调用VectorStore的混合检索,获取相关文档。
  4. 答案生成
    • 使用RAGPrompts的模板,结合上下文调用大语言模型生成答案。
  5. 输出
    • 返回最终答案,并记录日志。

代码示例(整合逻辑)

# 示例:完整查询处理
query = "AI学科学费是多少?"
rag_system = RAGSystem(vector_store, llm)
answer = rag_system.generate_answer(query)
print(answer)

说明

  • 端到端设计:从输入到输出无缝衔接。
  • 智能优化:通过分类和策略选择减少不必要计算。

章节总结

本章节深入探讨了strategy_selector.pyrag_system.py

  • strategy_selector.py:通过大语言模型动态选择检索策略(直接检索、HyDE、子查询、回溯),优化检索输入。
  • rag_system.py:整合查询分类、检索策略、向量存储和Prompt管理,完成从查询到回答的完整RAG流程。
  • 协同作用:两者结合实现了灵活、高效的查询处理,确保系统能够准确响应用户需求。

学习者通过本章节掌握了RAG系统的核心工作机制,为后续的系统运行和扩展奠定了基础。