ONNX 模型库
返回模型

说明文档

概述

这是一个多标签、多类别的情感线性分类器,与 BGE-small-en-v1.5 嵌入 配合使用,基于 go_emotions 数据集训练。

标签

来自 go_emotions 数据集的 28 个标签为:

['admiration', 'amusement', 'anger', 'annoyance', 'approval', 'caring', 'confusion', 'curiosity', 'desire', 'disappointment', 'disapproval', 'disgust', 'embarrassment', 'excitement', 'fear', 'gratitude', 'grief', 'joy', 'love', 'nervousness', 'optimism', 'pride', 'realization', 'relief', 'remorse', 'sadness', 'surprise', 'neutral']

指标(每项标签精确匹配)

这是一个多标签、多类别数据集,因此每个标签实际上是一个独立的二分类问题。在 go_emotions 测试集中,对每项的所有标签进行评估,指标如下所示。

优化每个标签的阈值以优化 F1 指标,指标(在 go_emotions 测试集上评估)为:

  • 精确率:0.445
  • 召回率:0.476
  • F1:0.449

按数据集中每个标签的相对支持度加权后,结果为:

  • 精确率:0.472
  • 召回率:0.582
  • F1:0.514

使用固定阈值 0.5 将分数转换为每个标签的二分类预测,指标(在 go_emotions 测试集上评估,未按支持度加权)为:

  • 精确率:0.602
  • 召回率:0.250
  • F1:0.303

指标(按标签)

这是一个多标签、多类别数据集,因此每个标签实际上是一个独立的二分类问题,按标签分别衡量指标更为准确。

优化每个标签的阈值以优化 F1 指标,指标(在 go_emotions 测试集上评估)为:

f1 precision recall support threshold
admiration 0.583 0.574 0.593 504 0.30
amusement 0.668 0.722 0.621 264 0.25
anger 0.350 0.309 0.404 198 0.15
annoyance 0.299 0.318 0.281 320 0.20
approval 0.338 0.281 0.425 351 0.15
caring 0.321 0.323 0.319 135 0.20
confusion 0.384 0.313 0.497 153 0.15
curiosity 0.467 0.432 0.507 284 0.20
desire 0.426 0.381 0.482 83 0.20
disappointment 0.210 0.147 0.364 151 0.10
disapproval 0.366 0.288 0.502 267 0.15
disgust 0.416 0.409 0.423 123 0.20
embarrassment 0.370 0.341 0.405 37 0.30
excitement 0.313 0.368 0.272 103 0.25
fear 0.615 0.677 0.564 78 0.40
gratitude 0.828 0.810 0.847 352 0.25
grief 0.545 0.600 0.500 6 0.85
joy 0.455 0.429 0.484 161 0.20
love 0.642 0.673 0.613 238 0.30
nervousness 0.350 0.412 0.304 23 0.60
optimism 0.439 0.417 0.462 186 0.20
pride 0.480 0.667 0.375 16 0.70
realization 0.232 0.191 0.297 145 0.10
relief 0.353 0.500 0.273 11 0.50
remorse 0.643 0.529 0.821 56 0.20
sadness 0.526 0.497 0.558 156 0.20
surprise 0.329 0.318 0.340 141 0.15
neutral 0.634 0.528 0.794 1787 0.30

阈值存储在 thresholds.json 中。

与 ONNXRuntime 配合使用

模型的输入称为 logits,每个标签有一个输出。每个输出产生一个二维数组,每个输入行对应 1 行,每行有 2 列——第一列是负例的概率输出,第二列是正例的概率输出。

# 假设你已经从 BAAI/bge-small-en-v1.5 获得了输入句子的嵌入
# 例如通过 sentence-transformers 生成,如 huggingface.co/BAAI/bge-small-en-v1.5
#      或通过 ONNX 版本,如 huggingface.co/Xenova/bge-small-en-v1.5

print(embeddings.shape)  # 例如:1 个句子的批次
> (1, 384)

import onnxruntime as ort

sess = ort.InferenceSession("path_to_model_dot_onnx", providers=['CPUExecutionProvider'])

outputs = [o.name for o in sess.get_outputs()]  # 标签列表,按输出顺序排列
preds_onnx = sess.run(_outputs, {'logits': embeddings})
# preds_onnx 是一个包含 28 个条目的列表,每个标签一个,
# 每个都是一个形状为 (1, 2) 的 numpy 数组(假设输入是 1 个句子的批次)

print(outputs[0])
> surprise
print(preds_onnx[0])
> array([[0.97136074, 0.02863926]], dtype=float32)

# 加载 thresholds.json 并使用它(按标签)将正例分数转换为二分类预测

关于数据集的评论

某些标签(如 gratitude)在单独考虑时表现非常出色,而其他标签(如 relief)表现非常差。

这是一个具有挑战性的数据集。诸如 relief 之类的标签在训练数据中确实有更少的示例(40k+ 中不到 100 个,测试集中只有 11 个)。

但在 go_emotions 的训练数据中也存在一些歧义和/或标注错误,这被认为是限制性能的因素。对数据集进行数据清洗以减少标注中的一些错误、歧义、冲突和重复,将产生性能更高的模型。

SamLowe/bge-small-en-v1.5-go_emotions-classifier-onnx

作者 SamLowe

text-classification
↓ 0 ♥ 0

创建时间: 2023-10-06 21:04:14+00:00

更新时间: 2023-10-06 22:40:55+00:00

在 Hugging Face 上查看

文件 (4)

.gitattributes
README.md
model.onnx ONNX
thresholds.json