Skip to content

模型导出指南

文档状态:Stable

本指南介绍如何将 MedFusion 模型导出为 ONNX 和 TorchScript 格式,以便在不同平台和环境中部署。

前置条件

  • 模型权重可正常加载
  • 能完成一次基础推理(导出前先确认模型本身可用)
  • 已安装 onnxruntime(如需验证 ONNX 推理)

概述

MedFusion 支持两种主要的模型导出格式:

  1. ONNX (Open Neural Network Exchange): 跨平台、跨框架的模型格式
  2. TorchScript: PyTorch 的序列化格式,保持与 PyTorch 生态的兼容性

快速开始

基本导出

python
from med_core.utils.export import export_model
import torch.nn as nn

# 创建模型
model = MyModel()

# 导出为 ONNX
export_model(
    model=model,
    output_path="model.onnx",
    input_shape=(3, 224, 224),
    format="onnx",
)

# 导出为 TorchScript
export_model(
    model=model,
    output_path="model.pt",
    input_shape=(3, 224, 224),
    format="torchscript",
)

✅ 预期输出

完成导出后,至少应满足:

  • 生成 model.onnx / model.pt 文件
  • ONNX 模型可被 onnxruntime 成功加载并推理
  • TorchScript 模型可被 torch.jit.load 成功加载并推理

ONNX 导出

基本用法

python
from med_core.utils.export import ModelExporter

# 创建导出器
exporter = ModelExporter(
    model=model,
    input_shape=(3, 224, 224),
    device="cpu",
)

# 导出为 ONNX
exporter.export_onnx(
    output_path="model.onnx",
    opset_version=11,
    input_names=["image"],
    output_names=["logits"],
)

# 验证导出的模型
exporter.verify_onnx("model.onnx")

动态轴

支持动态 batch size 和图像尺寸:

python
exporter.export_onnx(
    output_path="model_dynamic.onnx",
    input_names=["image"],
    output_names=["logits"],
    dynamic_axes={
        "image": {
            0: "batch_size",  # 动态 batch
            2: "height",      # 动态高度
            3: "width",       # 动态宽度
        },
        "logits": {0: "batch_size"},
    },
)

ONNX 推理

python
import onnxruntime as ort
import numpy as np

# 加载 ONNX 模型
session = ort.InferenceSession("model.onnx")

# 准备输入
input_data = np.random.randn(1, 3, 224, 224).astype(np.float32)
input_name = session.get_inputs()[0].name

# 推理
outputs = session.run(None, {input_name: input_data})
print(f"Output shape: {outputs[0].shape}")

ONNX 优势

  • ✅ 跨平台部署(Windows、Linux、macOS、移动端)
  • ✅ 支持多种推理引擎(ONNX Runtime、TensorRT、OpenVINO)
  • ✅ 硬件加速(CPU、GPU、NPU)
  • ✅ 模型优化和量化
  • ✅ 与其他框架互操作

TorchScript 导出

Trace 方法

适用于大多数模型:

python
exporter.export_torchscript(
    output_path="model_trace.pt",
    method="trace",
    optimize=True,
)

Script 方法

适用于包含动态控制流的模型:

python
exporter.export_torchscript(
    output_path="model_script.pt",
    method="script",
    optimize=True,
)

TorchScript 推理

python
import torch

# 加载 TorchScript 模型
model = torch.jit.load("model.pt")
model.eval()

# 推理
x = torch.randn(1, 3, 224, 224)
with torch.no_grad():
    output = model(x)
print(f"Output shape: {output.shape}")

TorchScript 优势

  • ✅ 完全兼容 PyTorch 生态
  • ✅ 保留模型结构和参数
  • ✅ 支持 C++ 部署
  • ✅ 性能优化
  • ✅ 易于调试

多模态模型导出

导出多输入模型

python
from med_core.utils.export import MultiModalExporter

# 创建多模态导出器
exporter = MultiModalExporter(
    model=multimodal_model,
    input_shapes={
        "image": (3, 224, 224),
        "tabular": (10,),
    },
    device="cpu",
)

# 导出为 ONNX
exporter.export_onnx(
    output_path="multimodal_model.onnx",
    input_names=["image", "tabular"],
    output_names=["logits"],
)

# 导出为 TorchScript
exporter.export_torchscript(
    output_path="multimodal_model.pt",
    method="trace",
)

多模态推理

python
# ONNX 推理
import onnxruntime as ort

session = ort.InferenceSession("multimodal_model.onnx")
outputs = session.run(
    None,
    {
        "image": image_data,
        "tabular": tabular_data,
    }
)

# TorchScript 推理
model = torch.jit.load("multimodal_model.pt")
output = model(image_tensor, tabular_tensor)

常见失败与排查

  • 导出报 shape mismatch:检查 input_shape 与模型实际输入一致
  • ONNX 验证失败:先用最小 batch(如 1)验证,再加动态轴
  • TorchScript script 失败:先尝试 trace,再处理控制流兼容

完整示例

示例 1: 分类模型导出

python
import torch
import torch.nn as nn
from med_core.utils.export import ModelExporter

# 定义模型
class Classifier(nn.Module):
    def __init__(self, num_classes=10):
        super().__init__()
        self.backbone = nn.Sequential(
            nn.Conv2d(3, 64, 3, padding=1),
            nn.ReLU(),
            nn.AdaptiveAvgPool2d(1),
            nn.Flatten(),
        )
        self.classifier = nn.Linear(64, num_classes)
    
    def forward(self, x):
        features = self.backbone(x)
        logits = self.classifier(features)
        return logits

# 创建并训练模型
model = Classifier(num_classes=10)
# ... 训练代码 ...

# 导出模型
model.eval()
exporter = ModelExporter(model, input_shape=(3, 224, 224))

# 导出为 ONNX
exporter.export_onnx(
    "classifier.onnx",
    input_names=["image"],
    output_names=["logits"],
    dynamic_axes={
        "image": {0: "batch_size"},
        "logits": {0: "batch_size"},
    },
)

# 验证
exporter.verify_onnx("classifier.onnx")

# 导出为 TorchScript
exporter.export_torchscript("classifier.pt", method="trace")
exporter.verify_torchscript("classifier.pt")

示例 2: 多模态模型导出

python
from med_core.utils.export import MultiModalExporter

# 定义多模态模型
class MultiModalClassifier(nn.Module):
    def __init__(self):
        super().__init__()
        self.image_encoder = nn.Sequential(
            nn.Conv2d(3, 64, 3, padding=1),
            nn.ReLU(),
            nn.AdaptiveAvgPool2d(1),
            nn.Flatten(),
        )
        self.tabular_encoder = nn.Sequential(
            nn.Linear(10, 64),
            nn.ReLU(),
        )
        self.fusion = nn.Linear(128, 10)
    
    def forward(self, image, tabular):
        image_feat = self.image_encoder(image)
        tabular_feat = self.tabular_encoder(tabular)
        fused = torch.cat([image_feat, tabular_feat], dim=1)
        output = self.fusion(fused)
        return output

# 创建并训练模型
model = MultiModalClassifier()
# ... 训练代码 ...

# 导出模型
model.eval()
exporter = MultiModalExporter(
    model,
    input_shapes={
        "image": (3, 224, 224),
        "tabular": (10,),
    },
)

# 导出为 ONNX
exporter.export_onnx(
    "multimodal_classifier.onnx",
    input_names=["image", "tabular"],
    output_names=["logits"],
)

# 导出为 TorchScript
exporter.export_torchscript("multimodal_classifier.pt")

示例 3: 生产环境部署

python
# 1. 导出优化的模型
exporter = ModelExporter(model, input_shape=(3, 224, 224))

# ONNX (用于跨平台部署)
exporter.export_onnx(
    "model_production.onnx",
    opset_version=11,
    input_names=["image"],
    output_names=["logits"],
    dynamic_axes={
        "image": {0: "batch_size"},
        "logits": {0: "batch_size"},
    },
)

# TorchScript (用于 PyTorch 服务)
exporter.export_torchscript(
    "model_production.pt",
    method="trace",
    optimize=True,
)

# 2. 验证模型
assert exporter.verify_onnx("model_production.onnx")
assert exporter.verify_torchscript("model_production.pt")

# 3. 部署
# - ONNX: 使用 ONNX Runtime、TensorRT 等
# - TorchScript: 使用 TorchServe、自定义服务等

最佳实践

1. 导出前的准备

python
# 设置为评估模式
model.eval()

# 移除训练相关的操作
for module in model.modules():
    if isinstance(module, nn.Dropout):
        module.p = 0.0
    if isinstance(module, nn.BatchNorm2d):
        module.track_running_stats = False

# 测试模型
x = torch.randn(1, 3, 224, 224)
with torch.no_grad():
    output = model(x)
print(f"Output shape: {output.shape}")

2. 选择合适的格式

场景推荐格式原因
跨平台部署ONNX支持多种推理引擎
PyTorch 生态TorchScript完全兼容
移动端部署ONNX轻量级
边缘设备ONNX硬件加速
C++ 部署TorchScript易于集成

3. 优化建议

python
# ONNX 优化
exporter.export_onnx(
    "model.onnx",
    opset_version=11,  # 使用较新的 opset
    do_constant_folding=True,  # 常量折叠
)

# TorchScript 优化
exporter.export_torchscript(
    "model.pt",
    method="trace",
    optimize=True,  # 启用优化
)

# 量化(进一步优化)
# 参见 model_compression.md

4. 验证流程

python
# 1. 导出模型
exporter.export_onnx("model.onnx")

# 2. 验证输出一致性
assert exporter.verify_onnx("model.onnx")

# 3. 测试不同输入
test_inputs = [
    torch.randn(1, 3, 224, 224),
    torch.randn(2, 3, 224, 224),
    torch.randn(4, 3, 224, 224),
]

for x in test_inputs:
    # PyTorch
    with torch.no_grad():
        pytorch_out = model(x)
    
    # ONNX
    import onnxruntime as ort
    session = ort.InferenceSession("model.onnx")
    onnx_out = session.run(None, {"image": x.numpy()})[0]
    
    # 比较
    assert np.allclose(pytorch_out.numpy(), onnx_out, rtol=1e-3)

# 4. 性能测试
import time

# PyTorch
start = time.time()
for _ in range(100):
    with torch.no_grad():
        _ = model(x)
pytorch_time = time.time() - start

# ONNX
start = time.time()
for _ in range(100):
    _ = session.run(None, {"image": x.numpy()})
onnx_time = time.time() - start

print(f"PyTorch: {pytorch_time:.3f}s")
print(f"ONNX: {onnx_time:.3f}s")
print(f"Speedup: {pytorch_time / onnx_time:.2f}x")

常见问题

Q1: 导出失败怎么办?

A: 检查以下几点:

  1. 模型是否设置为评估模式 (model.eval())
  2. 是否有不支持的操作
  3. 是否有动态控制流(使用 script 而不是 trace)
  4. PyTorch 和 ONNX 版本是否兼容

Q2: 如何处理自定义算子?

A:

python
# 方法 1: 使用 TorchScript
exporter.export_torchscript("model.pt", method="script")

# 方法 2: 注册 ONNX 算子
from torch.onnx import register_custom_op_symbolic

@register_custom_op_symbolic("custom::my_op", opset_version=11)
def my_op_symbolic(g, input):
    return g.op("custom::MyOp", input)

Q3: 如何优化推理性能?

A:

  1. 使用 ONNX Runtime 的优化选项
  2. 启用硬件加速(GPU、TensorRT)
  3. 使用量化和剪枝
  4. 批处理推理

Q4: 导出的模型太大怎么办?

A:

  1. 使用模型压缩技术(量化、剪枝)
  2. 移除不必要的层
  3. 使用更小的骨干网络
  4. 压缩模型文件(gzip)

Q5: 如何处理版本兼容性?

A:

python
import torch
import onnx

print(f"PyTorch version: {torch.__version__}")
print(f"ONNX version: {onnx.__version__}")

# 使用兼容的 opset 版本
# PyTorch 1.9+: opset 11-13
# PyTorch 1.10+: opset 11-14
exporter.export_onnx("model.onnx", opset_version=11)

性能对比

推理速度

格式CPU (ms)GPU (ms)说明
PyTorch10.52.3基准
TorchScript9.82.11.07x 加速
ONNX Runtime8.21.81.28x 加速
TensorRT-1.21.92x 加速

模型大小

格式大小 (MB)压缩后 (MB)
PyTorch (.pth)102.425.6
TorchScript (.pt)102.825.8
ONNX (.onnx)98.224.5

参考资源

官方文档

代码

  • med_core/utils/export.py - 导出工具实现
  • examples/model_export_demo.py - 使用示例
  • tests/test_export.py - 单元测试

相关指南

更新日志

v0.2.0 (2026-02-20)

  • ✨ 新增 ModelExporter 类
  • ✨ 新增 MultiModalExporter 类
  • ✨ 支持 ONNX 导出
  • ✨ 支持 TorchScript 导出
  • ✨ 支持动态轴
  • ✨ 支持模型验证
  • ✨ 新增便捷函数 export_model
  • 📝 完善文档和示例
  • ✅ 添加完整的单元测试

Released under the MIT License.