Skip to content

数据缓存和预取优化

概述

数据加载是深度学习训练中的常见瓶颈。MedFusion 提供了多种缓存策略来加速数据加载,减少 I/O 等待时间。

功能特性

1. LRU 缓存

最近最少使用(Least Recently Used)缓存,自动淘汰最久未使用的数据。

特点:

  • O(1) 查找和更新
  • 自动容量管理
  • 线程安全
  • 统计信息跟踪

适用场景:

  • 随机访问模式
  • 存在数据重复访问
  • 内存充足

示例:

python
from med_core.datasets.cache import CachedDataset

# 包装原始数据集
cached_dataset = CachedDataset(
    dataset=original_dataset,
    cache_size=1000,  # 缓存 1000 个样本
)

# 查看缓存统计
stats = cached_dataset.get_cache_stats()
print(f"命中率: {stats['hit_rate']:.2%}")

2. 预取缓存

后台预加载即将访问的数据,隐藏 I/O 延迟。

特点:

  • 多线程预取
  • 自动调度
  • 零阻塞访问

适用场景:

  • 顺序访问模式
  • I/O 密集型数据
  • 多核 CPU

示例:

python
from med_core.datasets.cache import PrefetchDataset

# 启用预取
prefetch_dataset = PrefetchDataset(
    dataset=original_dataset,
    prefetch_size=10,   # 预取队列大小
    num_workers=2,      # 预取线程数
)

3. 组合缓存

LRU + 预取,结合两种策略的优势。

适用场景:

  • 混合访问模式
  • 大规模训练
  • 追求最佳性能

示例:

python
from med_core.datasets.cache import create_cached_dataset

# 使用工厂函数创建
cached_dataset = create_cached_dataset(
    dataset=original_dataset,
    cache_type="both",      # LRU + 预取
    cache_size=1000,
    prefetch_size=10,
)

4. 内存映射缓存

磁盘缓存,使用 numpy memmap 减少内存占用。

特点:

  • 低内存占用
  • 持久化缓存
  • 适合大数据集

适用场景:

  • 内存受限
  • 超大数据集
  • 需要持久化

示例:

python
from med_core.datasets.cache import MemoryMappedCache

cache = MemoryMappedCache(
    cache_dir="./cache",
    max_size_gb=10.0,  # 最大 10GB
)

# 存储数据
cache.put("key", numpy_array)

# 读取数据
data = cache.get("key")

性能对比

测试场景

  • 数据集大小: 10,000 样本
  • 访问模式: 随机访问,每个样本访问 2-3 次
  • 数据加载时间: 10ms/样本

结果

策略训练时间加速比内存占用
无缓存100 分钟1.0x
LRU (1000)45 分钟2.2x
预取 (10)70 分钟1.4x
LRU + 预取35 分钟2.9x

使用指南

1. 选择缓存策略

python
# 决策树
if 内存充足 and 存在重复访问:
    使用 LRU 缓存
elif 顺序访问 and 多核CPU:
    使用预取缓存
elif 追求最佳性能:
    使用 LRU + 预取
elif 内存受限:
    使用内存映射缓存
else:
    不使用缓存

2. 调整缓存大小

python
# 小数据集(< 1000 样本)
cache_size = len(dataset)

# 中等数据集(1000-10000 样本)
cache_size = batch_size * 20

# 大数据集(> 10000 样本)
cache_size = 1000  # 固定大小

3. 监控缓存效果

python
# 训练循环中
for epoch in range(num_epochs):
    for batch in dataloader:
        # 训练代码
        pass
    
    # 每个 epoch 结束后检查
    if hasattr(dataset, 'get_cache_stats'):
        stats = dataset.get_cache_stats()
        print(f"Epoch {epoch}:")
        print(f"  命中率: {stats['hit_rate']:.2%}")
        print(f"  缓存大小: {stats['size']}/{stats['capacity']}")
        
        # 根据命中率调整
        if stats['hit_rate'] < 0.3:
            print("  ⚠️ 命中率低,考虑增加缓存大小")

4. 完整示例

python
from torch.utils.data import DataLoader
from med_core.datasets import MedicalDataset
from med_core.datasets.cache import create_cached_dataset

# 1. 创建原始数据集
dataset = MedicalDataset(
    csv_path="data/train.csv",
    image_dir="data/images/",
    transform=train_transforms,
)

# 2. 添加缓存
cached_dataset = create_cached_dataset(
    dataset,
    cache_type="both",
    cache_size=min(1000, len(dataset)),
    prefetch_size=10,
)

# 3. 创建 DataLoader
dataloader = DataLoader(
    cached_dataset,
    batch_size=32,
    shuffle=True,
    num_workers=4,
    pin_memory=True,
)

# 4. 训练
for epoch in range(num_epochs):
    for images, tabular, labels in dataloader:
        # 训练代码
        outputs = model(images, tabular)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
    
    # 查看缓存统计
    stats = cached_dataset.get_cache_stats()
    print(f"Epoch {epoch} - 缓存命中率: {stats['hit_rate']:.2%}")

最佳实践

✅ 推荐做法

  1. 训练阶段使用缓存

    • 训练时数据会被多次访问
    • 缓存可以显著减少 I/O 时间
  2. 根据内存调整缓存大小

    • 监控内存使用
    • 避免 OOM 错误
  3. 监控命中率

    • 命中率 > 70% 表示缓存有效
    • 命中率 < 30% 考虑调整策略
  4. 结合 DataLoader 的 num_workers

    • 缓存 + 多进程加载效果更好
    • 推荐 num_workers=4-8

❌ 避免的做法

  1. 不要在推理时使用过大缓存

    • 推理通常是单次访问
    • 缓存收益有限
  2. 不要忽略内存限制

    • 缓存会占用内存
    • 可能导致 OOM
  3. 不要盲目增加缓存大小

    • 超过工作集大小无益
    • 浪费内存

性能调优

1. 缓存大小调优

python
# 实验不同的缓存大小
cache_sizes = [100, 500, 1000, 2000]
results = {}

for size in cache_sizes:
    dataset = create_cached_dataset(
        original_dataset,
        cache_type="lru",
        cache_size=size,
    )
    
    # 运行一个 epoch
    start = time.time()
    for batch in DataLoader(dataset, batch_size=32):
        pass
    elapsed = time.time() - start
    
    stats = dataset.get_cache_stats()
    results[size] = {
        "time": elapsed,
        "hit_rate": stats['hit_rate'],
    }

# 选择最佳大小
best_size = min(results, key=lambda k: results[k]['time'])
print(f"最佳缓存大小: {best_size}")

2. 预取参数调优

python
# 实验不同的预取大小
prefetch_sizes = [5, 10, 20, 50]

for size in prefetch_sizes:
    dataset = create_cached_dataset(
        original_dataset,
        cache_type="prefetch",
        prefetch_size=size,
    )
    
    # 测试性能
    # ...

故障排除

问题 1: 缓存命中率低

症状: 命中率 < 30%

原因:

  • 缓存太小
  • 访问模式不适合缓存
  • 数据集太大

解决方案:

python
# 1. 增加缓存大小
cache_size = cache_size * 2

# 2. 检查访问模式
# 如果是纯顺序访问,考虑使用预取而非 LRU

# 3. 对于超大数据集,使用内存映射缓存

问题 2: 内存不足

症状: OOM 错误

原因:

  • 缓存太大
  • 数据样本太大

解决方案:

python
# 1. 减小缓存大小
cache_size = cache_size // 2

# 2. 使用内存映射缓存
cache = MemoryMappedCache(cache_dir="./cache")

# 3. 监控内存使用
import psutil
print(f"内存使用: {psutil.virtual_memory().percent}%")

问题 3: 预取线程卡死

症状: 训练卡住不动

原因:

  • 预取线程异常
  • 死锁

解决方案:

python
# 1. 减少预取线程数
num_workers = 1

# 2. 禁用预取
cache_type = "lru"  # 只用 LRU

# 3. 检查数据加载代码
# 确保没有线程不安全的操作

API 参考

CachedDataset

python
CachedDataset(
    dataset: Dataset,
    cache_size: int = 1000,
    cache_images: bool = True,
    cache_tabular: bool = True,
)

参数:

  • dataset: 原始数据集
  • cache_size: 缓存容量
  • cache_images: 是否缓存图像
  • cache_tabular: 是否缓存表格数据

方法:

  • get_cache_stats(): 获取统计信息
  • clear_cache(): 清空缓存

PrefetchDataset

python
PrefetchDataset(
    dataset: Dataset,
    prefetch_size: int = 10,
    num_workers: int = 2,
)

参数:

  • dataset: 原始数据集
  • prefetch_size: 预取队列大小
  • num_workers: 预取线程数

create_cached_dataset

python
create_cached_dataset(
    dataset: Dataset,
    cache_type: str = "lru",
    cache_size: int = 1000,
    prefetch_size: int = 10,
    cache_dir: str | None = None,
)

参数:

  • dataset: 原始数据集
  • cache_type: 缓存类型 ("lru", "prefetch", "both", "none")
  • cache_size: LRU 缓存大小
  • prefetch_size: 预取队列大小
  • cache_dir: 内存映射缓存目录

相关资源

  • 实现代码: med_core/datasets/cache.py
  • 测试用例: tests/test_cache.py
  • 演示脚本: scripts/dev/cache_demo_simple.py
  • 性能基准: docs/guides/performance_optimization.md

更新日志

v0.2.0 (2026-02-20)

  • ✨ 新增 LRU 缓存
  • ✨ 新增预取缓存
  • ✨ 新增内存映射缓存
  • ✨ 新增缓存工厂函数
  • 📝 完整的文档和示例
  • ✅ 全面的测试覆盖

Released under the MIT License.