Utils API
工具函数模块,提供常用的辅助功能。
概述
Utils 模块包含以下工具:
- 随机种子: 确保实验可复现
- 设备管理: GPU/CPU 自动检测和管理
- 检查点: 模型保存和加载
- 日志: 统一的日志系统
- 梯度检查点: 内存优化技术
随机种子
set_seed
设置所有随机种子以确保可复现性。
参数:
seed(int): 随机种子值deterministic(bool): 是否使用确定性算法,默认 True
影响范围:
- Python
random模块 - NumPy 随机数生成器
- PyTorch CPU 和 CUDA 随机数生成器
- cuDNN 后端行为
示例:
python
from med_core.utils import set_seed
# 设置随机种子
set_seed(42, deterministic=True)
# 现在所有随机操作都是可复现的注意:
deterministic=True会降低性能(10-30%)- 生产环境可以设置为
False以提高速度 - 分布式训练需要在每个进程中设置
设备管理
get_device
获取可用的计算设备。
参数:
device(str): 设备规格,默认 "auto""auto"- 自动检测(优先 CUDA > MPS > CPU)"cuda"- 强制使用 CUDA"cpu"- 强制使用 CPU"mps"- 强制使用 Apple Silicon GPU
返回:
torch.device: PyTorch 设备对象
示例:
python
from med_core.utils import get_device
# 自动检测
device = get_device("auto")
print(device) # cuda:0 或 cpu
# 强制使用 CPU
device = get_device("cpu")
# 将模型移到设备
model = model.to(device)get_device_info
获取设备信息。
返回:
dict[str, Any]: 设备信息字典
包含信息:
cuda_available- CUDA 是否可用cuda_device_count- CUDA 设备数量cuda_device_name- GPU 名称cuda_memory_total- GPU 总显存mps_available- Apple MPS 是否可用
示例:
python
from med_core.utils import get_device_info
info = get_device_info()
print(f"CUDA 可用: {info['cuda_available']}")
print(f"GPU 数量: {info['cuda_device_count']}")
if info['cuda_available']:
print(f"GPU 名称: {info['cuda_device_name']}")
print(f"显存: {info['cuda_memory_total'] / 1e9:.2f} GB")move_to_device
将张量或模型移动到指定设备。
参数:
obj(Any): 张量、模型、字典或列表device(torch.device): 目标设备
返回:
Any: 移动后的对象
支持类型:
torch.Tensor- 单个张量nn.Module- 模型dict- 张量字典(递归处理)list/tuple- 张量列表(递归处理)
示例:
python
from med_core.utils import move_to_device, get_device
device = get_device("cuda")
# 移动单个张量
tensor = torch.randn(10, 20)
tensor = move_to_device(tensor, device)
# 移动字典
batch = {
'images': torch.randn(32, 3, 224, 224),
'tabular': torch.randn(32, 64),
'labels': torch.randint(0, 2, (32,))
}
batch = move_to_device(batch, device)
# 移动模型
model = move_to_device(model, device)检查点管理
save_checkpoint
保存模型检查点。
参数:
model(nn.Module): 模型optimizer(Optimizer): 优化器epoch(int): 当前轮次filepath(str | Path): 保存路径metrics(dict[str, float]): 可选的指标字典scheduler(Any): 可选的学习率调度器**kwargs: 其他要保存的内容
示例:
python
from med_core.utils import save_checkpoint
save_checkpoint(
model=model,
optimizer=optimizer,
epoch=50,
filepath="checkpoints/epoch_50.pth",
metrics={
'train_loss': 0.234,
'val_loss': 0.267,
'val_auc': 0.892
},
scheduler=scheduler,
config=config.to_dict()
)load_checkpoint
加载模型检查点。
参数:
filepath(str | Path): 检查点路径model(nn.Module): 模型optimizer(Optimizer): 可选的优化器scheduler(Any): 可选的调度器device(str): 设备,默认 "cpu"
返回:
dict[str, Any]: 检查点元数据
示例:
python
from med_core.utils import load_checkpoint
checkpoint = load_checkpoint(
filepath="checkpoints/best.pth",
model=model,
optimizer=optimizer,
scheduler=scheduler,
device="cuda"
)
print(f"从轮次 {checkpoint['epoch']} 恢复")
print(f"验证 AUC: {checkpoint['metrics']['val_auc']:.4f}")find_best_checkpoint
查找最佳检查点。
参数:
checkpoint_dir(str | Path): 检查点目录metric(str): 指标名称,默认 "val_loss"mode(str): 模式 ("min" 或 "max"),默认 "min"
返回:
Path: 最佳检查点路径
示例:
python
from med_core.utils import find_best_checkpoint
best_ckpt = find_best_checkpoint(
checkpoint_dir="outputs/checkpoints",
metric="val_auc",
mode="max"
)
print(f"最佳检查点: {best_ckpt}")cleanup_checkpoints
清理旧的检查点文件。
参数:
checkpoint_dir(str | Path): 检查点目录keep_last_n(int): 保留最近 N 个检查点,默认 5keep_best(bool): 是否保留 best.pth,默认 True
示例:
python
from med_core.utils import cleanup_checkpoints
# 只保留最近 3 个检查点和 best.pth
cleanup_checkpoints(
checkpoint_dir="outputs/checkpoints",
keep_last_n=3,
keep_best=True
)日志系统
setup_logging
配置日志系统。
参数:
log_level(str): 日志级别,默认 "INFO"log_file(str | Path): 可选的日志文件路径format_string(str): 可选的日志格式
示例:
python
from med_core.utils import setup_logging
# 基础配置
setup_logging(log_level="INFO")
# 保存到文件
setup_logging(
log_level="DEBUG",
log_file="outputs/training.log"
)get_logger
获取日志记录器。
参数:
name(str): 日志记录器名称
返回:
logging.Logger: 日志记录器对象
示例:
python
from med_core.utils import get_logger
logger = get_logger(__name__)
logger.info("开始训练")
logger.debug("批次大小: 32")
logger.warning("学习率较高,可能不稳定")
logger.error("加载数据失败")梯度检查点
梯度检查点是一种内存优化技术,通过重新计算中间激活值来减少显存占用。
apply_gradient_checkpointing
为模型应用梯度检查点。
参数:
model(nn.Module): 模型checkpoint_segments(int): 检查点段数,默认 2
示例:
python
from med_core.utils import apply_gradient_checkpointing
# 应用梯度检查点
model = apply_gradient_checkpointing(model, checkpoint_segments=4)
# 显存占用减少 50-70%,训练速度降低 20-30%checkpoint_sequential
对顺序模块应用梯度检查点。
参数:
functions(list[nn.Module]): 模块列表segments(int): 分段数input(torch.Tensor): 输入张量
返回:
torch.Tensor: 输出张量
示例:
python
from med_core.utils import checkpoint_sequential
layers = [layer1, layer2, layer3, layer4]
output = checkpoint_sequential(layers, segments=2, input=x)estimate_memory_savings
估算梯度检查点的显存节省。
参数:
model(nn.Module): 模型input_size(tuple): 输入大小
返回:
dict[str, float]: 显存统计
示例:
python
from med_core.utils import estimate_memory_savings
savings = estimate_memory_savings(
model=model,
input_size=(32, 3, 224, 224)
)
print(f"原始显存: {savings['original_memory_mb']:.2f} MB")
print(f"优化后显存: {savings['optimized_memory_mb']:.2f} MB")
print(f"节省: {savings['savings_percent']:.1f}%")使用示例
完整训练脚本
python
from med_core.utils import (
set_seed,
get_device,
get_device_info,
setup_logging,
get_logger,
save_checkpoint,
load_checkpoint
)
# 1. 设置日志
setup_logging(log_level="INFO", log_file="training.log")
logger = get_logger(__name__)
# 2. 设置随机种子
set_seed(42, deterministic=True)
logger.info("随机种子已设置")
# 3. 检查设备
device_info = get_device_info()
logger.info(f"设备信息: {device_info}")
device = get_device("auto")
logger.info(f"使用设备: {device}")
# 4. 训练循环
for epoch in range(num_epochs):
# 训练代码...
# 保存检查点
if epoch % 10 == 0:
save_checkpoint(
model=model,
optimizer=optimizer,
epoch=epoch,
filepath=f"checkpoints/epoch_{epoch}.pth",
metrics=metrics
)
logger.info(f"检查点已保存: epoch {epoch}")
# 5. 加载最佳模型
checkpoint = load_checkpoint(
filepath="checkpoints/best.pth",
model=model,
device=device
)
logger.info(f"加载最佳模型: epoch {checkpoint['epoch']}")显存优化
python
from med_core.utils import (
apply_gradient_checkpointing,
estimate_memory_savings
)
# 估算显存节省
savings = estimate_memory_savings(
model=model,
input_size=(32, 3, 224, 224)
)
print(f"预计节省显存: {savings['savings_percent']:.1f}%")
# 应用梯度检查点
if savings['savings_percent'] > 30:
model = apply_gradient_checkpointing(model, checkpoint_segments=4)
print("已启用梯度检查点")最佳实践
随机种子:
- 研究实验使用
deterministic=True - 生产环境使用
deterministic=False - 记录种子值到配置文件
设备管理:
- 使用
get_device("auto")自动检测 - 训练前检查
get_device_info() - 使用
move_to_device()统一处理
检查点:
- 定期保存检查点(每 5-10 轮)
- 始终保存 best.pth
- 定期清理旧检查点
日志:
- 训练使用 INFO 级别
- 调试使用 DEBUG 级别
- 保存日志到文件
梯度检查点:
- 显存不足时使用
- 权衡显存和速度
- 大模型推荐使用
参考
完整实现请参考:
med_core/utils/seed.py- 随机种子med_core/utils/device.py- 设备管理med_core/utils/checkpoint.py- 检查点管理med_core/utils/logging.py- 日志系统med_core/utils/gradient_checkpointing.py- 梯度检查点