ONNX 模型库
返回模型

说明文档

概述

这是一个多标签、多类别的情感线性分类器,可与 sentence-transformers/all-MiniLM-L12-v2 配合使用,该模型在 go_emotions 数据集上进行了训练。

标签

来自 go_emotions 数据集的 28 个标签为:

['admiration', 'amusement', 'anger', 'annoyance', 'approval', 'caring', 'confusion', 'curiosity', 'desire', 'disappointment', 'disapproval', 'disgust', 'embarrassment', 'excitement', 'fear', 'gratitude', 'grief', 'joy', 'love', 'nervousness', 'optimism', 'pride', 'realization', 'relief', 'remorse', 'sadness', 'surprise', 'neutral']

指标(每项标签精确匹配)

这是一个多标签、多类别数据集,因此每个标签实际上都是一个独立的二分类问题。在 go_emotions 测试集中对所有标签进行评估,指标如下所示。

针对每个标签优化阈值以最大化 F1 指标,指标(在 go_emotions 测试集上评估)为:

  • Precision(精确率): 0.378
  • Recall(召回率): 0.438
  • F1: 0.394

按数据集中每个标签的相对支持度加权后,结果为:

  • Precision(精确率): 0.424
  • Recall(召回率): 0.590
  • F1: 0.481

使用固定阈值 0.5 将分数转换为每个标签的二分类预测,指标(在 go_emotions 测试集上评估,未按支持度加权)为:

  • Precision(精确率): 0.568
  • Recall(召回率): 0.214
  • F1: 0.260

指标(按标签)

这是一个多标签、多类别数据集,因此每个标签实际上都是一个独立的二分类问题,按标签分别测量指标更为准确。

针对每个标签优化阈值以最大化 F1 指标,指标(在 go_emotions 测试集上评估)为:

f1 precision recall support threshold
admiration 0.540 0.463 0.649 504 0.20
amusement 0.686 0.669 0.705 264 0.25
anger 0.419 0.373 0.480 198 0.15
annoyance 0.276 0.189 0.512 320 0.10
approval 0.299 0.260 0.350 351 0.15
caring 0.303 0.219 0.489 135 0.10
confusion 0.284 0.269 0.301 153 0.15
curiosity 0.365 0.310 0.444 284 0.15
desire 0.274 0.237 0.325 83 0.15
disappointment 0.188 0.292 0.139 151 0.20
disapproval 0.305 0.257 0.375 267 0.15
disgust 0.450 0.462 0.439 123 0.20
embarrassment 0.348 0.375 0.324 37 0.30
excitement 0.313 0.306 0.320 103 0.20
fear 0.550 0.505 0.603 78 0.25
gratitude 0.776 0.774 0.778 352 0.30
grief 0.353 0.273 0.500 6 0.70
joy 0.370 0.361 0.379 161 0.20
love 0.626 0.717 0.555 238 0.35
nervousness 0.308 0.276 0.348 23 0.55
optimism 0.436 0.432 0.441 186 0.20
pride 0.444 0.545 0.375 16 0.60
realization 0.171 0.146 0.207 145 0.10
relief 0.133 0.250 0.091 11 0.60
remorse 0.468 0.426 0.518 56 0.30
sadness 0.413 0.409 0.417 156 0.20
surprise 0.314 0.303 0.326 141 0.15
neutral 0.622 0.482 0.879 1787 0.25

阈值存储在 thresholds.json 中。

使用 ONNXRuntime

模型的输入称为 logits,每个标签有一个输出。每个输出产生一个二维数组,每行对应一个输入行,每行有两列——第一列是负类的概率输出,第二列是正类的概率输出。

# 假设你已经从 all-MiniLM-L12-v2 获得了输入句子的嵌入向量
# 例如通过 sentence-transformers 生成:
#      huggingface.co/sentence-transformers/all-MiniLM-L12-v2
#      或通过 ONNX 版本,例如 huggingface.co/Xenova/all-MiniLM-L12-v2

print(embeddings.shape)  # 例如:一个包含 1 个句子的批次
> (1, 384)

import onnxruntime as ort

sess = ort.InferenceSession("path_to_model_dot_onnx", providers=['CPUExecutionProvider'])

outputs = [o.name for o in sess.get_outputs()]  # 标签列表,按输出顺序排列
preds_onnx = sess.run(_outputs, {'logits': embeddings})
# preds_onnx 是一个包含 28 个条目的列表,每个标签一个,
# 每个条目是一个形状为 (1, 2) 的 numpy 数组(假设输入是一个包含 1 个句子的批次)

print(outputs[0])
> surprise
print(preds_onnx[0])
> array([[0.97136074, 0.02863926]], dtype=float32)

# 加载 thresholds.json 并使用它(按标签)将正类分数转换为二分类预测

关于数据集的评论

某些标签(如 gratitude/感激)在独立考虑时表现非常出色,而其他标签(如 relief/宽慰)表现则非常差。

这是一个具有挑战性的数据集。诸如 relief 之类的标签在训练数据中确实有更少的样本(40k+ 样本中不到 100 个,测试集中仅有 11 个)。

但是,go_emotions 的训练数据中也存在一些歧义和/或标注错误,这被认为是限制性能的因素。对数据集进行数据清洗以减少标注中的一些错误、歧义、冲突和重复,将产生性能更高的模型。

SamLowe/all-MiniLM-L12-v2-go_emotions-classifier-onnx

作者 SamLowe

text-classification
↓ 0 ♥ 1

创建时间: 2023-10-06 20:29:25+00:00

更新时间: 2023-10-06 22:39:00+00:00

在 Hugging Face 上查看

文件 (4)

.gitattributes
README.md
model.onnx ONNX
thresholds.json