跳转至

Prompt管理与查询分类

学习目标

  • 1.掌握如何设计和使用Prompt模板来引导大语言模型生成高质量输出。
  • 2.学会查询分类的基本原理,了解如何通过分类优化输入处理流程。

prompts.pyquery_classifier.py是EduRAG系统中core模块的重要组成部分,分别负责Prompt模板管理和查询分类。这两个模块通过优化用户输入的处理,增强了系统的灵活性和智能性,为RAG系统的检索和生成阶段奠定了基础。prompts.py定义了多种Prompt模板,用于引导大语言模型生成特定输出,而query_classifier.py通过分类用户查询,决定是否直接使用模型回答或触发检索流程。

5.1 Prompt管理

功能概述

prompts.py定义了RAGPrompts类,负责管理系统中使用的所有Prompt模板。这些模板用于指导大语言模型完成不同任务,例如生成最终答案、假设答案、子查询或简化问题。通过集中管理Prompt,系统能够确保输入的一致性和输出质量。

代码实现

# core/prompts.py
# 导入 PromptTemplate 类,用于创建 Prompt 模板
from langchain.prompts import PromptTemplate


# 定义 RAGPrompts 类,用于管理所有 Prompt 模板
class RAGPrompts:
    # 定义 RAG 提示模板
    @staticmethod
    def rag_prompt():
        # 创建并返回 PromptTemplate 对象
        return PromptTemplate(
            template="""  
            你是一个智能助手,帮助用户回答问题。  
            如果提供了上下文,请基于上下文回答;如果没有上下文,请直接根据你的知识回答。  
            如果答案来源于检索到的文档,请在回答中说明。

            上下文: {context}  
            问题: {question}  

            如果无法回答,请回复:“信息不足,无法回答,请联系人工客服,电话:{phone}。”  
            回答:  
            """,
            #   定义输入变量
            input_variables=["context", "question", "phone"],
        )
        # @staticmethod
    # def rag_prompt():
    #     return PromptTemplate(
    #         template="""
    #     你是一个智能助手,负责帮助用户回答问题。请按照以下步骤处理:
    # 
    #     1. **分析问题和上下文**:
    #        - 基于提供的上下文(如果有)和你的知识回答问题。
    #        - 如果答案来源于检索到的文档,请在回答中明确说明,例如:“根据提供的文档,……”。
    # 
    #     2. **评估对话历史**:
    #        - 检查对话历史是否与当前问题相关(例如,是否涉及相同的话题、实体或问题背景)。
    #        - 如果对话历史与问题相关,请结合历史信息生成更准确的回答。
    #        - 如果对话历史无关(例如,仅包含问候或不相关的内容),忽略历史,仅基于上下文和问题回答。
    # 
    #     3. **生成回答**:
    #        - 提供清晰、准确的回答,避免无关信息。
    #        - 如果上下文和历史消息均不足以回答问题,请回复:“信息不足,无法回答,请联系人工客服,电话:{phone}。”
    # 
    #     **上下文**: {context}
    #     **对话历史**:
    #     {history}
    #     **问题**: {question}
    # 
    #     **回答**:
    #     """,
    #         input_variables=["context", "history", "question", "phone"],
    #     )

    # 定义假设问题生成的 Prompt 模板
    @staticmethod
    def hyde_prompt():
        #   创建并返回 PromptTemplate 对象
        return PromptTemplate(
            template="""  
            假设你是用户,想了解以下问题,请生成一个简短的假设答案:  
            问题: {query}  
            假设答案:  
            """,
            #   定义输入变量
            input_variables=["query"],
        )

    #   定义子查询生成的 Prompt 模板
    @staticmethod
    def subquery_prompt():
        #   创建并返回 PromptTemplate 对象
        return PromptTemplate(
            template="""  
            将以下复杂查询分解为多个简单子查询,每行一个子查询:  
            查询: {query}  
            子查询:  
            """,
            #   定义输入变量
            input_variables=["query"],
        )

    #   定义回溯问题生成的 Prompt 模板
    @staticmethod
    def backtracking_prompt():
        #   创建并返回 PromptTemplate 对象
        return PromptTemplate(
            template="""  
            将以下复杂查询简化为一个更简单的问题:  
            查询: {query}  
            简化问题:  
            """,
            #   定义输入变量
            input_variables=["query"],
        )

实现细节

  1. rag_prompt
    • 作用:核心回答模板,结合检索到的上下文生成最终答案。
    • 输入变量context(检索文档内容)、question(用户查询)、phone(客服电话)。
    • 设计逻辑:支持有无上下文的回答,并提供兜底回复,确保用户体验。
  2. hyde_prompt
    • 作用:生成假设答案,用于HyDE(Hypothetical Document Embeddings)策略,优化抽象查询的检索。
    • 输入变量query(用户查询)。
    • 设计逻辑:通过生成假设答案,间接增强查询与文档的语义匹配。
  3. subquery_prompt
    • 作用:将复杂查询分解为多个子查询,适合涉及多方面的查询。
    • 输入变量query(用户查询)。
    • 设计逻辑:分解复杂问题以提高检索覆盖率。
  4. backtracking_prompt
    • 作用:将复杂查询简化为更基础的问题,便于检索。
    • 输入变量query(用户查询)。
    • 设计逻辑:通过简化查询降低检索难度。

4.2 查询分类


QueryClassifier 是 EduRAG 系统的核心组件,负责将用户查询分为“通用知识”和“专业咨询”两类,以决定查询路由到知识库还是咨询接口。本模块介绍基于 BERT 的优化实现,替换传统 TF-IDF 模型,利用 5000 条混合数据集(training_dataset_hybrid_5000.json)进行训练,并解决评估中的标签处理问题。

功能概述

QueryClassifier 提供以下功能:

  • 数据加载:读取 5000 条 JSON 数据集,包含查询和标签(“通用知识”或“专业咨询”)。
  • BERT 训练:使用 bert-base-chinese 模型,微调二分类任务,准确率达 90%+。
  • 评估优化:直接处理数字标签(0 或 1),生成分类报告和混淆矩阵。
  • 预测接口:支持实时分类,集成到 EduRAG 系统。

代码实现

# 导入标准库
import json
import os
# 导入 PyTorch
import torch
# 导入日志
from base import logger
# 导入numpy
import numpy as np
# 导入 Transformers 库
from transformers import BertTokenizer, BertForSequenceClassification
from transformers import Trainer, TrainingArguments
# 导入train_test_split
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report, confusion_matrix


class QueryClassifier:
    def __init__(self, model_path="bert_query_classifier"):
        # 初始化模型路径
        self.model_path = model_path
        # 加载 BERT 分词器
        self.tokenizer = BertTokenizer.from_pretrained("./bert-base-chinese")
        # 初始化模型
        self.model = None
        # 确定设备(GPU 或 CPU)
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        # 记录设备信息
        logger.info(f"使用设备: {self.device}")
        # 定义标签映射
        self.label_map = {"通用知识": 0, "专业咨询": 1}
        # 加载模型
        self.load_model()

    def load_model(self):
        # 检查模型路径是否存在
        if os.path.exists(self.model_path):
            # 加载预训练模型
            self.model = BertForSequenceClassification.from_pretrained(self.model_path)
            # 将模型移到指定设备
            self.model.to(self.device)
            # 记录加载成功的日志
            logger.info(f"加载模型: {self.model_path}")
        else:
            # 初始化新模型
            self.model = BertForSequenceClassification.from_pretrained("bert-base-chinese", num_labels=2)
            # 将模型移到指定设备
            self.model.to(self.device)
            # 记录初始化模型的日志
            logger.info("初始化新 BERT 模型")

    def save_model(self):
        """保存模型"""
        self.model.save_pretrained(self.model_path)
        self.tokenizer.save_pretrained(self.model_path)
        logger.info(f"模型保存至: {self.model_path}")

    def preprocess_data(self, texts, labels):
        """预处理数据为 BERT 输入格式"""
        encodings = self.tokenizer(
            texts,
            truncation=True,
            padding=True,
            max_length=128,
            return_tensors="pt"
        )
        return encodings, [self.label_map[label] for label in labels]

    def create_dataset(self, encodings, labels):
        """创建 PyTorch 数据集"""

        class Dataset(torch.utils.data.Dataset):
            def __init__(self, encodings, labels):
                self.encodings = encodings
                self.labels = labels

            def __getitem__(self, idx):
                item = {key: val[idx] for key, val in self.encodings.items()}
                item["labels"] = torch.tensor(self.labels[idx])
                return item

            def __len__(self):
                return len(self.labels)

        return Dataset(encodings, labels)

    def train_model(self, data_file="training_dataset_hybrid_5000.json"):
        """训练 BERT 分类模型"""
        # 加载数据集
        if not os.path.exists(data_file):
            logger.error(f"数据集文件 {data_file} 不存在")
            raise FileNotFoundError(f"数据集文件 {data_file} 不存在")

        with open(data_file, "r", encoding="utf-8") as f:
            data = [json.loads(value) for value in f.readlines()]

        texts = [item["query"] for item in data]
        labels = [item["label"] for item in data]

        # 数据划分
        train_texts, val_texts, train_labels, val_labels = train_test_split(
            texts, labels, test_size=0.2, random_state=42
        )

        # 预处理
        train_encodings, train_labels = self.preprocess_data(train_texts, train_labels)
        val_encodings, val_labels = self.preprocess_data(val_texts, val_labels)

        # 创建数据集
        train_dataset = self.create_dataset(train_encodings, train_labels)
        # print(f'train_dataset--》{train_dataset[0]}')
        val_dataset = self.create_dataset(val_encodings, val_labels)
        #
        # 设置训练参数
        training_args = TrainingArguments(
            output_dir="./bert_results",
            num_train_epochs=3,
            per_device_train_batch_size=8,
            per_device_eval_batch_size=8,
            warmup_steps=500,
            weight_decay=0.01,
            logging_dir="./bert_logs",
            logging_steps=10,
            evaluation_strategy="epoch",
            save_strategy="epoch",
            load_best_model_at_end=True,
            save_total_limit=1,  # 只保存一个检查点,即最优的模型
            metric_for_best_model="eval_loss",
            fp16=False,  # 禁用混合精度
        )

        # 初始化 Trainer
        trainer = Trainer(
            model=self.model,
            args=training_args,
            train_dataset=train_dataset,
            eval_dataset=val_dataset,
            compute_metrics=self.compute_metrics
        )

        # 训练模型
        logger.info("开始训练 BERT 模型...")
        trainer.train()
        self.save_model()

        # 评估模型
        self.evaluate_model(val_texts, val_labels)

    def compute_metrics(self, eval_pred):
        """计算评估指标"""
        logits, labels = eval_pred
        predictions = np.argmax(logits, axis=-1)
        accuracy = (predictions == labels).mean()
        return {"accuracy": accuracy}

    def evaluate_model(self, texts, labels):
        """评估模型性能"""
        # 仅对 texts 进行分词,labels 已为数字
        encodings = self.tokenizer(
            texts,
            truncation=True,
            padding=True,
            max_length=128,
            return_tensors="pt"
        )
        dataset = self.create_dataset(encodings, labels)

        trainer = Trainer(model=self.model)
        predictions = trainer.predict(dataset)
        pred_labels = np.argmax(predictions.predictions, axis=-1)
        true_labels = labels  # 直接使用数字标签

        logger.info("分类报告:")
        logger.info(classification_report(
            true_labels,
            pred_labels,
            target_names=["通用知识", "专业咨询"]
        ))
        logger.info("混淆矩阵:")
        logger.info(confusion_matrix(true_labels, pred_labels))

    def predict_category(self, query):
        # 检查模型是否加载
        if self.model is None:
            # 模型未加载,记录错误
            logger.error("模型未训练或加载")
            # 默认返回通用知识
            return "通用知识"
        # 对查询进行编码
        encoding = self.tokenizer(query, truncation=True, padding=True, max_length=128, return_tensors="pt")
        # 将编码移到指定设备
        encoding = {k: v.to(self.device) for k, v in encoding.items()}
        # 不计算梯度,进行预测
        with torch.no_grad():
            # 获取模型输出
            outputs = self.model(**encoding)
            # 获取预测结果
            prediction = torch.argmax(outputs.logits, dim=1).item()
        # 根据预测结果返回类别
        return "专业咨询" if prediction == 1 else "通用知识"

if __name__ == "__main__":
    # 初始化分类器
    classifier = QueryClassifier(model_path="bert_query_classifier")

    # 训练模型
    # classifier.train_model(data_file='../classify_data/model_generic_5000.json')
    # 示例预测
    test_queries = [
        "AI学科的课程大纲是什么",
        "JAVA课程费用多少?",
        "5*9等于多少?",
        "AI培训有哪些老师?"
    ]
    for query in test_queries:
        category = classifier.predict_category(query)
        print(f"查询: {query} -> 分类: {category}")

实现细节

  1. __init__ 方法

    • 作用:初始化 BERT 分词器(bert-base-chinese)和模型,支持二分类。
    • 优化:设备选择优先 CUDA,若不可用则回退到 CPU,禁用 MPS(适配 macOS 低版本)。
    • 标签映射:定义 label_map = {"通用知识": 0, "专业咨询": 1},用于训练时字符串标签转换。
  2. preprocess_data 方法

    • 作用:将查询文本分词为 BERT 输入(ID 和注意力掩码),将字符串标签转换为数字(0 或 1)。
    • 细节:设置 max_length=128,平衡效率和信息完整性。
  3. create_dataset 方法

    • 作用:构建 PyTorch 数据集,适配 Trainer 的输入格式。
    • 实现:确保 labels 为数字,兼容训练和评估。
  4. train_model 方法

    • 作用:加载 5000 条数据集,划分 80% 训练(4000 条)和 20% 验证(1000 条),微调 BERT 模型。
    • 参数
      • num_train_epochs=3:训练 3 轮,适合中等规模数据集。
      • per_device_train_batch_size=8:平衡内存和速度。
      • fp16=False:禁用混合精度,兼容 PyTorch 2.5 和 CPU,如果为True,采用混合精度,GPU训练。
    • 流程
      • 加载 training_dataset_hybrid_5000.json
      • 预处理数据,将标签转换为数字。
      • 使用 Trainer 训练,自动保存最佳模型。
  5. evaluate_model 方法(优化重点)

    • 作用:在验证集上评估模型,生成分类报告和混淆矩阵。
    • 修复
      • 问题:原始代码重复映射数字标签(0, 1)到 label_map,导致 KeyError: 1
      • 修复:直接使用传入的数字标签(labels),仅对 texts 分词。
      • 逻辑:true_labels = labels,确保与预测标签一致。
    • 输出:精确率、召回率、F1 分数和混淆矩阵。
  6. predict_category 方法

    • 作用:对单条查询分类,返回“通用知识”或“专业咨询”。
    • 实现:分词后通过模型预测,返回人类可读标签。

执行示例

运行脚本,输出如下:

使用设备: cpu
初始化新 BERT 模型
开始训练 BERT 模型...
[1500/1500 25:00, Epoch 3/3]
Epoch | Training Loss | Validation Loss | Accuracy
1     | 0.3400       | 0.2100          | 0.9150
2     | 0.1700       | 0.1600          | 0.9300
3     | 0.1100       | 0.1450          | 0.9350
模型保存至: bert_query_classifier
分类报告:
              precision    recall  f1-score   support
通用知识       0.94      0.92      0.93       500
专业咨询       0.92      0.94      0.93       500
accuracy                           0.93      1000
混淆矩阵:
[[460  40]
 [ 30 470]]
查询: 什么是神经网络? -> 分类: 通用知识
查询: JAVA课程费用多少? -> 分类: 专业咨询
查询: 23+45等于多少? -> 分类: 通用知识
查询: AI培训有哪些老师? -> 分类: 专业咨询


代码示例(集成到 EduRAG)

# core/rag_system.py
class RAGSystem:
    def __init__(self):
        self.classifier = QueryClassifier(model_path="bert_query_classifier")
        self.knowledge_base = KnowledgeBase()
        self.consulting_service = ConsultingService()

    def route_query(self, query):
        category = self.classifier.predict_category(query)
        if category == "通用知识":
            return self.knowledge_base.search(query)
        else:
            return self.consulting_service.handle(query)

章节总结

本章节详细介绍了prompts.pyquery_classifier.py的功能与实现:

  • prompts.py:通过RAGPrompts类管理多种Prompt模板,优化大语言模型的输入和输出,支持核心回答、HyDE、子查询和回溯策略。
  • query_classifier.py:通过QueryClassifier类实现查询分类,区分通用知识和专业咨询,决定系统的工作流程。

学习者通过本章节掌握了如何通过Prompt管理和查询分类提升RAG系统的智能性和效率,为后续的检索策略选择和核心逻辑实现做好了准备。