说明文档
基于 ConvNeXt CNN 的阿尔茨海默病分类 🧠
<p align="center"> <img src="https://static.vecteezy.com/system/resources/previews/002/543/044/non_2x/world-alzheimer-day-with-brain-and-icons-vector.jpg" width="250" height="220"> </p>
<div align="center" style="display: flex; justify-content: center; gap: 4px; flex-wrap: wrap;"> <img src="https://img.shields.io/badge/PyTorch-%23EE4C2C.svg?style=for-the-badge&logo=PyTorch&logoColor=white" alt="PyTorch"> <img src="https://img.shields.io/badge/timm-%2300C4B4.svg?style=for-the-badge&logo=timm&logoColor=white" alt="timm"> <img src="https://img.shields.io/badge/ONNX-E5E5E5.svg?style=for-the-badge&logo=ONNX&logoColor=black" alt="ONNX"> <img src="https://img.shields.io/badge/MLflow-%2300B7EB.svg?style=for-the-badge&logo=MLflow&logoColor=white" alt="MLflow"> <img src="https://img.shields.io/badge/Databricks-%23FF3621.svg?style=for-the-badge&logo=Databricks&logoColor=white" alt="Databricks"> </div>
<p align="center"> <strong>基于 ConvNeXt 的 CNN 模型,用于从 MRI 图像分类阿尔茨海默病阶段</strong> </p>
🎯 模型概述
一个预训练的 ConvNeXt CNN 模型经过微调,用于从脑部 MRI 扫描图像分类阿尔茨海默病的阶段。该模型区分 4 个类别:
类别:
- 非痴呆
- 极轻度痴呆
- 轻度痴呆
- 中度痴呆
🏗️ 架构与方法
骨干模型:ConvNeXt(Meta AI,2022)——一种受视觉 Transformer 启发但完全基于卷积构建的现代 SOTA CNN 架构。
迁移学习策略:采用渐进式解冻方法,基于 ImageNet 预训练权重:
步骤 1:仅训练分类头(10 个 epoch)
↓
步骤 2:解冻 stage 3 + 分类头(5 个 epoch,降低学习率)
↓
步骤 3:解冻 stage 2 + stage 3 + 分类头(5 个 epoch,进一步降低学习率)
📊 数据集与性能
<div align="center"> <img src="https://huggingface.co/KaiSKX/Alzheimer_ConvNeXtCNN/resolve/main/Screenshot%202025-09-26%20170634.png" alt="Training Dataset Sample" width="600"> <p><em>用于训练的阿尔茨海默病 MRI 数据集样本</em></p> </div>
⚡ 混淆矩阵与准确率
<div align="center"> <img src="Screenshot%202025-09-26%20170726.png" alt="CMatrix" width="600"> <p><em>阿尔茨海默病分类的混淆矩阵</em></p> </div>
测试准确率:在 4 个类别上达到 98.78%。
⚡ 推理速度基准测试(CPU)
| 格式 | 推理时间 | 相对速度 |
|---|---|---|
| PyTorch (.pth) | 148.16 毫秒/图像 | 基准 |
| ONNX (.onnx) | 132.48 毫秒/图像 | 快 1.12 倍 |


注意:该结果基于 Kaggle notebook CPU 环境。不同环境会有不同效果,但通常 ONNX 运行时引擎具有更好的推理速度。
🚀 快速开始
下载模型
from huggingface_hub import snapshot_download
# 下载整个模型目录
model_dir = snapshot_download(repo_id="KaiSKX/Alzheimer_ConvNeXtCNN", repo_type="model")
# 或下载特定格式
from huggingface_hub import hf_hub_download
onnx_path = hf_hub_download(repo_id="KaiSKX/Alzheimer_ConvNeXtCNN", filename="onnx/convnext_model.onnx", repo_type="model")
# 或下载特定格式
from huggingface_hub import hf_hub_download
onnx_path = hf_hub_download(repo_id="KaiSKX/Alzheimer_ConvNeXtCNN", filename="data/model.pth", repo_type="model")
💻 使用方法
MLflow 模型(PyTorch 风格)
最适合: MLflow/Databricks 部署或推理
import mlflow.pytorch
import torch
from PIL import Image
import torchvision.transforms as transforms
import time
# 加载 MLflow 模型
model = mlflow.pytorch.load_model("Alzheimer_ConvNeXtCNN")
model.eval()
# 定义类别名称映射
class_names = ["Mild Demented", "Moderate Demented", "Non Demented", "Very Mild Demented"]
# 准备输入
transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
image = Image.open("C:/Users/dream/Downloads/test_dataset/ModerateDemented/moderateDem39.jpg").convert("RGB")
input_data = transform(image).unsqueeze(0)
# 推理并计时 [时间测量可选]
with torch.no_grad():
start_time = time.time()
output = model(input_data)
end_time = time.time()
predicted_class = torch.argmax(output, dim=1).item()
inference_time = (end_time - start_time) * 1000
print(f"预测类别: {class_names[predicted_class]}")
print(f"推理时间: {inference_time:.2f} 毫秒")
ONNX Runtime
最适合: 跨平台兼容性和更快的 CPU 推理
import onnxruntime as ort
import numpy as np
from PIL import Image
import torchvision.transforms as transforms
import time
# 加载 ONNX 模型
session = ort.InferenceSession("onnx/convnext_model.onnx")
input_name = session.get_inputs()[0].name
# 定义类别名称映射
class_names = ["Mild Demented", "Moderate Demented", "Non Demented", "Very Mild Demented"]
# 准备输入
transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
image = Image.open("C:/Users/dream/Downloads/test_dataset/ModerateDemented/moderateDem39.jpg").convert("RGB")
input_data = transform(image).unsqueeze(0).numpy().astype(np.float32)
# 推理并计时 [时间测量可选]
start_time = time.time()
output = session.run(None, {input_name: input_data})[0]
end_time = time.time()
inference_time = (end_time - start_time) * 1000
predicted_class = np.argmax(output, axis=1)[0]
print(f"预测类别: {class_names[predicted_class]}")
print(f"推理时间: {inference_time:.2f} 毫秒")
🔧 训练详情
- 数据集:公开的阿尔茨海默病 MRI 数据集(见 Kaggle 数据集)
- 预训练:通过
timm获取 ImageNet 预训练的 ConvNeXt 权重 - 多 GPU:通过 2 块
NVIDIA T4实现并行训练加速 - MLOps:集成
Databricks MLflow用于实验跟踪和模型注册
⚠️ 重要声明
仅供研究与作品集展示使用:本模型是为教育和作品集展示目的而开发的。它不适用于专业临床诊断、医疗决策或任何实际医疗应用。
有关详细的训练方法和 MLOps 实现,请参考 Kaggle notebook。
KaiSKX/Alzheimer_ConvNeXtCNN
作者 KaiSKX
创建时间: 2025-09-25 06:14:48+00:00
更新时间: 2025-09-28 09:47:52+00:00
在 Hugging Face 上查看