1. 什么是 CrossEncoder? #
CrossEncoder 是一种深度学习模型,它同时处理两个输入文本(如句子对),并直接输出它们之间的相关性分数。与 Bi-Encoder(双编码器)不同,CrossEncoder 在编码时会让两个输入进行充分的交互。
2.主要特点 #
- 高准确性:由于两个输入在模型内部可以充分交互,通常比 Bi-Encoder 表现更好
- 计算代价高:不适合大规模检索,因为需要为每对组合单独计算
- 适用于:重排序(reranking)、文本对分类、语义相似度判断等任务
3.典型应用场景 #
- 信息检索中的结果重排序
- 问答系统中的答案选择
- 文本相似度计算
- 自然语言推理(NLI)
- paraphrase 识别
4.基本用法示例 #
# 从sentence_transformers库中导入CrossEncoder类
from sentence_transformers import CrossEncoder
# 加载预训练的Cross-Encoder模型,用于句子对相关性评分
rerank_model = CrossEncoder("cross-encoder/ms-marco-MiniLM-L-6-v2")
# 构造待评估的句子对列表,每个元素是一个(query, document)元组
sentence_pairs = [
(
"How many people live in Berlin?",
"Berlin has a population of 3,520,031 registered inhabitants in an area of 891.82 square kilometers.",
),
("How many people live in Berlin?", "Berlin is well known for its museums."),
]
# 使用Cross-Encoder模型对句子对进行相关性打分,返回分数列表
scores = rerank_model.predict(sentence_pairs)
# 打印相关性分数数组 [ 8.845852 -4.320078]
print(scores)
# 将句子对和分数打包成一个zip对象(迭代器) <zip object at 0x000001FE559525C0>
zipped = zip(sentence_pairs, scores)
# 打印zip对象本身(显示为对象地址)
print(zipped)
# 将zip对象转换为列表并打印,显示每个句子对及其分数
print(list(zipped))
# 逐个遍历句子对和对应分数,格式化输出每一对及其相关性分数
for pair, score in zip(sentence_pairs, scores):
print(f"{pair} => Score: {score:.4f}")
# names = ['Alice', 'Bob', 'Charlie']
# ages = [25, 30, 35]
# zipped = zip(names, ages)
# print(list(zipped))
# # 输出: [('Alice', 25), ('Bob', 30), ('Charlie', 35)]
5.可用预训练模型 #
HuggingFace Model Hub 提供了多种预训练 CrossEncoder 模型:
cross-encoder/ms-marco-MiniLM-L-6-v2:适用于信息检索cross-encoder/stsb-roberta-large:适用于语义文本相似度cross-encoder/nli-deberta-v3-base:适用于自然语言推理
6.与 BiEncoder 对比 #
| 特性 | CrossEncoder | BiEncoder |
|---|---|---|
| 输入处理 | 同时编码两个输入并交互 | 分别编码两个输入 |
| 速度 | 慢(适合小规模数据) | 快(适合大规模检索) |
| 准确性 | 更高 | 相对较低 |
| 适用场景 | 重排序、精细匹配 | 语义搜索、聚类、大规模检索 |
7.高级用法 #
# 自定义推理
model = CrossEncoder('cross-encoder/stsb-roberta-large',
max_length=512, # 最大序列长度
device='cuda') # 使用GPU
# 带标签的训练数据
train_examples = [
(("句子1", "句子2"), 1), # 1表示相关
(("句子1", "不相关句子"), 0) # 0表示不相关
]
# 可以继续训练模型
model.fit(train_examples, epochs=3)CrossEncoder 是 Sentence-Transformers 库中处理成对文本任务的强大工具,特别适合需要高准确性的场景。