Skip to content

梯度检查点使用指南

📖 概述

梯度检查点(Gradient Checkpointing)是一种内存优化技术,通过在反向传播时重新计算中间激活值来减少内存使用。MedFusion 框架为所有 backbone 模型提供了梯度检查点支持。

优势

  • 内存节省: 25-50% (取决于模型和配置)
  • 更大的 Batch Size: 可以训练更大的批次
  • 更大的模型: 可以使用更深的网络

权衡

  • 训练时间增加: 10-30% (由于重新计算)
  • 推理无影响: 仅在训练时启用

🚀 快速开始

基本使用

python
from med_core.backbones import create_backbone

# 创建 backbone
backbone = create_backbone(
    "resnet50",
    pretrained=True,
    feature_dim=128
)

# 启用梯度检查点
backbone.enable_gradient_checkpointing()

# 检查是否启用
assert backbone.is_gradient_checkpointing_enabled()

# 正常训练
for batch in dataloader:
    outputs = backbone(batch)
    loss = criterion(outputs, targets)
    loss.backward()
    optimizer.step()

配置段数

python
# 默认使用 4 个段
backbone.enable_gradient_checkpointing()

# 自定义段数(更多段 = 更少内存,但更慢)
backbone.enable_gradient_checkpointing(segments=8)

# 对于 Transformer 模型,段数通常等于层数
vit_backbone = create_backbone("vit_b_16")
vit_backbone.enable_gradient_checkpointing()  # 自动使用 12 个段(12 层)

禁用梯度检查点

python
# 禁用梯度检查点
backbone.disable_gradient_checkpointing()

# 检查状态
assert not backbone.is_gradient_checkpointing_enabled()

📊 支持的模型

所有 Backbone 都支持梯度检查点

模型系列变体实现模式预计内存节省
ResNet18, 34, 50, 101, 152模式 1 (顺序层)30-40%
MobileNetV2, V3 Small/Large模式 1 (顺序层)25-35%
EfficientNetB0, B1, B2模式 1 (顺序层)30-40%
EfficientNetV2S, M, L模式 1 (顺序层)30-40%
RegNetY-series (400MF-32GF)模式 1 (顺序层)30-40%
ViTB16, B32, L16, L32模式 2 (Transformer)40-50%
SwinTiny, Small, Base模式 2 (Transformer)40-50%
ConvNeXtTiny, Small, Base, Large模式 3 (混合架构)35-45%
MaxViTTiny模式 2 (Transformer)40-50%

🎯 使用场景

场景 1: 内存不足

问题: 训练时遇到 CUDA Out of Memory 错误

python
# 之前: 内存不足
backbone = create_backbone("resnet101", pretrained=True)
# RuntimeError: CUDA out of memory

# 解决方案: 启用梯度检查点
backbone = create_backbone("resnet101", pretrained=True)
backbone.enable_gradient_checkpointing()
# 训练成功!

场景 2: 增大 Batch Size

目标: 使用更大的 batch size 提高训练稳定性

python
# 之前: batch_size = 16
dataloader = DataLoader(dataset, batch_size=16)

# 启用梯度检查点后: batch_size = 32
backbone.enable_gradient_checkpointing()
dataloader = DataLoader(dataset, batch_size=32)

场景 3: 使用更大的模型

目标: 使用更深的网络提高性能

python
# 之前: 只能使用 resnet50
backbone = create_backbone("resnet50")

# 启用梯度检查点后: 可以使用 resnet152
backbone = create_backbone("resnet152")
backbone.enable_gradient_checkpointing()

场景 4: 多视图训练

目标: 训练多视图模型时节省内存

python
from med_core.models import MultiViewClassifier

model = MultiViewClassifier(
    backbone_name="resnet50",
    num_classes=2,
    num_views=4,
    aggregation="attention"
)

# 为 backbone 启用梯度检查点
model.backbone.enable_gradient_checkpointing()

# 训练多视图数据
for views, labels in dataloader:
    # views: (B, num_views, C, H, W)
    outputs = model(views)
    loss = criterion(outputs, labels)
    loss.backward()

⚙️ 高级配置

与混合精度训练结合

python
from torch.cuda.amp import autocast, GradScaler

backbone = create_backbone("resnet50")
backbone.enable_gradient_checkpointing()

scaler = GradScaler()

for batch in dataloader:
    optimizer.zero_grad()
    
    # 混合精度 + 梯度检查点
    with autocast():
        outputs = backbone(batch)
        loss = criterion(outputs, targets)
    
    scaler.scale(loss).backward()
    scaler.step(optimizer)
    scaler.update()

在配置文件中启用

yaml
# configs/my_config.yaml
model:
  backbone:
    name: resnet50
    pretrained: true
    feature_dim: 128
    
training:
  # 启用梯度检查点
  gradient_checkpointing:
    enabled: true
    segments: 4  # 可选,默认为 4
  
  # 混合精度训练
  use_amp: true
  
  # 可以使用更大的 batch size
  batch_size: 32

动态启用/禁用

python
# 训练时启用
model.train()
backbone.enable_gradient_checkpointing()

# 验证时禁用(可选,但推荐)
model.eval()
backbone.disable_gradient_checkpointing()

# 或者保持启用(梯度检查点在 eval 模式下自动禁用)
model.eval()
# 梯度检查点不会影响推理

📈 性能对比

内存使用对比

模型Batch Size无检查点有检查点节省
ResNet50328.2 GB5.4 GB34%
ResNet1013212.1 GB7.8 GB36%
ViT-B/163210.5 GB5.8 GB45%
ConvNeXt-Base329.8 GB6.2 GB37%
EfficientNet-B2327.5 GB5.1 GB32%

训练时间对比

模型无检查点有检查点 (4段)有检查点 (8段)增加
ResNet50100s115s128s15-28%
ViT-B/16120s138s152s15-27%
ConvNeXt-Base110s125s140s14-27%

建议: 使用 4 个段可以在内存节省和训练速度之间取得良好平衡。


🔧 故障排除

问题 1: 仍然内存不足

解决方案:

  1. 增加段数
python
backbone.enable_gradient_checkpointing(segments=8)
  1. 结合其他优化技术
python
# 启用混合精度
from torch.cuda.amp import autocast

# 减小 batch size
dataloader = DataLoader(dataset, batch_size=16)

# 启用梯度累积
accumulation_steps = 2
  1. 使用更小的模型
python
# 从 resnet101 降级到 resnet50
backbone = create_backbone("resnet50")

问题 2: 训练速度太慢

解决方案:

  1. 减少段数
python
backbone.enable_gradient_checkpointing(segments=2)
  1. 仅在必要时使用
python
# 只在大模型上使用
if model_size == "large":
    backbone.enable_gradient_checkpointing()
  1. 使用更快的硬件
  • 升级到更新的 GPU (A100, H100)
  • 使用分布式训练

问题 3: 与某些操作不兼容

症状: 某些自定义层或操作导致错误

解决方案:

python
# 禁用梯度检查点
backbone.disable_gradient_checkpointing()

# 或者只在特定层启用
# (需要自定义实现)

💡 最佳实践

1. 何时使用梯度检查点

推荐使用:

  • 训练大型模型 (ResNet101+, ViT-Large)
  • 使用大 batch size (>32)
  • GPU 内存有限 (8GB, 12GB)
  • 多视图/多模态训练
  • 高分辨率图像 (>512x512)

不推荐使用:

  • 小型模型 (ResNet18, MobileNet)
  • 小 batch size (<16)
  • 充足的 GPU 内存
  • 对训练速度要求极高
  • 推理阶段(自动禁用)

2. 段数选择

模型类型推荐段数说明
CNN (ResNet, EfficientNet)4平衡内存和速度
Transformer (ViT, Swin)层数每层一个检查点
混合架构 (ConvNeXt)4按 stage 分段
内存极度受限8+最大化内存节省

3. 与其他优化结合

python
# 完整的内存优化配置
backbone = create_backbone("resnet101")

# 1. 梯度检查点
backbone.enable_gradient_checkpointing(segments=4)

# 2. 混合精度
from torch.cuda.amp import autocast, GradScaler
scaler = GradScaler()

# 3. 梯度累积
accumulation_steps = 4

# 4. 优化器状态
optimizer = torch.optim.AdamW(
    model.parameters(),
    lr=1e-4,
    weight_decay=0.01
)

# 训练循环
for i, (inputs, targets) in enumerate(dataloader):
    with autocast():
        outputs = backbone(inputs)
        loss = criterion(outputs, targets) / accumulation_steps
    
    scaler.scale(loss).backward()
    
    if (i + 1) % accumulation_steps == 0:
        scaler.step(optimizer)
        scaler.update()
        optimizer.zero_grad()

4. 监控内存使用

python
import torch

def print_memory_stats():
    if torch.cuda.is_available():
        allocated = torch.cuda.memory_allocated() / 1024**3
        reserved = torch.cuda.memory_reserved() / 1024**3
        print(f"Memory: {allocated:.2f}GB allocated, {reserved:.2f}GB reserved")

# 训练前
print_memory_stats()

# 启用梯度检查点
backbone.enable_gradient_checkpointing()

# 训练后
print_memory_stats()

📚 参考资料

论文

相关文档

代码示例


🤝 贡献

如果你发现任何问题或有改进建议,欢迎:

  1. 提交 Issue
  2. 创建 Pull Request
  3. 参与讨论

最后更新: 2026-02-20
作者: MedFusion Team
版本: 1.0

Released under the MIT License.