Datasets API
医学多模态数据集模块,支持单视图和多视图数据加载。
概述
Datasets 模块提供了灵活的数据加载系统,支持:
- 单视图数据集: 每个样本一张图像 + 表格数据
- 多视图数据集: 每个样本多张图像(多角度、多切片、时间序列等)
- 注意力监督: 支持注意力图监督信号
- 数据增强: 医学影像专用的数据增强策略
- 数据清洗: 自动检测和处理异常数据
基础类
BaseMultimodalDataset
所有多模态数据集的抽象基类。
参数:
image_paths(list[str | Path]): 图像文件路径列表tabular_data(np.ndarray | torch.Tensor): 表格特征数组 (N, num_features)labels(np.ndarray | torch.Tensor): 标签数组 (N,)transform(Any): 图像变换,默认 Nonetarget_transform(Any): 标签变换,默认 None
方法:
__len__()- 返回数据集大小__getitem__(idx)- 获取单个样本 (image, tabular, label)load_image(path)- 加载图像(需子类实现)get_tabular_dim()- 返回表格特征维度get_num_classes()- 返回类别数get_class_distribution()- 返回类别分布get_sample_weights()- 计算样本权重(用于平衡采样)subset(indices)- 创建子集get_statistics()- 获取数据集统计信息
示例:
python
from med_core.datasets import BaseMultimodalDataset
# 查看数据集统计
stats = dataset.get_statistics()
print(f"样本数: {stats['num_samples']}")
print(f"类别分布: {stats['class_distribution']}")
# 创建平衡采样器
sample_weights = dataset.get_sample_weights()
sampler = WeightedRandomSampler(sample_weights, len(dataset))MedicalMultimodalDataset
医学多模态数据集的具体实现。
参数:
image_paths(list[str | Path]): 医学图像路径tabular_data(np.ndarray | torch.Tensor): 临床表格数据labels(np.ndarray | torch.Tensor): 诊断标签transform(Any): 图像变换target_transform(Any): 标签变换feature_names(list[str]): 特征名称列表patient_ids(list[str]): 患者 ID 列表
特性:
- 支持常见医学图像格式(JPEG, PNG, TIFF)
- 自动特征预处理(归一化、编码)
- 处理缺失数据
- 支持患者级别的数据划分
示例:
python
from med_core.datasets import MedicalMultimodalDataset
# 从 CSV 创建数据集
dataset = MedicalMultimodalDataset.from_csv(
csv_path="data/patients.csv",
image_dir="data/ct_scans",
image_column="ct_path",
target_column="diagnosis",
numerical_features=["age", "bmi", "blood_pressure"],
categorical_features=["gender", "smoking_status"],
transform=train_transforms
)
# 访问样本
image, tabular, label = dataset[0]
print(f"图像形状: {image.shape}")
print(f"表格特征: {tabular.shape}")
print(f"标签: {label}")多视图数据集
BaseMultiViewDataset
多视图数据集的抽象基类。
特性:
- 支持每个样本多张图像
- 灵活的视图配置
- 自动处理不同数量的视图
MedicalMultiViewDataset
医学多视图数据集实现。
应用场景:
- 多角度 CT 扫描
- 多切片 MRI
- 时间序列影像
- 病理切片的多视野
示例:
python
from med_core.datasets import MedicalMultiViewDataset, MultiViewConfig
# 配置多视图
config = MultiViewConfig(
max_views=8, # 最多 8 个视图
pad_mode="repeat", # 不足时重复填充
view_aggregation="attention" # 使用注意力聚合
)
# 创建多视图数据集
dataset = MedicalMultiViewDataset(
image_paths_list=[
["patient1_view1.jpg", "patient1_view2.jpg"],
["patient2_view1.jpg", "patient2_view2.jpg", "patient2_view3.jpg"],
],
tabular_data=tabular_features,
labels=labels,
config=config,
transform=train_transforms
)
# 获取样本
views, tabular, label = dataset[0]
print(f"视图数量: {len(views)}")
print(f"每个视图形状: {views[0].shape}")MultiViewConfig
多视图配置类。
参数:
max_views(int): 最大视图数量pad_mode(str): 填充模式 ("repeat", "zero", "mean")view_aggregation(str): 聚合策略 ("mean", "max", "attention")
数据变换
get_train_transforms
获取训练集数据增强。
参数:
image_size(int | tuple): 图像大小,默认 224augmentation_level(str): 增强强度 ("light", "medium", "strong")
返回:
torchvision.transforms.Compose: 变换组合
示例:
python
from med_core.datasets import get_train_transforms
train_transforms = get_train_transforms(
image_size=224,
augmentation_level="medium"
)get_val_transforms
获取验证集数据变换(无增强)。
参数:
image_size(int | tuple): 图像大小,默认 224
示例:
python
from med_core.datasets import get_val_transforms
val_transforms = get_val_transforms(image_size=224)get_medical_augmentation
获取医学影像专用增强。
特性:
- 保持医学影像的诊断特征
- 避免过度变形
- 支持 CLAHE 对比度增强
- 支持弹性变形
示例:
python
from med_core.datasets import get_medical_augmentation
medical_aug = get_medical_augmentation(
image_size=512,
use_clahe=True, # 使用 CLAHE 增强对比度
elastic_transform=True # 弹性变形
)数据加载工具
create_dataloaders
创建训练、验证、测试数据加载器。
参数:
dataset: 数据集对象batch_size(int): 批次大小train_split(float): 训练集比例,默认 0.7val_split(float): 验证集比例,默认 0.15test_split(float): 测试集比例,默认 0.15num_workers(int): 数据加载线程数,默认 4pin_memory(bool): 是否固定内存,默认 Truebalanced_sampling(bool): 是否使用平衡采样,默认 False
返回:
tuple[DataLoader, DataLoader, DataLoader]: (train_loader, val_loader, test_loader)
示例:
python
from med_core.datasets import create_dataloaders
train_loader, val_loader, test_loader = create_dataloaders(
dataset=dataset,
batch_size=32,
train_split=0.7,
val_split=0.15,
test_split=0.15,
num_workers=8,
balanced_sampling=True # 处理类别不平衡
)
# 训练循环
for images, tabular, labels in train_loader:
# 训练代码
passsplit_dataset
划分数据集为训练、验证、测试集。
参数:
dataset: 数据集对象train_ratio(float): 训练集比例val_ratio(float): 验证集比例test_ratio(float): 测试集比例stratify(bool): 是否分层划分,默认 Truerandom_state(int): 随机种子
返回:
tuple[Dataset, Dataset, Dataset]: (train_set, val_set, test_set)
示例:
python
from med_core.datasets import split_dataset
train_set, val_set, test_set = split_dataset(
dataset=dataset,
train_ratio=0.7,
val_ratio=0.15,
test_ratio=0.15,
stratify=True, # 保持类别比例
random_state=42
)数据清洗
DataCleaner
数据清洗工具类。
功能:
- 检测异常值
- 处理缺失值
- 移除重复样本
- 特征归一化
示例:
python
from med_core.datasets import DataCleaner
cleaner = DataCleaner()
# 清洗表格数据
cleaned_data = cleaner.clean_tabular(
data=raw_tabular_data,
remove_outliers=True,
fill_missing="median",
normalize=True
)
# 检测图像质量
quality_scores = cleaner.assess_image_quality(image_paths)
valid_indices = quality_scores > 0.5 # 过滤低质量图像配置示例
在 YAML 配置文件中配置数据集:
yaml
data:
csv_path: data/patients.csv
image_dir: data/ct_scans
image_column: ct_path
target_column: diagnosis
numerical_features:
- age
- bmi
- blood_pressure
categorical_features:
- gender
- smoking_status
train_split: 0.7
val_split: 0.15
test_split: 0.15
augmentation:
level: medium
use_clahe: true
dataloader:
batch_size: 32
num_workers: 8
pin_memory: true
balanced_sampling: true多视图配置:
yaml
data:
multiview:
enabled: true
max_views: 8
pad_mode: repeat
view_aggregation: attention最佳实践
数据划分:
- 使用分层划分保持类别比例
- 固定随机种子确保可复现性
- 患者级别划分避免数据泄露
数据增强:
- 训练集使用中等强度增强
- 验证/测试集仅做归一化
- 医学影像避免过度变形
数据加载:
- GPU 训练时设置
pin_memory=True - 根据 CPU 核心数设置
num_workers - 类别不平衡时使用
balanced_sampling
性能优化:
- 使用数据缓存加速重复加载
- 预处理数据保存为 HDF5 格式
- 多视图数据使用延迟加载
参考
完整实现请参考:
med_core/datasets/base.py- 基础数据集类med_core/datasets/medical.py- 医学数据集实现med_core/datasets/medical_multiview.py- 多视图数据集med_core/datasets/transforms.py- 数据变换med_core/datasets/data_cleaner.py- 数据清洗