理解基于 TorchDynamo 的 ONNX 导出器内存使用情况¶
先前的基于 TorchScript 的 ONNX 导出器会执行一次模型以跟踪其执行过程,如果模型的内存需求超过可用的 GPU 内存,则可能导致 GPU 内存耗尽。新的基于 TorchDynamo 的 ONNX 导出器已解决此问题。
基于 TorchDynamo 的 ONNX 导出器利用 FakeTensorMode 来避免在导出过程中执行实际的张量计算。与基于 TorchScript 的 ONNX 导出器相比,这种方法可显著降低内存使用量。
以下示例演示了基于 TorchScript 和基于 TorchDynamo 的 ONNX 导出器之间的内存使用差异。在此示例中,我们使用了 MONAI 中的 HighResNet 模型。在继续之前,请从 PyPI 安装它
pip install monai
PyTorch 提供了一个用于捕获和可视化内存使用情况跟踪的工具。我们将使用此工具记录两个导出器在导出过程中的内存使用情况,并比较结果。您可以在 理解 CUDA 内存使用情况 上找到有关此工具的更多详细信息。
基于 TorchScript 的导出器¶
可以运行以下代码来生成快照文件,该文件记录导出过程中分配的 CUDA 内存的状态。
import torch
from torch.onnx.utils import export
from monai.networks.nets import (
HighResNet,
)
torch.cuda.memory._record_memory_history()
model = HighResNet(
spatial_dims=3, in_channels=1, out_channels=3, norm_type="batch"
).eval()
model = model.to("cuda")
data = torch.randn(30, 1, 48, 48, 48, dtype=torch.float32).to("cuda")
with torch.no_grad():
export(
model,
data,
"torchscript_exporter_highresnet.onnx",
)
snapshot_name = f"torchscript_exporter_example.pickle"
print(f"generate {snapshot_name}")
torch.cuda.memory._dump_snapshot(snapshot_name)
print(f"Export is done.")
打开 pytorch.org/memory_viz 并将生成的 pickle 快照文件拖放到可视化工具中。内存使用情况描述如下

通过此图,我们可以看到内存使用峰值高于 2.8GB。
基于 TorchDynamo 的导出器¶
可以运行以下代码来生成快照文件,该文件记录导出过程中分配的 CUDA 内存的状态。
import torch
from monai.networks.nets import (
HighResNet,
)
torch.cuda.memory._record_memory_history()
model = HighResNet(
spatial_dims=3, in_channels=1, out_channels=3, norm_type="batch"
).eval()
model = model.to("cuda")
data = torch.randn(30, 1, 48, 48, 48, dtype=torch.float32).to("cuda")
with torch.no_grad():
onnx_program = torch.onnx.export(
model,
data,
"test_faketensor.onnx",
dynamo=True,
)
snapshot_name = f"torchdynamo_exporter_example.pickle"
print(f"generate {snapshot_name}")
torch.cuda.memory._dump_snapshot(snapshot_name)
print(f"Export is done.")
打开 pytorch.org/memory_viz 并将生成的 pickle 快照文件拖放到可视化工具中。内存使用情况描述如下

通过此图,我们可以看到内存使用峰值仅为 45MB 左右。与基于 TorchScript 的导出器的内存使用峰值相比,内存使用量减少了 98%。