Skip to content

案例研究 2:乳腺癌分类

预计时间:75 分钟

1. 医学背景

乳腺癌是全球女性最常见的恶性肿瘤,早期诊断和精准分型对治疗方案选择至关重要。现代诊断通常结合多种模态:

  • 病理切片(WSI):显微镜下的组织形态学特征
  • 影像学(MRI/超声):肿瘤的空间分布和血流特征
  • 临床数据:年龄、家族史、激素受体状态等

分子分型:

  • Luminal A(ER+/PR+/HER2-):预后较好
  • Luminal B(ER+/PR+/HER2+):中等预后
  • HER2 富集型(ER-/PR-/HER2+):需靶向治疗
  • 三阴性(ER-/PR-/HER2-):预后较差

2. 数据集介绍

本案例使用 TCGA-BRCA(The Cancer Genome Atlas - Breast Cancer)数据集的子集,包含:

  • 病理切片:H&E 染色的全切片图像(WSI),分辨率 40× 或 20×
  • MRI 影像:T1/T2 加权序列,动态对比增强(DCE-MRI)
  • 临床数据:年龄、TNM 分期、分子分型、生存信息

数据规模:

  • 1,100+ 患者
  • 每个患者 1-5 张 WSI(平均 2.3 张)
  • MRI 序列:3D 体积数据(256×256×N 切片)

3. 环境准备

bash
# 安装依赖(包含 WSI 处理库)
uv sync --extra pathology

# 生成合成数据
uv run python scripts/generate_synthetic_breast_data.py \
    --output-dir data/breast_cancer \
    --num-samples 500

4. 数据预处理

4.1 病理切片预处理

WSI 文件通常非常大(10-50GB),需要分块处理:

python
# scripts/preprocess_wsi.py
import openslide
import numpy as np
from pathlib import Path
from PIL import Image
import torch
from torchvision import transforms

class WSIPreprocessor:
    """全切片图像预处理器"""

    def __init__(
        self,
        patch_size: int = 256,
        level: int = 0,
        tissue_threshold: float = 0.5
    ):
        self.patch_size = patch_size
        self.level = level
        self.tissue_threshold = tissue_threshold

    def extract_patches(self, wsi_path: Path, output_dir: Path):
        """从 WSI 提取组织区域的 patches"""
        slide = openslide.OpenSlide(str(wsi_path))

        # 获取缩略图用于组织检测
        thumbnail = slide.get_thumbnail((2000, 2000))
        tissue_mask = self._detect_tissue(thumbnail)

        # 计算 patch 坐标
        patch_coords = self._get_patch_coordinates(
            slide, tissue_mask, self.level
        )

        # 提取并保存 patches
        output_dir.mkdir(parents=True, exist_ok=True)
        patches = []

        for i, (x, y) in enumerate(patch_coords):
            patch = slide.read_region(
                (x, y), self.level,
                (self.patch_size, self.patch_size)
            ).convert('RGB')

            patch_path = output_dir / f"patch_{i:04d}.png"
            patch.save(patch_path)
            patches.append({
                'path': patch_path,
                'coords': (x, y),
                'level': self.level
            })

        slide.close()
        return patches

    def _detect_tissue(self, thumbnail: Image.Image) -> np.ndarray:
        """检测组织区域(排除背景)"""
        img_array = np.array(thumbnail)

        # 转换到 HSV 空间
        from skimage import color
        hsv = color.rgb2hsv(img_array)

        # 基于饱和度阈值分割
        saturation = hsv[:, :, 1]
        tissue_mask = saturation > 0.1

        # 形态学操作去噪
        from skimage.morphology import remove_small_objects, binary_closing
        tissue_mask = remove_small_objects(tissue_mask, min_size=500)
        tissue_mask = binary_closing(tissue_mask, footprint=np.ones((5, 5)))

        return tissue_mask

    def _get_patch_coordinates(
        self,
        slide: openslide.OpenSlide,
        tissue_mask: np.ndarray,
        level: int
    ) -> list[tuple[int, int]]:
        """计算需要提取的 patch 坐标"""
        level_dims = slide.level_dimensions[level]
        downsample = slide.level_downsamples[level]

        # 调整 mask 尺寸
        from skimage.transform import resize
        mask_resized = resize(
            tissue_mask,
            (level_dims[1] // self.patch_size, level_dims[0] // self.patch_size),
            order=0
        )

        # 找到组织区域的 patch
        coords = []
        for i in range(mask_resized.shape[0]):
            for j in range(mask_resized.shape[1]):
                if mask_resized[i, j] > self.tissue_threshold:
                    x = int(j * self.patch_size * downsample)
                    y = int(i * self.patch_size * downsample)
                    coords.append((x, y))

        return coords

# 批量处理
def preprocess_all_wsi(input_dir: Path, output_dir: Path):
    preprocessor = WSIPreprocessor(patch_size=256, level=0)

    for wsi_path in input_dir.glob("*.svs"):
        print(f"Processing {wsi_path.name}...")
        patient_id = wsi_path.stem

        patches = preprocessor.extract_patches(
            wsi_path,
            output_dir / patient_id
        )

        # 保存 patch 信息
        import json
        with open(output_dir / patient_id / "patches.json", 'w') as f:
            json.dump(patches, f, indent=2, default=str)

if __name__ == "__main__":
    preprocess_all_wsi(
        input_dir=Path("data/breast_cancer/wsi_raw"),
        output_dir=Path("data/breast_cancer/wsi_patches")
    )

4.2 MRI 预处理

python
# scripts/preprocess_breast_mri.py
import numpy as np
import nibabel as nib
from pathlib import Path
from med_core.preprocessing import MRIPreprocessor

def preprocess_breast_mri(input_dir: Path, output_dir: Path):
    """预处理乳腺 MRI 数据"""
    preprocessor = MRIPreprocessor(
        target_spacing=(1.0, 1.0, 3.0),  # 各向异性重采样
        normalize_method='z_score',
        crop_to_roi=True
    )

    output_dir.mkdir(parents=True, exist_ok=True)

    for patient_dir in input_dir.glob("*"):
        if not patient_dir.is_dir():
            continue

        print(f"Processing {patient_dir.name}...")

        # 加载 T1 和 T2 序列
        t1_path = patient_dir / "T1.nii.gz"
        t2_path = patient_dir / "T2.nii.gz"

        if t1_path.exists():
            t1_volume = nib.load(t1_path).get_fdata()
            t1_processed = preprocessor(t1_volume)
            nib.save(
                nib.Nifti1Image(t1_processed, np.eye(4)),
                output_dir / f"{patient_dir.name}_T1.nii.gz"
            )

        if t2_path.exists():
            t2_volume = nib.load(t2_path).get_fdata()
            t2_processed = preprocessor(t2_volume)
            nib.save(
                nib.Nifti1Image(t2_processed, np.eye(4)),
                output_dir / f"{patient_dir.name}_T2.nii.gz"
            )

if __name__ == "__main__":
    preprocess_breast_mri(
        input_dir=Path("data/breast_cancer/mri_raw"),
        output_dir=Path("data/breast_cancer/mri_processed")
    )

5. 多模态数据集

python
# examples/breast_cancer_dataset.py
import torch
from torch.utils.data import Dataset
import pandas as pd
from pathlib import Path
import json
from PIL import Image
import nibabel as nib
import numpy as np

class BreastCancerMultiModalDataset(Dataset):
    """乳腺癌多模态数据集"""

    def __init__(
        self,
        data_dir: Path,
        split: str = "train",
        use_pathology: bool = True,
        use_mri: bool = True,
        use_clinical: bool = True,
        max_patches: int = 50
    ):
        self.data_dir = Path(data_dir)
        self.use_pathology = use_pathology
        self.use_mri = use_mri
        self.use_clinical = use_clinical
        self.max_patches = max_patches

        # 加载临床数据
        self.clinical_df = pd.read_csv(self.data_dir / f"{split}_clinical.csv")
        self.patient_ids = self.clinical_df['patient_id'].tolist()

        # 特征列
        self.clinical_features = [
            'age', 'tumor_size', 'lymph_nodes',
            'er_status', 'pr_status', 'her2_status'
        ]

    def __len__(self):
        return len(self.patient_ids)

    def __getitem__(self, idx):
        patient_id = self.patient_ids[idx]
        sample = {'patient_id': patient_id}

        # 1. 病理数据(多个 patches)
        if self.use_pathology:
            patches = self._load_pathology_patches(patient_id)
            sample['pathology'] = patches

        # 2. MRI 数据
        if self.use_mri:
            mri_volume = self._load_mri(patient_id)
            sample['mri'] = mri_volume

        # 3. 临床数据
        if self.use_clinical:
            clinical_features = self._load_clinical(patient_id)
            sample['clinical'] = clinical_features

        # 标签(分子分型)
        label = self.clinical_df.loc[
            self.clinical_df['patient_id'] == patient_id,
            'molecular_subtype'
        ].values[0]
        sample['label'] = torch.tensor(label, dtype=torch.long)

        return sample

    def _load_pathology_patches(self, patient_id: str) -> torch.Tensor:
        """加载病理 patches"""
        patch_dir = self.data_dir / "wsi_patches" / patient_id

        # 加载 patch 信息
        with open(patch_dir / "patches.json") as f:
            patch_info = json.load(f)

        # 随机采样 patches
        import random
        if len(patch_info) > self.max_patches:
            patch_info = random.sample(patch_info, self.max_patches)

        patches = []
        for info in patch_info:
            img = Image.open(info['path']).convert('RGB')
            img_tensor = torch.from_numpy(np.array(img)).permute(2, 0, 1).float() / 255.0
            patches.append(img_tensor)

        # 填充到固定数量
        while len(patches) < self.max_patches:
            patches.append(torch.zeros_like(patches[0]))

        return torch.stack(patches)  # (N, 3, H, W)

    def _load_mri(self, patient_id: str) -> torch.Tensor:
        """加载 MRI 数据"""
        mri_path = self.data_dir / "mri_processed" / f"{patient_id}_T1.nii.gz"

        if not mri_path.exists():
            # 返回零张量
            return torch.zeros(1, 64, 64, 32)

        volume = nib.load(mri_path).get_fdata()

        # 中心裁剪
        from med_core.preprocessing import center_crop_3d
        volume = center_crop_3d(volume, (64, 64, 32))

        return torch.from_numpy(volume).unsqueeze(0).float()  # (1, D, H, W)

    def _load_clinical(self, patient_id: str) -> torch.Tensor:
        """加载临床特征"""
        row = self.clinical_df[self.clinical_df['patient_id'] == patient_id]
        features = row[self.clinical_features].values[0]
        return torch.from_numpy(features).float()

6. 多模态融合模型

6.1 使用 MedFusion 构建

python
# train_breast_cancer.py(自定义训练脚本示意)
from med_core.models import MultiModalModelBuilder

# 构建多模态模型
builder = MultiModalModelBuilder(num_classes=4)  # 4 种分子分型

# 病理模态(使用 MIL 聚合多个 patches)
builder.add_modality(
    name="pathology",
    backbone="resnet50",
    input_channels=3,
    pretrained=True,
    aggregator="attention_mil"  # 注意力 MIL 聚合
)

# MRI 模态
builder.add_modality(
    name="mri",
    backbone="resnet3d_18",
    input_channels=1,
    pretrained=False
)

# 临床数据模态
builder.add_modality(
    name="clinical",
    backbone="mlp",
    input_dim=6,
    hidden_dims=[64, 128, 256]
)

# 融合策略
builder.set_fusion(
    fusion_type="attention",  # 注意力融合
    hidden_dim=512
)

# 分类头
builder.set_head("classification", dropout=0.5)

model = builder.build()
print(model)

6.2 配置文件

创建 configs/breast_cancer_config.yaml

yaml
data:
  dataset_type: "breast_cancer_multimodal"
  data_dir: "data/breast_cancer"
  batch_size: 4
  num_workers: 4

  modalities:
    pathology:
      enabled: true
      max_patches: 50
    mri:
      enabled: true
    clinical:
      enabled: true

  augmentation:
    pathology:
      random_flip: true
      random_rotation: 90
      color_jitter: true
      stain_normalization: "macenko"
    mri:
      random_flip: true
      random_noise: 0.01

model:
  num_classes: 4

  modalities:
    pathology:
      backbone: "resnet50"
      input_channels: 3
      pretrained: true
      aggregator:
        type: "attention_mil"
        hidden_dim: 256

    mri:
      backbone: "resnet3d_18"
      input_channels: 1
      pretrained: false

    clinical:
      backbone: "mlp"
      input_dim: 6
      hidden_dims: [64, 128, 256]

  fusion:
    type: "attention"
    hidden_dim: 512
    num_heads: 8

  head:
    type: "classification"
    dropout: 0.5

training:
  epochs: 100
  optimizer:
    type: "adamw"
    lr: 0.0001
    weight_decay: 0.01

  scheduler:
    type: "cosine"
    T_max: 100
    warmup_epochs: 10

  loss:
    type: "cross_entropy"
    label_smoothing: 0.1

  mixed_precision: true

  early_stopping:
    patience: 15
    metric: "val_f1_macro"
    mode: "max"

logging:
  output_dir: "outputs/breast_cancer"
  tensorboard: true
  wandb:
    enabled: true
    project: "breast-cancer-classification"

7. 训练模型

python
# train_breast_cancer.py(完整版示意)
import torch
from torch.utils.data import DataLoader
from med_core.models import build_model_from_config
from med_core.trainers import MultimodalTrainer
from med_core.configs import load_config
from breast_cancer_dataset import BreastCancerMultiModalDataset

def main():
    config = load_config("configs/breast_cancer_config.yaml")

    # 创建数据集
    train_dataset = BreastCancerMultiModalDataset(
        data_dir=config['data']['data_dir'],
        split='train',
        use_pathology=config['data']['modalities']['pathology']['enabled'],
        use_mri=config['data']['modalities']['mri']['enabled'],
        use_clinical=config['data']['modalities']['clinical']['enabled']
    )

    val_dataset = BreastCancerMultiModalDataset(
        data_dir=config['data']['data_dir'],
        split='val',
        use_pathology=config['data']['modalities']['pathology']['enabled'],
        use_mri=config['data']['modalities']['mri']['enabled'],
        use_clinical=config['data']['modalities']['clinical']['enabled']
    )

    train_loader = DataLoader(
        train_dataset,
        batch_size=config['data']['batch_size'],
        shuffle=True,
        num_workers=config['data']['num_workers']
    )

    val_loader = DataLoader(
        val_dataset,
        batch_size=config['data']['batch_size'],
        shuffle=False,
        num_workers=config['data']['num_workers']
    )

    # 构建模型
    model = build_model_from_config(config['model'])

    # 训练
    trainer = MultimodalTrainer(
        model=model,
        config=config,
        train_loader=train_loader,
        val_loader=val_loader
    )

    trainer.train()

if __name__ == "__main__":
    main()

运行训练:

bash
uv run python train_breast_cancer.py

8. 融合策略对比

8.1 对比不同融合方法

python
# compare_fusion_strategies.py(自定义实验脚本示意)
fusion_strategies = [
    "concatenate",  # 简单拼接
    "gated",        # 门控融合
    "attention",    # 注意力融合
    "cross_attention",  # 交叉注意力
    "bilinear"      # 双线性融合
]

results = {}

for fusion_type in fusion_strategies:
    print(f"\n训练 {fusion_type} 融合策略...")

    # 修改配置
    config['model']['fusion']['type'] = fusion_type

    # 构建模型
    model = build_model_from_config(config['model'])

    # 训练
    trainer = MultimodalTrainer(model, config, train_loader, val_loader)
    metrics = trainer.train()

    results[fusion_type] = metrics

# 可视化对比
import matplotlib.pyplot as plt

fusion_names = list(results.keys())
f1_scores = [results[f]['best_val_f1'] for f in fusion_names]

plt.figure(figsize=(10, 6))
plt.bar(fusion_names, f1_scores)
plt.ylabel('F1 Score (Macro)')
plt.title('不同融合策略性能对比')
plt.xticks(rotation=45)
plt.tight_layout()
plt.savefig('outputs/breast_cancer/fusion_comparison.png', dpi=300)

9. 注意力可视化

9.1 MIL 注意力权重

python
# visualize_mil_attention.py(自定义可视化脚本示意)
import torch
import matplotlib.pyplot as plt
import numpy as np

def visualize_mil_attention(model, sample, device='cuda'):
    """可视化 MIL 注意力权重"""
    model.eval()
    model.to(device)

    pathology_patches = sample['pathology'].unsqueeze(0).to(device)  # (1, N, 3, H, W)

    with torch.no_grad():
        # 前向传播并获取注意力权重
        output, aux = model({
            'pathology': pathology_patches,
            'mri': sample['mri'].unsqueeze(0).to(device),
            'clinical': sample['clinical'].unsqueeze(0).to(device)
        })

        attention_weights = aux['pathology_attention']  # (1, N)

    # 可视化
    attention_weights = attention_weights[0].cpu().numpy()
    num_patches = len(attention_weights)

    fig, axes = plt.subplots(2, 5, figsize=(20, 8))
    axes = axes.flatten()

    # 显示注意力权重最高的 10 个 patches
    top_indices = np.argsort(attention_weights)[-10:][::-1]

    for i, idx in enumerate(top_indices):
        patch = pathology_patches[0, idx].cpu().permute(1, 2, 0).numpy()
        weight = attention_weights[idx]

        axes[i].imshow(patch)
        axes[i].set_title(f'Patch {idx}\nAttention: {weight:.4f}')
        axes[i].axis('off')

    plt.tight_layout()
    plt.savefig(f'outputs/breast_cancer/mil_attention_{sample["patient_id"]}.png',
                dpi=300, bbox_inches='tight')

# 可视化测试样本
test_dataset = BreastCancerMultiModalDataset(data_dir='data/breast_cancer', split='test')
for i in range(5):
    sample = test_dataset[i]
    visualize_mil_attention(model, sample)

9.2 模态融合注意力

python
# visualize_fusion_attention.py(自定义可视化脚本示意)
def visualize_fusion_attention(model, sample, device='cuda'):
    """可视化模态间注意力"""
    model.eval()
    model.to(device)

    with torch.no_grad():
        output, aux = model({
            'pathology': sample['pathology'].unsqueeze(0).to(device),
            'mri': sample['mri'].unsqueeze(0).to(device),
            'clinical': sample['clinical'].unsqueeze(0).to(device)
        })

        # 获取模态注意力权重
        modality_attention = aux['fusion_attention']  # (1, 3)

    modality_names = ['病理', 'MRI', '临床']
    weights = modality_attention[0].cpu().numpy()

    # 绘制条形图
    plt.figure(figsize=(8, 6))
    colors = ['#FF6B6B', '#4ECDC4', '#45B7D1']
    bars = plt.bar(modality_names, weights, color=colors)

    plt.ylabel('注意力权重')
    plt.title(f'模态融合注意力分布\n患者: {sample["patient_id"]}')
    plt.ylim(0, 1)

    # 添加数值标签
    for bar, weight in zip(bars, weights):
        height = bar.get_height()
        plt.text(bar.get_x() + bar.get_width()/2., height,
                f'{weight:.3f}',
                ha='center', va='bottom')

    plt.tight_layout()
    plt.savefig(f'outputs/breast_cancer/fusion_attention_{sample["patient_id"]}.png',
                dpi=300, bbox_inches='tight')

10. 预期结果

在 TCGA-BRCA 数据集上,多模态融合模型预期性能:

模态组合F1 ScoreAUC
仅病理0.720.85
仅 MRI0.650.78
仅临床0.580.70
病理 + 临床0.760.88
病理 + MRI0.780.90
三模态融合0.820.92

融合策略对比:

  • Concatenate: F1 = 0.78
  • Gated: F1 = 0.80
  • Attention: F1 = 0.82 (最佳)
  • Cross-Attention: F1 = 0.81
  • Bilinear: F1 = 0.79

11. 消融实验

python
# ablation_study.py(自定义实验脚本示意)
ablation_configs = [
    {"pathology": True, "mri": False, "clinical": False},
    {"pathology": False, "mri": True, "clinical": False},
    {"pathology": False, "mri": False, "clinical": True},
    {"pathology": True, "mri": True, "clinical": False},
    {"pathology": True, "mri": False, "clinical": True},
    {"pathology": False, "mri": True, "clinical": True},
    {"pathology": True, "mri": True, "clinical": True},
]

results = []

for cfg in ablation_configs:
    modalities = [k for k, v in cfg.items() if v]
    print(f"\n训练模态: {', '.join(modalities)}")

    # 创建数据集
    train_dataset = BreastCancerMultiModalDataset(
        data_dir='data/breast_cancer',
        split='train',
        use_pathology=cfg['pathology'],
        use_mri=cfg['mri'],
        use_clinical=cfg['clinical']
    )

    # 训练并记录结果
    # ... (训练代码)

    results.append({
        'modalities': '+'.join(modalities),
        'f1': metrics['best_val_f1'],
        'auc': metrics['best_val_auc']
    })

# 可视化
import pandas as pd
df = pd.DataFrame(results)
print(df)

12. 常见问题

Q: WSI 文件太大,内存不足?

  • 使用 patch 级别处理,不要一次加载整个 WSI
  • 减少 max_patches 数量
  • 使用更小的 patch_size(如 224×224)

Q: 如何处理缺失模态?

  • 使用零填充或平均特征填充
  • 训练时随机丢弃模态(模态 dropout)
  • 使用模态特定的缺失指示符

Q: 不同模态特征尺度差异大?

  • 对每个模态分别归一化
  • 使用 LayerNorm 或 BatchNorm
  • 调整融合层的权重初始化

13. 总结

本案例展示了多模态医学数据融合的完整流程:

  1. 病理 WSI 的 patch 提取和 MIL 聚合
  2. MRI 3D 体积数据处理
  3. 临床表格数据整合
  4. 多种融合策略对比
  5. 注意力机制可视化

关键收获:

  • 多模态融合显著提升性能(+10% F1)
  • 注意力融合优于简单拼接
  • MIL 聚合有效处理 WSI 数据
  • 模态间互补信息很重要

下一步: 尝试 案例 3:生存预测,学习时间序列分析和生存模型。

Released under the MIT License.