返回模型
说明文档
模型卡片
jaeyong2/gte-multilingual-base-Ja-embedding 的 ONNX 版本。在 jaeyong2/Ja-emb-PreView 数据集上对 Alibaba-NLP/gte-multilingual-base 模型进行微调,以更好地适应日语。
模型详情
Alibaba-NLP/gte-multilingual-base
训练
- 数据:jaeyong2/Ja-emb-PreView
import torch
import datasets
from torch.utils.data import DataLoader
from transformers import AutoTokenizer, AutoModel
from torch.optim import AdamW
from tqdm import tqdm
from torch import nn
class TripletLoss(nn.Module):
def __init__(self, margin=1.0):
super(TripletLoss, self).__init__()
self.margin = margin
def forward(self, anchor, positive, negative):
distance_positive = (anchor - positive).pow(2).sum(1)
distance_negative = (anchor - negative).pow(2).sum(1)
losses = torch.relu(distance_positive - distance_negative + self.margin)
return losses.mean()
def batch_to_device(batch, device):
return {key: value.to(device) for key, value in batch.items()}
model_name = "Alibaba-NLP/gte-multilingual-base"
dataset = datasets.load_dataset("jaeyong2/Ja-emb-PreView")
train_dataloader = DataLoader(dataset['train'], batch_size=8, shuffle=True)
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModel.from_pretrained(model_name)
triplet_loss = TripletLoss(margin=1.0)
optimizer = AdamW(model.parameters(), lr=5e-5)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
if torch.cuda.is_available() and torch.cuda.is_bf16_supported():
model = model.to(torch.bfloat16)
for epoch in range(3):
model.train()
total_loss = 0
count = 0
print(f"\nEpoch {epoch + 1}/3")
for batch in tqdm(train_dataloader, desc=f"Training Epoch {epoch + 1}"):
optimizer.zero_grad()
loss = None
for index in range(len(batch["context"])):
anchor_encodings = tokenizer(
[batch["context"][index]],
truncation=True,
padding="max_length",
max_length=1024,
return_tensors="pt"
)
positive_encodings = tokenizer(
[batch["Title"][index]],
truncation=True,
padding="max_length",
max_length=256,
return_tensors="pt"
)
negative_encodings = tokenizer(
[batch["Fake Title"][index]],
truncation=True,
padding="max_length",
max_length=256,
return_tensors="pt"
)
anchor_encodings = batch_to_device(anchor_encodings, device)
positive_encodings = batch_to_device(positive_encodings, device)
negative_encodings = batch_to_device(negative_encodings, device)
anchor_output = model(**anchor_encodings)[0][:, 0, :]
positive_output = model(**positive_encodings)[0][:, 0, :]
negative_output = model(**negative_encodings)[0][:, 0, :]
if loss == None:
loss = triplet_loss(anchor_output, positive_output, negative_output)
else:
loss += triplet_loss(anchor_output, positive_output, negative_output)
loss /= len(batch["context"])
loss.backward()
optimizer.step()
total_loss += loss.item()
count += 1
avg_loss = total_loss / count
print(f"Epoch {epoch + 1} - Average Loss: {avg_loss:.4f}")
推理
代码:
model = ONNXEmbeddingModel(model_path)
multilingual_texts = [
"Machine learning is fascinating",
"機械学習は魅力的です", # Japanese: Machine learning is fascinating
"L'apprentissage automatique est fascinant", # French
]
ml_embeddings = model.encode(multilingual_texts, normalize=True)
ml_similarities = model.similarity(ml_embeddings, ml_embeddings)
print("Cross-lingual similarities:")
for i, text1 in enumerate(multilingual_texts):
for j, text2 in enumerate(multilingual_texts):
if i < j:
sim = ml_similarities[i, j]
print(f" {sim:.4f}: '{text1[:30]}...' <-> '{text2[:30]}...'")
许可证
- Alibaba-NLP/gte-multilingual-base : https://choosealicense.com/licenses/apache-2.0/
bunbohue/Japanese-gte-multilingual-base-ONNX
作者 bunbohue
sentence-similarity
transformers
↓ 0
♥ 0
创建时间: 2025-12-16 11:14:59+00:00
更新时间: 2025-12-17 07:04:41+00:00
在 Hugging Face 上查看文件 (9)
.gitattributes
README.md
config.json
configuration.py
model.onnx
ONNX
model.onnx.data
special_tokens_map.json
tokenizer.json
tokenizer_config.json