ONNX 模型库
返回模型

说明文档

基于 ConvNeXt CNN 的阿尔茨海默病分类 🧠

<p align="center"> <img src="https://static.vecteezy.com/system/resources/previews/002/543/044/non_2x/world-alzheimer-day-with-brain-and-icons-vector.jpg" width="250" height="220"> </p>

<div align="center" style="display: flex; justify-content: center; gap: 4px; flex-wrap: wrap;"> <img src="https://img.shields.io/badge/PyTorch-%23EE4C2C.svg?style=for-the-badge&logo=PyTorch&logoColor=white" alt="PyTorch"> <img src="https://img.shields.io/badge/timm-%2300C4B4.svg?style=for-the-badge&logo=timm&logoColor=white" alt="timm"> <img src="https://img.shields.io/badge/ONNX-E5E5E5.svg?style=for-the-badge&logo=ONNX&logoColor=black" alt="ONNX"> <img src="https://img.shields.io/badge/MLflow-%2300B7EB.svg?style=for-the-badge&logo=MLflow&logoColor=white" alt="MLflow"> <img src="https://img.shields.io/badge/Databricks-%23FF3621.svg?style=for-the-badge&logo=Databricks&logoColor=white" alt="Databricks"> </div>

<p align="center"> <strong>基于 ConvNeXt 的 CNN 模型,用于从 MRI 图像分类阿尔茨海默病阶段</strong> </p>


🎯 模型概述

一个预训练的 ConvNeXt CNN 模型经过微调,用于从脑部 MRI 扫描图像分类阿尔茨海默病的阶段。该模型区分 4 个类别:

类别

  • 非痴呆
  • 极轻度痴呆
  • 轻度痴呆
  • 中度痴呆

🏗️ 架构与方法

骨干模型:ConvNeXt(Meta AI,2022)——一种受视觉 Transformer 启发但完全基于卷积构建的现代 SOTA CNN 架构。

迁移学习策略:采用渐进式解冻方法,基于 ImageNet 预训练权重:

步骤 1:仅训练分类头(10 个 epoch)
                        ↓
步骤 2:解冻 stage 3 + 分类头(5 个 epoch,降低学习率)
                        ↓
步骤 3:解冻 stage 2 + stage 3 + 分类头(5 个 epoch,进一步降低学习率)

📊 数据集与性能

<div align="center"> <img src="https://huggingface.co/KaiSKX/Alzheimer_ConvNeXtCNN/resolve/main/Screenshot%202025-09-26%20170634.png" alt="Training Dataset Sample" width="600"> <p><em>用于训练的阿尔茨海默病 MRI 数据集样本</em></p> </div>

⚡ 混淆矩阵与准确率

<div align="center"> <img src="Screenshot%202025-09-26%20170726.png" alt="CMatrix" width="600"> <p><em>阿尔茨海默病分类的混淆矩阵</em></p> </div>

测试准确率:在 4 个类别上达到 98.78%

⚡ 推理速度基准测试(CPU)

格式 推理时间 相对速度
PyTorch (.pth) 148.16 毫秒/图像 基准
ONNX (.onnx) 132.48 毫秒/图像 快 1.12 倍

benchmark

benchmark

注意:该结果基于 Kaggle notebook CPU 环境。不同环境会有不同效果,但通常 ONNX 运行时引擎具有更好的推理速度。


🚀 快速开始

下载模型

from huggingface_hub import snapshot_download

# 下载整个模型目录
model_dir = snapshot_download(repo_id="KaiSKX/Alzheimer_ConvNeXtCNN", repo_type="model")

# 或下载特定格式
from huggingface_hub import hf_hub_download
onnx_path = hf_hub_download(repo_id="KaiSKX/Alzheimer_ConvNeXtCNN", filename="onnx/convnext_model.onnx", repo_type="model")

# 或下载特定格式
from huggingface_hub import hf_hub_download
onnx_path = hf_hub_download(repo_id="KaiSKX/Alzheimer_ConvNeXtCNN", filename="data/model.pth", repo_type="model")

💻 使用方法

MLflow 模型(PyTorch 风格)

最适合: MLflow/Databricks 部署或推理

import mlflow.pytorch
import torch
from PIL import Image
import torchvision.transforms as transforms
import time

# 加载 MLflow 模型
model = mlflow.pytorch.load_model("Alzheimer_ConvNeXtCNN")
model.eval()

# 定义类别名称映射
class_names = ["Mild Demented", "Moderate Demented", "Non Demented", "Very Mild Demented"]

# 准备输入
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

image = Image.open("C:/Users/dream/Downloads/test_dataset/ModerateDemented/moderateDem39.jpg").convert("RGB")
input_data = transform(image).unsqueeze(0)

# 推理并计时 [时间测量可选]
with torch.no_grad():
    start_time = time.time()
    output = model(input_data)
    end_time = time.time()

predicted_class = torch.argmax(output, dim=1).item()
inference_time = (end_time - start_time) * 1000

print(f"预测类别: {class_names[predicted_class]}")
print(f"推理时间: {inference_time:.2f} 毫秒")

ONNX Runtime

最适合: 跨平台兼容性和更快的 CPU 推理

import onnxruntime as ort
import numpy as np
from PIL import Image
import torchvision.transforms as transforms
import time

# 加载 ONNX 模型
session = ort.InferenceSession("onnx/convnext_model.onnx")
input_name = session.get_inputs()[0].name

# 定义类别名称映射
class_names = ["Mild Demented", "Moderate Demented", "Non Demented", "Very Mild Demented"]

# 准备输入
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

image = Image.open("C:/Users/dream/Downloads/test_dataset/ModerateDemented/moderateDem39.jpg").convert("RGB")
input_data = transform(image).unsqueeze(0).numpy().astype(np.float32)

# 推理并计时 [时间测量可选]
start_time = time.time()
output = session.run(None, {input_name: input_data})[0]
end_time = time.time()

inference_time = (end_time - start_time) * 1000
predicted_class = np.argmax(output, axis=1)[0]

print(f"预测类别: {class_names[predicted_class]}")
print(f"推理时间: {inference_time:.2f} 毫秒")

🔧 训练详情

  • 数据集:公开的阿尔茨海默病 MRI 数据集(见 Kaggle 数据集
  • 预训练:通过 timm 获取 ImageNet 预训练的 ConvNeXt 权重
  • 多 GPU:通过 2 块 NVIDIA T4 实现并行训练加速
  • MLOps:集成 Databricks MLflow 用于实验跟踪和模型注册

⚠️ 重要声明

仅供研究与作品集展示使用:本模型是为教育和作品集展示目的而开发的。它不适用于专业临床诊断、医疗决策或任何实际医疗应用。

有关详细的训练方法和 MLOps 实现,请参考 Kaggle notebook

KaiSKX/Alzheimer_ConvNeXtCNN

作者 KaiSKX

image-classification
↓ 0 ♥ 0

创建时间: 2025-09-25 06:14:48+00:00

更新时间: 2025-09-28 09:47:52+00:00

在 Hugging Face 上查看

文件 (16)

.gitattributes
MLmodel
README.md
Screenshot 2025-09-26 170634.png
Screenshot 2025-09-26 170726.png
Screenshot 2025-09-26 170736.png
Screenshot 2025-09-26 170746.png
conda.yaml
data/model.pth
data/pickle_module_info.txt
input_example.json
onnx/convnext_model.onnx ONNX
python_env.yaml
registered_model_meta
requirements.txt
serving_input_example.json