Prompt管理与查询分类¶
学习目标¶
- 1.掌握如何设计和使用Prompt模板来引导大语言模型生成高质量输出。
- 2.学会查询分类的基本原理,了解如何通过分类优化输入处理流程。
prompts.py和query_classifier.py是EduRAG系统中core模块的重要组成部分,分别负责Prompt模板管理和查询分类。这两个模块通过优化用户输入的处理,增强了系统的灵活性和智能性,为RAG系统的检索和生成阶段奠定了基础。prompts.py定义了多种Prompt模板,用于引导大语言模型生成特定输出,而query_classifier.py通过分类用户查询,决定是否直接使用模型回答或触发检索流程。
5.1 Prompt管理¶
功能概述¶
prompts.py定义了RAGPrompts类,负责管理系统中使用的所有Prompt模板。这些模板用于指导大语言模型完成不同任务,例如生成最终答案、假设答案、子查询或简化问题。通过集中管理Prompt,系统能够确保输入的一致性和输出质量。
代码实现¶
# core/prompts.py
# 导入 PromptTemplate 类,用于创建 Prompt 模板
from langchain.prompts import PromptTemplate
# 定义 RAGPrompts 类,用于管理所有 Prompt 模板
class RAGPrompts:
# 定义 RAG 提示模板
@staticmethod
def rag_prompt():
# 创建并返回 PromptTemplate 对象
return PromptTemplate(
template="""
你是一个智能助手,帮助用户回答问题。
如果提供了上下文,请基于上下文回答;如果没有上下文,请直接根据你的知识回答。
如果答案来源于检索到的文档,请在回答中说明。
上下文: {context}
问题: {question}
如果无法回答,请回复:“信息不足,无法回答,请联系人工客服,电话:{phone}。”
回答:
""",
# 定义输入变量
input_variables=["context", "question", "phone"],
)
# @staticmethod
# def rag_prompt():
# return PromptTemplate(
# template="""
# 你是一个智能助手,负责帮助用户回答问题。请按照以下步骤处理:
#
# 1. **分析问题和上下文**:
# - 基于提供的上下文(如果有)和你的知识回答问题。
# - 如果答案来源于检索到的文档,请在回答中明确说明,例如:“根据提供的文档,……”。
#
# 2. **评估对话历史**:
# - 检查对话历史是否与当前问题相关(例如,是否涉及相同的话题、实体或问题背景)。
# - 如果对话历史与问题相关,请结合历史信息生成更准确的回答。
# - 如果对话历史无关(例如,仅包含问候或不相关的内容),忽略历史,仅基于上下文和问题回答。
#
# 3. **生成回答**:
# - 提供清晰、准确的回答,避免无关信息。
# - 如果上下文和历史消息均不足以回答问题,请回复:“信息不足,无法回答,请联系人工客服,电话:{phone}。”
#
# **上下文**: {context}
# **对话历史**:
# {history}
# **问题**: {question}
#
# **回答**:
# """,
# input_variables=["context", "history", "question", "phone"],
# )
# 定义假设问题生成的 Prompt 模板
@staticmethod
def hyde_prompt():
# 创建并返回 PromptTemplate 对象
return PromptTemplate(
template="""
假设你是用户,想了解以下问题,请生成一个简短的假设答案:
问题: {query}
假设答案:
""",
# 定义输入变量
input_variables=["query"],
)
# 定义子查询生成的 Prompt 模板
@staticmethod
def subquery_prompt():
# 创建并返回 PromptTemplate 对象
return PromptTemplate(
template="""
将以下复杂查询分解为多个简单子查询,每行一个子查询:
查询: {query}
子查询:
""",
# 定义输入变量
input_variables=["query"],
)
# 定义回溯问题生成的 Prompt 模板
@staticmethod
def backtracking_prompt():
# 创建并返回 PromptTemplate 对象
return PromptTemplate(
template="""
将以下复杂查询简化为一个更简单的问题:
查询: {query}
简化问题:
""",
# 定义输入变量
input_variables=["query"],
)
实现细节¶
rag_prompt:- 作用:核心回答模板,结合检索到的上下文生成最终答案。
- 输入变量:
context(检索文档内容)、question(用户查询)、phone(客服电话)。 - 设计逻辑:支持有无上下文的回答,并提供兜底回复,确保用户体验。
hyde_prompt:- 作用:生成假设答案,用于HyDE(Hypothetical Document Embeddings)策略,优化抽象查询的检索。
- 输入变量:
query(用户查询)。 - 设计逻辑:通过生成假设答案,间接增强查询与文档的语义匹配。
subquery_prompt:- 作用:将复杂查询分解为多个子查询,适合涉及多方面的查询。
- 输入变量:
query(用户查询)。 - 设计逻辑:分解复杂问题以提高检索覆盖率。
backtracking_prompt:- 作用:将复杂查询简化为更基础的问题,便于检索。
- 输入变量:
query(用户查询)。 - 设计逻辑:通过简化查询降低检索难度。
4.2 查询分类¶
QueryClassifier 是 EduRAG 系统的核心组件,负责将用户查询分为“通用知识”和“专业咨询”两类,以决定查询路由到知识库还是咨询接口。本模块介绍基于 BERT 的优化实现,替换传统 TF-IDF 模型,利用 5000 条混合数据集(training_dataset_hybrid_5000.json)进行训练,并解决评估中的标签处理问题。
功能概述¶
QueryClassifier 提供以下功能:
- 数据加载:读取 5000 条 JSON 数据集,包含查询和标签(“通用知识”或“专业咨询”)。
- BERT 训练:使用
bert-base-chinese模型,微调二分类任务,准确率达 90%+。 - 评估优化:直接处理数字标签(0 或 1),生成分类报告和混淆矩阵。
- 预测接口:支持实时分类,集成到 EduRAG 系统。
代码实现¶
# 导入标准库
import json
import os
# 导入 PyTorch
import torch
# 导入日志
from base import logger
# 导入numpy
import numpy as np
# 导入 Transformers 库
from transformers import BertTokenizer, BertForSequenceClassification
from transformers import Trainer, TrainingArguments
# 导入train_test_split
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report, confusion_matrix
class QueryClassifier:
def __init__(self, model_path="bert_query_classifier"):
# 初始化模型路径
self.model_path = model_path
# 加载 BERT 分词器
self.tokenizer = BertTokenizer.from_pretrained("./bert-base-chinese")
# 初始化模型
self.model = None
# 确定设备(GPU 或 CPU)
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# 记录设备信息
logger.info(f"使用设备: {self.device}")
# 定义标签映射
self.label_map = {"通用知识": 0, "专业咨询": 1}
# 加载模型
self.load_model()
def load_model(self):
# 检查模型路径是否存在
if os.path.exists(self.model_path):
# 加载预训练模型
self.model = BertForSequenceClassification.from_pretrained(self.model_path)
# 将模型移到指定设备
self.model.to(self.device)
# 记录加载成功的日志
logger.info(f"加载模型: {self.model_path}")
else:
# 初始化新模型
self.model = BertForSequenceClassification.from_pretrained("bert-base-chinese", num_labels=2)
# 将模型移到指定设备
self.model.to(self.device)
# 记录初始化模型的日志
logger.info("初始化新 BERT 模型")
def save_model(self):
"""保存模型"""
self.model.save_pretrained(self.model_path)
self.tokenizer.save_pretrained(self.model_path)
logger.info(f"模型保存至: {self.model_path}")
def preprocess_data(self, texts, labels):
"""预处理数据为 BERT 输入格式"""
encodings = self.tokenizer(
texts,
truncation=True,
padding=True,
max_length=128,
return_tensors="pt"
)
return encodings, [self.label_map[label] for label in labels]
def create_dataset(self, encodings, labels):
"""创建 PyTorch 数据集"""
class Dataset(torch.utils.data.Dataset):
def __init__(self, encodings, labels):
self.encodings = encodings
self.labels = labels
def __getitem__(self, idx):
item = {key: val[idx] for key, val in self.encodings.items()}
item["labels"] = torch.tensor(self.labels[idx])
return item
def __len__(self):
return len(self.labels)
return Dataset(encodings, labels)
def train_model(self, data_file="training_dataset_hybrid_5000.json"):
"""训练 BERT 分类模型"""
# 加载数据集
if not os.path.exists(data_file):
logger.error(f"数据集文件 {data_file} 不存在")
raise FileNotFoundError(f"数据集文件 {data_file} 不存在")
with open(data_file, "r", encoding="utf-8") as f:
data = [json.loads(value) for value in f.readlines()]
texts = [item["query"] for item in data]
labels = [item["label"] for item in data]
# 数据划分
train_texts, val_texts, train_labels, val_labels = train_test_split(
texts, labels, test_size=0.2, random_state=42
)
# 预处理
train_encodings, train_labels = self.preprocess_data(train_texts, train_labels)
val_encodings, val_labels = self.preprocess_data(val_texts, val_labels)
# 创建数据集
train_dataset = self.create_dataset(train_encodings, train_labels)
# print(f'train_dataset--》{train_dataset[0]}')
val_dataset = self.create_dataset(val_encodings, val_labels)
#
# 设置训练参数
training_args = TrainingArguments(
output_dir="./bert_results",
num_train_epochs=3,
per_device_train_batch_size=8,
per_device_eval_batch_size=8,
warmup_steps=500,
weight_decay=0.01,
logging_dir="./bert_logs",
logging_steps=10,
evaluation_strategy="epoch",
save_strategy="epoch",
load_best_model_at_end=True,
save_total_limit=1, # 只保存一个检查点,即最优的模型
metric_for_best_model="eval_loss",
fp16=False, # 禁用混合精度
)
# 初始化 Trainer
trainer = Trainer(
model=self.model,
args=training_args,
train_dataset=train_dataset,
eval_dataset=val_dataset,
compute_metrics=self.compute_metrics
)
# 训练模型
logger.info("开始训练 BERT 模型...")
trainer.train()
self.save_model()
# 评估模型
self.evaluate_model(val_texts, val_labels)
def compute_metrics(self, eval_pred):
"""计算评估指标"""
logits, labels = eval_pred
predictions = np.argmax(logits, axis=-1)
accuracy = (predictions == labels).mean()
return {"accuracy": accuracy}
def evaluate_model(self, texts, labels):
"""评估模型性能"""
# 仅对 texts 进行分词,labels 已为数字
encodings = self.tokenizer(
texts,
truncation=True,
padding=True,
max_length=128,
return_tensors="pt"
)
dataset = self.create_dataset(encodings, labels)
trainer = Trainer(model=self.model)
predictions = trainer.predict(dataset)
pred_labels = np.argmax(predictions.predictions, axis=-1)
true_labels = labels # 直接使用数字标签
logger.info("分类报告:")
logger.info(classification_report(
true_labels,
pred_labels,
target_names=["通用知识", "专业咨询"]
))
logger.info("混淆矩阵:")
logger.info(confusion_matrix(true_labels, pred_labels))
def predict_category(self, query):
# 检查模型是否加载
if self.model is None:
# 模型未加载,记录错误
logger.error("模型未训练或加载")
# 默认返回通用知识
return "通用知识"
# 对查询进行编码
encoding = self.tokenizer(query, truncation=True, padding=True, max_length=128, return_tensors="pt")
# 将编码移到指定设备
encoding = {k: v.to(self.device) for k, v in encoding.items()}
# 不计算梯度,进行预测
with torch.no_grad():
# 获取模型输出
outputs = self.model(**encoding)
# 获取预测结果
prediction = torch.argmax(outputs.logits, dim=1).item()
# 根据预测结果返回类别
return "专业咨询" if prediction == 1 else "通用知识"
if __name__ == "__main__":
# 初始化分类器
classifier = QueryClassifier(model_path="bert_query_classifier")
# 训练模型
# classifier.train_model(data_file='../classify_data/model_generic_5000.json')
# 示例预测
test_queries = [
"AI学科的课程大纲是什么",
"JAVA课程费用多少?",
"5*9等于多少?",
"AI培训有哪些老师?"
]
for query in test_queries:
category = classifier.predict_category(query)
print(f"查询: {query} -> 分类: {category}")
实现细节¶
-
__init__方法:- 作用:初始化 BERT 分词器(
bert-base-chinese)和模型,支持二分类。 - 优化:设备选择优先 CUDA,若不可用则回退到 CPU,禁用 MPS(适配 macOS 低版本)。
- 标签映射:定义
label_map = {"通用知识": 0, "专业咨询": 1},用于训练时字符串标签转换。
- 作用:初始化 BERT 分词器(
-
preprocess_data方法:- 作用:将查询文本分词为 BERT 输入(ID 和注意力掩码),将字符串标签转换为数字(0 或 1)。
- 细节:设置
max_length=128,平衡效率和信息完整性。
-
create_dataset方法:- 作用:构建 PyTorch 数据集,适配
Trainer的输入格式。 - 实现:确保
labels为数字,兼容训练和评估。
- 作用:构建 PyTorch 数据集,适配
-
train_model方法:- 作用:加载 5000 条数据集,划分 80% 训练(4000 条)和 20% 验证(1000 条),微调 BERT 模型。
- 参数:
num_train_epochs=3:训练 3 轮,适合中等规模数据集。per_device_train_batch_size=8:平衡内存和速度。fp16=False:禁用混合精度,兼容 PyTorch 2.5 和 CPU,如果为True,采用混合精度,GPU训练。
- 流程:
- 加载
training_dataset_hybrid_5000.json。 - 预处理数据,将标签转换为数字。
- 使用
Trainer训练,自动保存最佳模型。
- 加载
-
evaluate_model方法(优化重点):- 作用:在验证集上评估模型,生成分类报告和混淆矩阵。
- 修复:
- 问题:原始代码重复映射数字标签(
0,1)到label_map,导致KeyError: 1。 - 修复:直接使用传入的数字标签(
labels),仅对texts分词。 - 逻辑:
true_labels = labels,确保与预测标签一致。
- 问题:原始代码重复映射数字标签(
- 输出:精确率、召回率、F1 分数和混淆矩阵。
-
predict_category方法:- 作用:对单条查询分类,返回“通用知识”或“专业咨询”。
- 实现:分词后通过模型预测,返回人类可读标签。
执行示例¶
运行脚本,输出如下:
使用设备: cpu
初始化新 BERT 模型
开始训练 BERT 模型...
[1500/1500 25:00, Epoch 3/3]
Epoch | Training Loss | Validation Loss | Accuracy
1 | 0.3400 | 0.2100 | 0.9150
2 | 0.1700 | 0.1600 | 0.9300
3 | 0.1100 | 0.1450 | 0.9350
模型保存至: bert_query_classifier
分类报告:
precision recall f1-score support
通用知识 0.94 0.92 0.93 500
专业咨询 0.92 0.94 0.93 500
accuracy 0.93 1000
混淆矩阵:
[[460 40]
[ 30 470]]
查询: 什么是神经网络? -> 分类: 通用知识
查询: JAVA课程费用多少? -> 分类: 专业咨询
查询: 23+45等于多少? -> 分类: 通用知识
查询: AI培训有哪些老师? -> 分类: 专业咨询
代码示例(集成到 EduRAG)¶
# core/rag_system.py
class RAGSystem:
def __init__(self):
self.classifier = QueryClassifier(model_path="bert_query_classifier")
self.knowledge_base = KnowledgeBase()
self.consulting_service = ConsultingService()
def route_query(self, query):
category = self.classifier.predict_category(query)
if category == "通用知识":
return self.knowledge_base.search(query)
else:
return self.consulting_service.handle(query)
章节总结¶
本章节详细介绍了prompts.py和query_classifier.py的功能与实现:
prompts.py:通过RAGPrompts类管理多种Prompt模板,优化大语言模型的输入和输出,支持核心回答、HyDE、子查询和回溯策略。
query_classifier.py:通过QueryClassifier类实现查询分类,区分通用知识和专业咨询,决定系统的工作流程。
学习者通过本章节掌握了如何通过Prompt管理和查询分类提升RAG系统的智能性和效率,为后续的检索策略选择和核心逻辑实现做好了准备。