ONNX 模型库
返回模型

说明文档

概述

这是一个多标签、多类别的情感线性分类器,与 BGE-small-en 嵌入 配合使用,在 go_emotions 数据集上进行了训练。

标签

来自 go_emotions 数据集的 28 个标签为:

['admiration', 'amusement', 'anger', 'annoyance', 'approval', 'caring', 'confusion', 'curiosity', 'desire', 'disappointment', 'disapproval', 'disgust', 'embarrassment', 'excitement', 'fear', 'gratitude', 'grief', 'joy', 'love', 'nervousness', 'optimism', 'pride', 'realization', 'relief', 'remorse', 'sadness', 'surprise', 'neutral']

指标(每条数据的标签精确匹配)

这是一个多标签、多类别数据集,因此每个标签实际上是一个独立的二分类问题。在 go_emotions 测试集中,对所有标签逐项评估的指标如下所示。

针对每个标签优化阈值以最大化 F1 指标,指标(在 go_emotions 测试集上评估)为:

  • 精确率:0.429
  • 召回率:0.483
  • F1:0.439

按数据集中每个标签的相对支持度加权后,结果为:

  • 精确率:0.457
  • 召回率:0.585
  • F1:0.502

使用固定阈值 0.5 将分数转换为每个标签的二分类预测时,指标(在 go_emotions 测试集上评估,未按支持度加权)为:

  • 精确率:0.650
  • 召回率:0.189
  • F1:0.249

指标(按标签)

这是一个多标签、多类别数据集,因此每个标签实际上是一个独立的二分类问题,按标签分别衡量指标更为准确。

针对每个标签优化阈值以最大化 F1 指标,指标(在 go_emotions 测试集上评估)为:

f1 precision recall support threshold
admiration 0.561 0.517 0.613 504 0.25
amusement 0.647 0.663 0.633 264 0.20
anger 0.324 0.238 0.510 198 0.10
annoyance 0.292 0.200 0.541 320 0.10
approval 0.335 0.297 0.385 351 0.15
caring 0.306 0.221 0.496 135 0.10
confusion 0.360 0.400 0.327 153 0.20
curiosity 0.461 0.392 0.560 284 0.15
desire 0.411 0.476 0.361 83 0.25
disappointment 0.204 0.150 0.318 151 0.10
disapproval 0.357 0.291 0.461 267 0.15
disgust 0.403 0.417 0.390 123 0.20
embarrassment 0.424 0.483 0.378 37 0.30
excitement 0.298 0.255 0.359 103 0.15
fear 0.609 0.590 0.628 78 0.25
gratitude 0.801 0.819 0.784 352 0.30
grief 0.500 0.500 0.500 6 0.75
joy 0.437 0.453 0.422 161 0.20
love 0.641 0.693 0.597 238 0.30
nervousness 0.356 0.364 0.348 23 0.45
optimism 0.416 0.538 0.339 186 0.25
pride 0.500 0.750 0.375 16 0.65
realization 0.247 0.228 0.269 145 0.10
relief 0.364 0.273 0.545 11 0.30
remorse 0.581 0.529 0.643 56 0.25
sadness 0.525 0.519 0.532 156 0.20
surprise 0.301 0.235 0.418 141 0.10
neutral 0.626 0.519 0.786 1787 0.30

阈值存储在 thresholds.json 中。

使用 ONNXRuntime

模型的输入称为 logits,每个标签对应一个输出。每个输出生成一个二维数组,输入有多少行就有多少行,每行有两列——第一列是负类别的概率输出,第二列是正类别的概率输出。

# 假设你已经有了输入句子的 BAAI/bge-small-en 嵌入
# 例如通过 sentence-transformers 生成,如 huggingface.co/BAAI/bge-small-en
# 或通过 ONNX 版本生成,如 huggingface.co/Xenova/bge-small-en

print(embeddings.shape)  # 例如:1 个句子的批次
> (1, 384)

import onnxruntime as ort

sess = ort.InferenceSession("path_to_model_dot_onnx", providers=['CPUExecutionProvider'])

outputs = [o.name for o in sess.get_outputs()]  # 标签列表,按输出顺序排列
preds_onnx = sess.run(_outputs, {'logits': embeddings})
# preds_onnx 是一个包含 28 个条目的列表,每个标签一个,
# 每个条目是一个形状为 (1, 2) 的 numpy 数组(假设输入是 1 个句子的批次)

print(outputs[0])
> surprise
print(preds_onnx[0])
> array([[0.97136074, 0.02863926]], dtype=float32)

# 加载 thresholds.json 并使用它(按标签)将正类别分数转换为二分类预测

关于数据集的评论

某些标签(如 gratitude)单独考虑时表现非常好,而其他标签(如 relief)表现则很差。

这是一个具有挑战性的数据集。诸如 relief 之类的标签在训练数据中的样本确实很少(40k+ 条数据中不到 100 条,测试集中只有 11 条)。

但 go_emotions 的训练数据中也存在一些歧义和/或标注错误,这被认为是限制模型性能的因素。对数据集进行数据清洗以减少标注中的一些错误、歧义、冲突和重复,将产生性能更高的模型。

SamLowe/bge-small-en-go_emotions-classifier-onnx

作者 SamLowe

text-classification
↓ 0 ♥ 0

创建时间: 2023-10-06 20:34:53+00:00

更新时间: 2023-10-06 22:38:06+00:00

在 Hugging Face 上查看

文件 (4)

.gitattributes
README.md
model.onnx ONNX
thresholds.json