ONNX 模型库
返回模型

说明文档

概述

这是一个多标签、多类别的情感线性分类器,与 sentence-transformers/all-MiniLM-L6-v2 配合使用,基于 go_emotions 数据集训练而成。

标签

go_emotions 数据集的 28 个标签如下:

['admiration', 'amusement', 'anger', 'annoyance', 'approval', 'caring', 'confusion', 'curiosity', 'desire', 'disappointment', 'disapproval', 'disgust', 'embarrassment', 'excitement', 'fear', 'gratitude', 'grief', 'joy', 'love', 'nervousness', 'optimism', 'pride', 'realization', 'relief', 'remorse', 'sadness', 'surprise', 'neutral']

指标(每条数据的标签精确匹配)

这是一个多标签、多类别数据集,因此每个标签实际上是一个独立的二分类问题。在 go_emotions 测试集中对所有标签进行评估,指标如下所示。

针对每个标签优化阈值以最大化 F1 指标,指标(在 go_emotions 测试集上评估)为:

  • Precision(精确率):0.384
  • Recall(召回率):0.438
  • F1:0.397

按数据集中每个标签的相对支持度加权后,结果为:

  • Precision(精确率):0.443
  • Recall(召回率):0.552
  • F1:0.484

使用固定阈值 0.5 将分数转换为每个标签的二分类预测,指标(在 go_emotions 测试集上评估,未经支持度加权)为:

  • Precision(精确率):0.551
  • Recall(召回率):0.211
  • F1:0.261

指标(按标签)

这是一个多标签、多类别数据集,因此每个标签实际上是一个独立的二分类问题,按标签测量指标更为准确。

针对每个标签优化阈值以最大化 F1 指标,指标(在 go_emotions 测试集上评估)如下:

f1 precision recall support threshold
admiration 0.529 0.499 0.563 504 0.25
amusement 0.733 0.672 0.807 264 0.20
anger 0.394 0.363 0.429 198 0.15
annoyance 0.293 0.252 0.350 320 0.15
approval 0.292 0.345 0.254 351 0.20
caring 0.320 0.270 0.393 135 0.15
confusion 0.291 0.276 0.307 153 0.15
curiosity 0.366 0.307 0.454 284 0.15
desire 0.317 0.269 0.386 83 0.15
disappointment 0.159 0.127 0.212 151 0.10
disapproval 0.306 0.341 0.277 267 0.20
disgust 0.405 0.412 0.398 123 0.20
embarrassment 0.364 0.414 0.324 37 0.35
excitement 0.296 0.232 0.408 103 0.15
fear 0.496 0.576 0.436 78 0.40
gratitude 0.793 0.787 0.798 352 0.30
grief 0.323 0.200 0.833 6 0.45
joy 0.402 0.341 0.491 161 0.15
love 0.640 0.679 0.605 238 0.30
nervousness 0.263 0.333 0.217 23 0.70
optimism 0.433 0.453 0.414 186 0.20
pride 0.429 0.500 0.375 16 0.50
realization 0.177 0.159 0.200 145 0.10
relief 0.182 0.182 0.182 11 0.40
remorse 0.541 0.500 0.589 56 0.30
sadness 0.461 0.467 0.455 156 0.20
surprise 0.302 0.299 0.305 141 0.15
neutral 0.620 0.505 0.803 1787 0.30

阈值存储在 thresholds.json 中。

与 ONNXRuntime 配合使用

模型的输入称为 logits,每个标签有一个输出。每个输出产生一个二维数组,每行对应一个输入行,每行有两列——第一列是负例的概率输出,第二列是正例的概率输出。

# 假设你已经从 all-MiniLM-L6-v2 获得了输入句子的嵌入向量
# 例如通过 sentence-transformers 生成,如:
#      huggingface.co/sentence-transformers/all-MiniLM-L6-v2
#      或通过 ONNX 版本,例如 huggingface.co/Xenova/all-MiniLM-L6-v2

print(embeddings.shape)  # 例如:1 个句子的批次
> (1, 384)

import onnxruntime as ort

sess = ort.InferenceSession("path_to_model_dot_onnx", providers=['CPUExecutionProvider'])

outputs = [o.name for o in sess.get_outputs()]  # 标签列表,按输出顺序排列
preds_onnx = sess.run(_outputs, {'logits': embeddings})
# preds_onnx 是一个包含 28 个条目的列表,每个标签一个,
# 每个条目是一个形状为 (1, 2) 的 numpy 数组(假设输入是 1 个句子的批次)

print(outputs[0])
> surprise
print(preds_onnx[0])
> array([[0.97136074, 0.02863926]], dtype=float32)

# 加载 thresholds.json 并使用它(按标签)将正例分数转换为二分类预测

关于数据集的评论

某些标签(如 gratitude/感激)单独考虑时表现非常出色,而其他标签(如 relief/释然)表现则非常差。

这是一个具有挑战性的数据集。像 relief 这样的标签在训练数据中的样本确实很少(40k+ 条数据中不到 100 条,测试集中只有 11 条)。

但 go_emotions 的训练数据中也存在一些明显的歧义和/或标注错误,这被认为是限制模型性能的因素。对数据集进行数据清洗,减少标注中的一些错误、歧义、冲突和重复,将产生性能更高的模型。

SamLowe/all-MiniLM-L6-v2-go_emotions-classifier-onnx

作者 SamLowe

text-classification
↓ 0 ♥ 0

创建时间: 2023-10-06 20:32:08+00:00

更新时间: 2023-10-06 22:40:14+00:00

在 Hugging Face 上查看

文件 (4)

.gitattributes
README.md
model.onnx ONNX
thresholds.json