返回模型
说明文档
ResNet50-APTOS-DR-ONNX 模型
本仓库包含一个 ResNet50 模型,该模型最初在 APTOS 数据集上训练用于糖尿病视网膜病变(DR)检测,现已导出为 ONNX 格式以便高效推理。
模型概述
- 架构:ResNet50
- 任务:糖尿病视网膜病变分类(5 个类别:无 DR、轻度 DR、中度 DR、重度 DR、增殖性 DR)
- 格式:ONNX(Opset 18)
使用方法(ONNX 推理)
要使用此模型进行推理,您需要 onnxruntime 库。以下是一个基本示例:
import onnxruntime as ort
import numpy as np
from PIL import Image
from torchvision import transforms
ONNX_MODEL_PATH = "mithu-vit.onnx" # 下载的 ONNX 模型路径
CLASSES = ["无 DR", "轻度 DR", "中度 DR", "重度 DR", "增殖性 DR"]
# 图像预处理(与训练流程匹配)
preprocess = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
def predict_image(image_path):
img = Image.open(image_path).convert('RGB')
input_tensor = preprocess(img)
input_numpy = input_tensor.unsqueeze(0).numpy() # 添加批次维度
session = ort.InferenceSession(ONNX_MODEL_PATH)
input_name = session.get_inputs()[0].name
output_name = session.get_outputs()[0].name
outputs = session.run([output_name], {input_name: input_numpy})
logits = outputs[0][0]
probs = np.exp(logits) / np.sum(np.exp(logits))
pred_index = np.argmax(probs)
print(f"预测类别: {CLASSES[pred_index]} (类别 {pred_index})")
print(f"置信度: {probs[pred_index] * 100:.2f}%")
print("所有概率:")
for i, p in enumerate(probs):
print(f" {CLASSES[i]}: {p*100:.2f}%")
# 示例用法:
# predict_image("path/to/your/image.jpg")
微调
原始模型使用 PyTorch 训练。如果您希望在自定义数据集上微调此模型,或用于略有不同的任务,可以使用原始 PyTorch 权重(如果可用),或在合适的框架中适配 ONNX 模型以进行进一步训练。
微调步骤通常包括:
- 加载预训练模型:从原始 PyTorch 模型或兼容迁移学习的版本开始。
- 准备数据集:确保图像已正确标注和预处理(调整为 224x224,使用 ImageNet 统计量进行归一化)。
- 修改头部:替换最后的分类层以匹配新数据集的类别数量。
- 定义优化器和损失函数:为您的微调任务选择合适的设置。
- 训练:微调模型,通常使用比初始训练更低的学习率,重点训练新的头部,并可能解冻较早的层以进行更细粒度的调整。
- 导出为 ONNX:微调后,按照与原始导出过程类似的步骤将更新后的模型导出为 ONNX 格式。
推荐的微调框架:
Shadow0482/ResNet50-APTOS-DR-ONNX
作者 Shadow0482
image-classification
↓ 0
♥ 0
创建时间: 2026-01-06 16:31:08+00:00
更新时间: 2026-01-06 16:56:15+00:00
在 Hugging Face 上查看文件 (4)
.gitattributes
README.md
mithu-vit.onnx
ONNX
mithu-vit.onnx.data