ONNX 模型库
返回模型

说明文档

RexBERT-base (ONNX)

这是 thebajajra/RexBERT-base 的 ONNX 版本。它是通过 这个 Hugging Face Space 自动转换并上传的。

使用 Transformers.js

查看 fill-mask 的流水线文档:https://huggingface.co/docs/transformers.js/api/pipelines#module_pipelines.FillMaskPipeline


RexBERT-base

License: Apache2.0 Models Data GitHub

简介:一个仅编码器的 Transformer 模型(ModernBERT 风格),专为电子商务应用设计,分三个阶段训练——预训练上下文扩展衰减——用于产品搜索、属性提取、分类和嵌入等场景。该模型在 2.3T+ 通用 token 和 350B+ 电商专用 token 上进行了训练。


目录


快速开始

import torch
from transformers import AutoTokenizer, AutoModel, AutoModelForMaskedLM, pipeline

MODEL_ID = "thebajajra/RexBERT-base"

# 分词器
tok = AutoTokenizer.from_pretrained(MODEL_ID, use_fast=True)

# 1) 填充掩码(如果存在 MLM 头)
mlm = pipeline("fill-mask", model=MODEL_ID, tokenizer=tok)
print(mlm("These running shoes are great for  training."))

# 2) 特征提取(CLS 或平均池化嵌入)
enc = AutoModel.from_pretrained(MODEL_ID)
inputs = tok(["wireless mouse", "ergonomic mouse pad"], padding=True, truncation=True, return_tensors="pt")
with torch.no_grad():
    out = enc(**inputs, output_hidden_states=True)
# 对最后一层隐藏状态进行平均池化以获得句子嵌入
emb = (out.last_hidden_state * inputs.attention_mask.unsqueeze(-1)).sum(dim=1) / inputs.attention_mask.sum(dim=1, keepdim=True)

预期用途与限制

适用场景

  • 产品和查询的检索/语义搜索(标题、描述、属性)
  • 属性提取/槽位填充(品牌、颜色、尺寸、材质)
  • 分类(类目分配、不安全/受限商品过滤、评论情感分析)
  • 重排序查询理解(拼写/ASR 规范化、缩写扩展)

不适用场景

  • 长文本生成(请使用解码器/序列到序列语言模型)
  • 未经人工审核的高风险决策(定价、合规、安全标记)

目标用户

  • 搜索/推荐工程师、电商数据团队、从事领域专用编码器研究的机器学习研究人员

模型描述

RexBERT-base 是一个仅编码器的 1.5 亿参数 Transformer 模型,采用掩码语言建模目标进行训练,并针对电商相关文本进行了优化。三阶段训练课程提升了通用语言理解能力,扩展了上下文处理能力,然后在超大规模电商语料上进行专业化,以捕获领域特定的术语和实体分布。


训练方案

RexBERT-base 分三个阶段进行训练:

  1. 预训练 在多样化的英文文本上进行通用 MLM 预训练,以获得稳健的语言表示。

  2. 上下文扩展 使用增加的最大序列长度进行持续训练,以更好地处理长产品页面、拼接的属性块、多轮查询和分面字符串。这保留了之前的能力,同时扩展了上下文处理能力。

  3. 在 350B+ 电商 token 上衰减350B+ 领域特定 token(产品目录、查询、评论、分类/属性)上进行最终专业化阶段。学习率和采样权重进行退火(衰减),以巩固领域知识并稳定商业任务上的性能。

训练细节(待补充):

  • 优化器 / 学习率调度:TODO
  • 每阶段的有效批量大小 / 步数:TODO
  • 每阶段的上下文长度(如 512 → 1k/2k):TODO
  • 分词器/词表:TODO
  • 硬件与训练时长:TODO
  • 检查点标签:TODO(如 pretrainextdecay

数据概览

我们识别出 9 个与电商重叠且有大量相关 token 但需要过滤的领域。以下是领域列表及其过滤后的大小:

领域 大小 (GB)
Hobby(爱好) 114
News(新闻) 66
Health(健康) 66
Entertainment(娱乐) 64
Travel(旅游) 52
Food(食品) 22
Automotive(汽车) 19
Sports(体育) 12
Music and Dance(音乐与舞蹈) 7

此外,还有 6 个几乎完全重叠的领域,直接从 FineFineWeb 中选取:

领域 大小 (GB)
Fashion(时尚) 37
Beauty(美妆) 37
Celebrity(名人) 28
Movie(电影) 26
Photo(摄影) 15
Painting(绘画) 2

通过聚焦这些领域,我们将搜索范围缩小到网络数据中可能出现购物相关文本的部分。然而,即使在选定的领域内,并非每条内容都真正与买卖相关,许多可能是信息性文章、新闻或不相关的讨论。因此,需要在每个领域内进行更细粒度的过滤,以仅提取电商特定的内容。我们通过为每个领域训练轻量级分类器来区分电商上下文与非电商内容来实现这一点。


评估

Token 分类

image/png

RexBERT 在参数量减少 2-3 倍的情况下,超越了 ModernBERT 系列的性能。

语义相似度

image/png

RexBERT 模型在其参数/尺寸类别中超越了所有其他模型。


使用示例

1) 掩码语言建模

from transformers import AutoModelForMaskedLM, AutoTokenizer, pipeline

m = AutoModelForMaskedLM.from_pretrained("thebajajra/RexBERT-base")
t = AutoTokenizer.from_pretrained("thebajajra/RexBERT-base")
fill = pipeline("fill-mask", model=m, tokenizer=t)

fill("Best  headphones under $100.")

2) 嵌入 / 特征提取

import torch
from transformers import AutoTokenizer, AutoModel

tok = AutoTokenizer.from_pretrained("thebajajra/RexBERT-base")
enc = AutoModel.from_pretrained("thebajajra/RexBERT-base")

texts = ["nike air zoom pegasus 40", "running shoes pegasus zoom nike"]
batch = tok(texts, padding=True, truncation=True, return_tensors="pt")

with torch.no_grad():
    out = enc(**batch)
# 对最后一层隐藏状态进行平均池化
attn = batch["attention_mask"].unsqueeze(-1)
emb = (out.last_hidden_state * attn).sum(1) / attn.sum(1)
# 归一化以用于余弦相似度(推荐用于检索)
emb = torch.nn.functional.normalize(emb, p=2, dim=1)

3) 文本分类微调

from transformers import AutoTokenizer, AutoModelForSequenceClassification, TrainingArguments, Trainer

tok = AutoTokenizer.from_pretrained("thebajajra/RexBERT-base")
model = AutoModelForSequenceClassification.from_pretrained("thebajajra/RexBERT-base", num_labels=NUM_LABELS)

# 准备你的 Dataset 对象:train_ds, val_ds(text→label)
args = TrainingArguments(
    per_device_train_batch_size=32,
    per_device_eval_batch_size=32,
    learning_rate=3e-5,
    num_train_epochs=3,
    evaluation_strategy="steps",
    fp16=True,
    report_to="none",
    load_best_model_at_end=True,
)

trainer = Trainer(model=model, args=args, train_dataset=train_ds, eval_dataset=val_ds, tokenizer=tok)
trainer.train()

模型架构与兼容性

  • 架构: 仅编码器,ModernBERT 风格的 base 模型。
  • 库: 兼容 🤗 Transformers;支持 fill-maskfeature-extraction 流水线。
  • 上下文长度:上下文扩展阶段增加——确保 config.json 中的 max_position_embeddings 与你所需的最大长度匹配。
  • 文件: config.json、分词器文件,以及(可选的)MLM 或分类头。
  • 导出: 标准 PyTorch 权重;如需生产环境使用,可导出 ONNX / TorchScript。

负责任与安全使用

  • 偏见: 电商数据可能包含品牌、价格和地区偏见;审核下游分类器/检索器在不同类别/地区的错误率差异。
  • 敏感内容: 为成人/受限商品添加过滤器;如发布分类器,请记录审核阈值。
  • 隐私: 不要暴露个人身份信息(PII);确保训练数据符合相关条款和适用法律。
  • 滥用: 此模型不能替代商品上架的法律/合规审核。

许可证

  • 许可证: apache-2.0

维护者与联系方式


thomasht86/RexBERT-base-ONNX

作者 thomasht86

fill-mask transformers.js
↓ 1 ♥ 0

创建时间: 2025-11-13 09:34:16+00:00

更新时间: 2025-11-13 09:34:38+00:00

在 Hugging Face 上查看

文件 (15)

.gitattributes
README.md
config.json
onnx/model.onnx ONNX
onnx/model_bnb4.onnx ONNX
onnx/model_fp16.onnx ONNX
onnx/model_int8.onnx ONNX
onnx/model_q4.onnx ONNX
onnx/model_q4f16.onnx ONNX
onnx/model_quantized.onnx ONNX
onnx/model_uint8.onnx ONNX
quantize_config.json
special_tokens_map.json
tokenizer.json
tokenizer_config.json