说明文档
概述
这是一个多标签、多类别的情感线性分类器,可与 sentence-transformers/all-MiniLM-L12-v2 配合使用,该模型在 go_emotions 数据集上进行了训练。
标签
来自 go_emotions 数据集的 28 个标签为:
['admiration', 'amusement', 'anger', 'annoyance', 'approval', 'caring', 'confusion', 'curiosity', 'desire', 'disappointment', 'disapproval', 'disgust', 'embarrassment', 'excitement', 'fear', 'gratitude', 'grief', 'joy', 'love', 'nervousness', 'optimism', 'pride', 'realization', 'relief', 'remorse', 'sadness', 'surprise', 'neutral']
指标(每项标签精确匹配)
这是一个多标签、多类别数据集,因此每个标签实际上都是一个独立的二分类问题。在 go_emotions 测试集中对所有标签进行评估,指标如下所示。
针对每个标签优化阈值以最大化 F1 指标,指标(在 go_emotions 测试集上评估)为:
- Precision(精确率): 0.378
- Recall(召回率): 0.438
- F1: 0.394
按数据集中每个标签的相对支持度加权后,结果为:
- Precision(精确率): 0.424
- Recall(召回率): 0.590
- F1: 0.481
使用固定阈值 0.5 将分数转换为每个标签的二分类预测,指标(在 go_emotions 测试集上评估,未按支持度加权)为:
- Precision(精确率): 0.568
- Recall(召回率): 0.214
- F1: 0.260
指标(按标签)
这是一个多标签、多类别数据集,因此每个标签实际上都是一个独立的二分类问题,按标签分别测量指标更为准确。
针对每个标签优化阈值以最大化 F1 指标,指标(在 go_emotions 测试集上评估)为:
| f1 | precision | recall | support | threshold | |
|---|---|---|---|---|---|
| admiration | 0.540 | 0.463 | 0.649 | 504 | 0.20 |
| amusement | 0.686 | 0.669 | 0.705 | 264 | 0.25 |
| anger | 0.419 | 0.373 | 0.480 | 198 | 0.15 |
| annoyance | 0.276 | 0.189 | 0.512 | 320 | 0.10 |
| approval | 0.299 | 0.260 | 0.350 | 351 | 0.15 |
| caring | 0.303 | 0.219 | 0.489 | 135 | 0.10 |
| confusion | 0.284 | 0.269 | 0.301 | 153 | 0.15 |
| curiosity | 0.365 | 0.310 | 0.444 | 284 | 0.15 |
| desire | 0.274 | 0.237 | 0.325 | 83 | 0.15 |
| disappointment | 0.188 | 0.292 | 0.139 | 151 | 0.20 |
| disapproval | 0.305 | 0.257 | 0.375 | 267 | 0.15 |
| disgust | 0.450 | 0.462 | 0.439 | 123 | 0.20 |
| embarrassment | 0.348 | 0.375 | 0.324 | 37 | 0.30 |
| excitement | 0.313 | 0.306 | 0.320 | 103 | 0.20 |
| fear | 0.550 | 0.505 | 0.603 | 78 | 0.25 |
| gratitude | 0.776 | 0.774 | 0.778 | 352 | 0.30 |
| grief | 0.353 | 0.273 | 0.500 | 6 | 0.70 |
| joy | 0.370 | 0.361 | 0.379 | 161 | 0.20 |
| love | 0.626 | 0.717 | 0.555 | 238 | 0.35 |
| nervousness | 0.308 | 0.276 | 0.348 | 23 | 0.55 |
| optimism | 0.436 | 0.432 | 0.441 | 186 | 0.20 |
| pride | 0.444 | 0.545 | 0.375 | 16 | 0.60 |
| realization | 0.171 | 0.146 | 0.207 | 145 | 0.10 |
| relief | 0.133 | 0.250 | 0.091 | 11 | 0.60 |
| remorse | 0.468 | 0.426 | 0.518 | 56 | 0.30 |
| sadness | 0.413 | 0.409 | 0.417 | 156 | 0.20 |
| surprise | 0.314 | 0.303 | 0.326 | 141 | 0.15 |
| neutral | 0.622 | 0.482 | 0.879 | 1787 | 0.25 |
阈值存储在 thresholds.json 中。
使用 ONNXRuntime
模型的输入称为 logits,每个标签有一个输出。每个输出产生一个二维数组,每行对应一个输入行,每行有两列——第一列是负类的概率输出,第二列是正类的概率输出。
# 假设你已经从 all-MiniLM-L12-v2 获得了输入句子的嵌入向量
# 例如通过 sentence-transformers 生成:
# huggingface.co/sentence-transformers/all-MiniLM-L12-v2
# 或通过 ONNX 版本,例如 huggingface.co/Xenova/all-MiniLM-L12-v2
print(embeddings.shape) # 例如:一个包含 1 个句子的批次
> (1, 384)
import onnxruntime as ort
sess = ort.InferenceSession("path_to_model_dot_onnx", providers=['CPUExecutionProvider'])
outputs = [o.name for o in sess.get_outputs()] # 标签列表,按输出顺序排列
preds_onnx = sess.run(_outputs, {'logits': embeddings})
# preds_onnx 是一个包含 28 个条目的列表,每个标签一个,
# 每个条目是一个形状为 (1, 2) 的 numpy 数组(假设输入是一个包含 1 个句子的批次)
print(outputs[0])
> surprise
print(preds_onnx[0])
> array([[0.97136074, 0.02863926]], dtype=float32)
# 加载 thresholds.json 并使用它(按标签)将正类分数转换为二分类预测
关于数据集的评论
某些标签(如 gratitude/感激)在独立考虑时表现非常出色,而其他标签(如 relief/宽慰)表现则非常差。
这是一个具有挑战性的数据集。诸如 relief 之类的标签在训练数据中确实有更少的样本(40k+ 样本中不到 100 个,测试集中仅有 11 个)。
但是,go_emotions 的训练数据中也存在一些歧义和/或标注错误,这被认为是限制性能的因素。对数据集进行数据清洗以减少标注中的一些错误、歧义、冲突和重复,将产生性能更高的模型。
SamLowe/all-MiniLM-L12-v2-go_emotions-classifier-onnx
作者 SamLowe
创建时间: 2023-10-06 20:29:25+00:00
更新时间: 2023-10-06 22:39:00+00:00
在 Hugging Face 上查看