案例研究 2:乳腺癌分类
预计时间:75 分钟
1. 医学背景
乳腺癌是全球女性最常见的恶性肿瘤,早期诊断和精准分型对治疗方案选择至关重要。现代诊断通常结合多种模态:
- 病理切片(WSI):显微镜下的组织形态学特征
- 影像学(MRI/超声):肿瘤的空间分布和血流特征
- 临床数据:年龄、家族史、激素受体状态等
分子分型:
- Luminal A(ER+/PR+/HER2-):预后较好
- Luminal B(ER+/PR+/HER2+):中等预后
- HER2 富集型(ER-/PR-/HER2+):需靶向治疗
- 三阴性(ER-/PR-/HER2-):预后较差
2. 数据集介绍
本案例使用 TCGA-BRCA(The Cancer Genome Atlas - Breast Cancer)数据集的子集,包含:
- 病理切片:H&E 染色的全切片图像(WSI),分辨率 40× 或 20×
- MRI 影像:T1/T2 加权序列,动态对比增强(DCE-MRI)
- 临床数据:年龄、TNM 分期、分子分型、生存信息
数据规模:
- 1,100+ 患者
- 每个患者 1-5 张 WSI(平均 2.3 张)
- MRI 序列:3D 体积数据(256×256×N 切片)
3. 环境准备
bash
# 安装依赖(包含 WSI 处理库)
uv sync --extra pathology
# 生成合成数据
uv run python scripts/generate_synthetic_breast_data.py \
--output-dir data/breast_cancer \
--num-samples 5004. 数据预处理
4.1 病理切片预处理
WSI 文件通常非常大(10-50GB),需要分块处理:
python
# scripts/preprocess_wsi.py
import openslide
import numpy as np
from pathlib import Path
from PIL import Image
import torch
from torchvision import transforms
class WSIPreprocessor:
"""全切片图像预处理器"""
def __init__(
self,
patch_size: int = 256,
level: int = 0,
tissue_threshold: float = 0.5
):
self.patch_size = patch_size
self.level = level
self.tissue_threshold = tissue_threshold
def extract_patches(self, wsi_path: Path, output_dir: Path):
"""从 WSI 提取组织区域的 patches"""
slide = openslide.OpenSlide(str(wsi_path))
# 获取缩略图用于组织检测
thumbnail = slide.get_thumbnail((2000, 2000))
tissue_mask = self._detect_tissue(thumbnail)
# 计算 patch 坐标
patch_coords = self._get_patch_coordinates(
slide, tissue_mask, self.level
)
# 提取并保存 patches
output_dir.mkdir(parents=True, exist_ok=True)
patches = []
for i, (x, y) in enumerate(patch_coords):
patch = slide.read_region(
(x, y), self.level,
(self.patch_size, self.patch_size)
).convert('RGB')
patch_path = output_dir / f"patch_{i:04d}.png"
patch.save(patch_path)
patches.append({
'path': patch_path,
'coords': (x, y),
'level': self.level
})
slide.close()
return patches
def _detect_tissue(self, thumbnail: Image.Image) -> np.ndarray:
"""检测组织区域(排除背景)"""
img_array = np.array(thumbnail)
# 转换到 HSV 空间
from skimage import color
hsv = color.rgb2hsv(img_array)
# 基于饱和度阈值分割
saturation = hsv[:, :, 1]
tissue_mask = saturation > 0.1
# 形态学操作去噪
from skimage.morphology import remove_small_objects, binary_closing
tissue_mask = remove_small_objects(tissue_mask, min_size=500)
tissue_mask = binary_closing(tissue_mask, footprint=np.ones((5, 5)))
return tissue_mask
def _get_patch_coordinates(
self,
slide: openslide.OpenSlide,
tissue_mask: np.ndarray,
level: int
) -> list[tuple[int, int]]:
"""计算需要提取的 patch 坐标"""
level_dims = slide.level_dimensions[level]
downsample = slide.level_downsamples[level]
# 调整 mask 尺寸
from skimage.transform import resize
mask_resized = resize(
tissue_mask,
(level_dims[1] // self.patch_size, level_dims[0] // self.patch_size),
order=0
)
# 找到组织区域的 patch
coords = []
for i in range(mask_resized.shape[0]):
for j in range(mask_resized.shape[1]):
if mask_resized[i, j] > self.tissue_threshold:
x = int(j * self.patch_size * downsample)
y = int(i * self.patch_size * downsample)
coords.append((x, y))
return coords
# 批量处理
def preprocess_all_wsi(input_dir: Path, output_dir: Path):
preprocessor = WSIPreprocessor(patch_size=256, level=0)
for wsi_path in input_dir.glob("*.svs"):
print(f"Processing {wsi_path.name}...")
patient_id = wsi_path.stem
patches = preprocessor.extract_patches(
wsi_path,
output_dir / patient_id
)
# 保存 patch 信息
import json
with open(output_dir / patient_id / "patches.json", 'w') as f:
json.dump(patches, f, indent=2, default=str)
if __name__ == "__main__":
preprocess_all_wsi(
input_dir=Path("data/breast_cancer/wsi_raw"),
output_dir=Path("data/breast_cancer/wsi_patches")
)4.2 MRI 预处理
python
# scripts/preprocess_breast_mri.py
import numpy as np
import nibabel as nib
from pathlib import Path
from med_core.preprocessing import MRIPreprocessor
def preprocess_breast_mri(input_dir: Path, output_dir: Path):
"""预处理乳腺 MRI 数据"""
preprocessor = MRIPreprocessor(
target_spacing=(1.0, 1.0, 3.0), # 各向异性重采样
normalize_method='z_score',
crop_to_roi=True
)
output_dir.mkdir(parents=True, exist_ok=True)
for patient_dir in input_dir.glob("*"):
if not patient_dir.is_dir():
continue
print(f"Processing {patient_dir.name}...")
# 加载 T1 和 T2 序列
t1_path = patient_dir / "T1.nii.gz"
t2_path = patient_dir / "T2.nii.gz"
if t1_path.exists():
t1_volume = nib.load(t1_path).get_fdata()
t1_processed = preprocessor(t1_volume)
nib.save(
nib.Nifti1Image(t1_processed, np.eye(4)),
output_dir / f"{patient_dir.name}_T1.nii.gz"
)
if t2_path.exists():
t2_volume = nib.load(t2_path).get_fdata()
t2_processed = preprocessor(t2_volume)
nib.save(
nib.Nifti1Image(t2_processed, np.eye(4)),
output_dir / f"{patient_dir.name}_T2.nii.gz"
)
if __name__ == "__main__":
preprocess_breast_mri(
input_dir=Path("data/breast_cancer/mri_raw"),
output_dir=Path("data/breast_cancer/mri_processed")
)5. 多模态数据集
python
# examples/breast_cancer_dataset.py
import torch
from torch.utils.data import Dataset
import pandas as pd
from pathlib import Path
import json
from PIL import Image
import nibabel as nib
import numpy as np
class BreastCancerMultiModalDataset(Dataset):
"""乳腺癌多模态数据集"""
def __init__(
self,
data_dir: Path,
split: str = "train",
use_pathology: bool = True,
use_mri: bool = True,
use_clinical: bool = True,
max_patches: int = 50
):
self.data_dir = Path(data_dir)
self.use_pathology = use_pathology
self.use_mri = use_mri
self.use_clinical = use_clinical
self.max_patches = max_patches
# 加载临床数据
self.clinical_df = pd.read_csv(self.data_dir / f"{split}_clinical.csv")
self.patient_ids = self.clinical_df['patient_id'].tolist()
# 特征列
self.clinical_features = [
'age', 'tumor_size', 'lymph_nodes',
'er_status', 'pr_status', 'her2_status'
]
def __len__(self):
return len(self.patient_ids)
def __getitem__(self, idx):
patient_id = self.patient_ids[idx]
sample = {'patient_id': patient_id}
# 1. 病理数据(多个 patches)
if self.use_pathology:
patches = self._load_pathology_patches(patient_id)
sample['pathology'] = patches
# 2. MRI 数据
if self.use_mri:
mri_volume = self._load_mri(patient_id)
sample['mri'] = mri_volume
# 3. 临床数据
if self.use_clinical:
clinical_features = self._load_clinical(patient_id)
sample['clinical'] = clinical_features
# 标签(分子分型)
label = self.clinical_df.loc[
self.clinical_df['patient_id'] == patient_id,
'molecular_subtype'
].values[0]
sample['label'] = torch.tensor(label, dtype=torch.long)
return sample
def _load_pathology_patches(self, patient_id: str) -> torch.Tensor:
"""加载病理 patches"""
patch_dir = self.data_dir / "wsi_patches" / patient_id
# 加载 patch 信息
with open(patch_dir / "patches.json") as f:
patch_info = json.load(f)
# 随机采样 patches
import random
if len(patch_info) > self.max_patches:
patch_info = random.sample(patch_info, self.max_patches)
patches = []
for info in patch_info:
img = Image.open(info['path']).convert('RGB')
img_tensor = torch.from_numpy(np.array(img)).permute(2, 0, 1).float() / 255.0
patches.append(img_tensor)
# 填充到固定数量
while len(patches) < self.max_patches:
patches.append(torch.zeros_like(patches[0]))
return torch.stack(patches) # (N, 3, H, W)
def _load_mri(self, patient_id: str) -> torch.Tensor:
"""加载 MRI 数据"""
mri_path = self.data_dir / "mri_processed" / f"{patient_id}_T1.nii.gz"
if not mri_path.exists():
# 返回零张量
return torch.zeros(1, 64, 64, 32)
volume = nib.load(mri_path).get_fdata()
# 中心裁剪
from med_core.preprocessing import center_crop_3d
volume = center_crop_3d(volume, (64, 64, 32))
return torch.from_numpy(volume).unsqueeze(0).float() # (1, D, H, W)
def _load_clinical(self, patient_id: str) -> torch.Tensor:
"""加载临床特征"""
row = self.clinical_df[self.clinical_df['patient_id'] == patient_id]
features = row[self.clinical_features].values[0]
return torch.from_numpy(features).float()6. 多模态融合模型
6.1 使用 MedFusion 构建
python
# train_breast_cancer.py(自定义训练脚本示意)
from med_core.models import MultiModalModelBuilder
# 构建多模态模型
builder = MultiModalModelBuilder(num_classes=4) # 4 种分子分型
# 病理模态(使用 MIL 聚合多个 patches)
builder.add_modality(
name="pathology",
backbone="resnet50",
input_channels=3,
pretrained=True,
aggregator="attention_mil" # 注意力 MIL 聚合
)
# MRI 模态
builder.add_modality(
name="mri",
backbone="resnet3d_18",
input_channels=1,
pretrained=False
)
# 临床数据模态
builder.add_modality(
name="clinical",
backbone="mlp",
input_dim=6,
hidden_dims=[64, 128, 256]
)
# 融合策略
builder.set_fusion(
fusion_type="attention", # 注意力融合
hidden_dim=512
)
# 分类头
builder.set_head("classification", dropout=0.5)
model = builder.build()
print(model)6.2 配置文件
创建 configs/breast_cancer_config.yaml:
yaml
data:
dataset_type: "breast_cancer_multimodal"
data_dir: "data/breast_cancer"
batch_size: 4
num_workers: 4
modalities:
pathology:
enabled: true
max_patches: 50
mri:
enabled: true
clinical:
enabled: true
augmentation:
pathology:
random_flip: true
random_rotation: 90
color_jitter: true
stain_normalization: "macenko"
mri:
random_flip: true
random_noise: 0.01
model:
num_classes: 4
modalities:
pathology:
backbone: "resnet50"
input_channels: 3
pretrained: true
aggregator:
type: "attention_mil"
hidden_dim: 256
mri:
backbone: "resnet3d_18"
input_channels: 1
pretrained: false
clinical:
backbone: "mlp"
input_dim: 6
hidden_dims: [64, 128, 256]
fusion:
type: "attention"
hidden_dim: 512
num_heads: 8
head:
type: "classification"
dropout: 0.5
training:
epochs: 100
optimizer:
type: "adamw"
lr: 0.0001
weight_decay: 0.01
scheduler:
type: "cosine"
T_max: 100
warmup_epochs: 10
loss:
type: "cross_entropy"
label_smoothing: 0.1
mixed_precision: true
early_stopping:
patience: 15
metric: "val_f1_macro"
mode: "max"
logging:
output_dir: "outputs/breast_cancer"
tensorboard: true
wandb:
enabled: true
project: "breast-cancer-classification"7. 训练模型
python
# train_breast_cancer.py(完整版示意)
import torch
from torch.utils.data import DataLoader
from med_core.models import build_model_from_config
from med_core.trainers import MultimodalTrainer
from med_core.configs import load_config
from breast_cancer_dataset import BreastCancerMultiModalDataset
def main():
config = load_config("configs/breast_cancer_config.yaml")
# 创建数据集
train_dataset = BreastCancerMultiModalDataset(
data_dir=config['data']['data_dir'],
split='train',
use_pathology=config['data']['modalities']['pathology']['enabled'],
use_mri=config['data']['modalities']['mri']['enabled'],
use_clinical=config['data']['modalities']['clinical']['enabled']
)
val_dataset = BreastCancerMultiModalDataset(
data_dir=config['data']['data_dir'],
split='val',
use_pathology=config['data']['modalities']['pathology']['enabled'],
use_mri=config['data']['modalities']['mri']['enabled'],
use_clinical=config['data']['modalities']['clinical']['enabled']
)
train_loader = DataLoader(
train_dataset,
batch_size=config['data']['batch_size'],
shuffle=True,
num_workers=config['data']['num_workers']
)
val_loader = DataLoader(
val_dataset,
batch_size=config['data']['batch_size'],
shuffle=False,
num_workers=config['data']['num_workers']
)
# 构建模型
model = build_model_from_config(config['model'])
# 训练
trainer = MultimodalTrainer(
model=model,
config=config,
train_loader=train_loader,
val_loader=val_loader
)
trainer.train()
if __name__ == "__main__":
main()运行训练:
bash
uv run python train_breast_cancer.py8. 融合策略对比
8.1 对比不同融合方法
python
# compare_fusion_strategies.py(自定义实验脚本示意)
fusion_strategies = [
"concatenate", # 简单拼接
"gated", # 门控融合
"attention", # 注意力融合
"cross_attention", # 交叉注意力
"bilinear" # 双线性融合
]
results = {}
for fusion_type in fusion_strategies:
print(f"\n训练 {fusion_type} 融合策略...")
# 修改配置
config['model']['fusion']['type'] = fusion_type
# 构建模型
model = build_model_from_config(config['model'])
# 训练
trainer = MultimodalTrainer(model, config, train_loader, val_loader)
metrics = trainer.train()
results[fusion_type] = metrics
# 可视化对比
import matplotlib.pyplot as plt
fusion_names = list(results.keys())
f1_scores = [results[f]['best_val_f1'] for f in fusion_names]
plt.figure(figsize=(10, 6))
plt.bar(fusion_names, f1_scores)
plt.ylabel('F1 Score (Macro)')
plt.title('不同融合策略性能对比')
plt.xticks(rotation=45)
plt.tight_layout()
plt.savefig('outputs/breast_cancer/fusion_comparison.png', dpi=300)9. 注意力可视化
9.1 MIL 注意力权重
python
# visualize_mil_attention.py(自定义可视化脚本示意)
import torch
import matplotlib.pyplot as plt
import numpy as np
def visualize_mil_attention(model, sample, device='cuda'):
"""可视化 MIL 注意力权重"""
model.eval()
model.to(device)
pathology_patches = sample['pathology'].unsqueeze(0).to(device) # (1, N, 3, H, W)
with torch.no_grad():
# 前向传播并获取注意力权重
output, aux = model({
'pathology': pathology_patches,
'mri': sample['mri'].unsqueeze(0).to(device),
'clinical': sample['clinical'].unsqueeze(0).to(device)
})
attention_weights = aux['pathology_attention'] # (1, N)
# 可视化
attention_weights = attention_weights[0].cpu().numpy()
num_patches = len(attention_weights)
fig, axes = plt.subplots(2, 5, figsize=(20, 8))
axes = axes.flatten()
# 显示注意力权重最高的 10 个 patches
top_indices = np.argsort(attention_weights)[-10:][::-1]
for i, idx in enumerate(top_indices):
patch = pathology_patches[0, idx].cpu().permute(1, 2, 0).numpy()
weight = attention_weights[idx]
axes[i].imshow(patch)
axes[i].set_title(f'Patch {idx}\nAttention: {weight:.4f}')
axes[i].axis('off')
plt.tight_layout()
plt.savefig(f'outputs/breast_cancer/mil_attention_{sample["patient_id"]}.png',
dpi=300, bbox_inches='tight')
# 可视化测试样本
test_dataset = BreastCancerMultiModalDataset(data_dir='data/breast_cancer', split='test')
for i in range(5):
sample = test_dataset[i]
visualize_mil_attention(model, sample)9.2 模态融合注意力
python
# visualize_fusion_attention.py(自定义可视化脚本示意)
def visualize_fusion_attention(model, sample, device='cuda'):
"""可视化模态间注意力"""
model.eval()
model.to(device)
with torch.no_grad():
output, aux = model({
'pathology': sample['pathology'].unsqueeze(0).to(device),
'mri': sample['mri'].unsqueeze(0).to(device),
'clinical': sample['clinical'].unsqueeze(0).to(device)
})
# 获取模态注意力权重
modality_attention = aux['fusion_attention'] # (1, 3)
modality_names = ['病理', 'MRI', '临床']
weights = modality_attention[0].cpu().numpy()
# 绘制条形图
plt.figure(figsize=(8, 6))
colors = ['#FF6B6B', '#4ECDC4', '#45B7D1']
bars = plt.bar(modality_names, weights, color=colors)
plt.ylabel('注意力权重')
plt.title(f'模态融合注意力分布\n患者: {sample["patient_id"]}')
plt.ylim(0, 1)
# 添加数值标签
for bar, weight in zip(bars, weights):
height = bar.get_height()
plt.text(bar.get_x() + bar.get_width()/2., height,
f'{weight:.3f}',
ha='center', va='bottom')
plt.tight_layout()
plt.savefig(f'outputs/breast_cancer/fusion_attention_{sample["patient_id"]}.png',
dpi=300, bbox_inches='tight')10. 预期结果
在 TCGA-BRCA 数据集上,多模态融合模型预期性能:
| 模态组合 | F1 Score | AUC |
|---|---|---|
| 仅病理 | 0.72 | 0.85 |
| 仅 MRI | 0.65 | 0.78 |
| 仅临床 | 0.58 | 0.70 |
| 病理 + 临床 | 0.76 | 0.88 |
| 病理 + MRI | 0.78 | 0.90 |
| 三模态融合 | 0.82 | 0.92 |
融合策略对比:
- Concatenate: F1 = 0.78
- Gated: F1 = 0.80
- Attention: F1 = 0.82 (最佳)
- Cross-Attention: F1 = 0.81
- Bilinear: F1 = 0.79
11. 消融实验
python
# ablation_study.py(自定义实验脚本示意)
ablation_configs = [
{"pathology": True, "mri": False, "clinical": False},
{"pathology": False, "mri": True, "clinical": False},
{"pathology": False, "mri": False, "clinical": True},
{"pathology": True, "mri": True, "clinical": False},
{"pathology": True, "mri": False, "clinical": True},
{"pathology": False, "mri": True, "clinical": True},
{"pathology": True, "mri": True, "clinical": True},
]
results = []
for cfg in ablation_configs:
modalities = [k for k, v in cfg.items() if v]
print(f"\n训练模态: {', '.join(modalities)}")
# 创建数据集
train_dataset = BreastCancerMultiModalDataset(
data_dir='data/breast_cancer',
split='train',
use_pathology=cfg['pathology'],
use_mri=cfg['mri'],
use_clinical=cfg['clinical']
)
# 训练并记录结果
# ... (训练代码)
results.append({
'modalities': '+'.join(modalities),
'f1': metrics['best_val_f1'],
'auc': metrics['best_val_auc']
})
# 可视化
import pandas as pd
df = pd.DataFrame(results)
print(df)12. 常见问题
Q: WSI 文件太大,内存不足?
- 使用 patch 级别处理,不要一次加载整个 WSI
- 减少 max_patches 数量
- 使用更小的 patch_size(如 224×224)
Q: 如何处理缺失模态?
- 使用零填充或平均特征填充
- 训练时随机丢弃模态(模态 dropout)
- 使用模态特定的缺失指示符
Q: 不同模态特征尺度差异大?
- 对每个模态分别归一化
- 使用 LayerNorm 或 BatchNorm
- 调整融合层的权重初始化
13. 总结
本案例展示了多模态医学数据融合的完整流程:
- 病理 WSI 的 patch 提取和 MIL 聚合
- MRI 3D 体积数据处理
- 临床表格数据整合
- 多种融合策略对比
- 注意力机制可视化
关键收获:
- 多模态融合显著提升性能(+10% F1)
- 注意力融合优于简单拼接
- MIL 聚合有效处理 WSI 数据
- 模态间互补信息很重要
下一步: 尝试 案例 3:生存预测,学习时间序列分析和生存模型。