• 文档 >
  • Torch Export 与 Cudagraphs
快捷方式

Torch Export 与 Cudagraphs

此交互式脚本旨在概述在 ir=”dynamo” 路径中使用 Torch-TensorRT Cudagraphs 集成的过程。该功能在 torch.compile 路径中也类似地工作。

导入和模型定义

import torch
import torch_tensorrt
import torchvision.models as models

使用 torch_tensorrt.compile 和默认设置进行编译

# We begin by defining and initializing a model
model = models.resnet18(pretrained=True).eval().to("cuda")

# Define sample inputs
inputs = torch.randn((16, 3, 224, 224)).cuda()
# Next, we compile the model using torch_tensorrt.compile
# We use the `ir="dynamo"` flag here, and `ir="torch_compile"` should
# work with cudagraphs as well.
opt = torch_tensorrt.compile(
    model,
    ir="dynamo",
    inputs=torch_tensorrt.Input(
        min_shape=(1, 3, 224, 224),
        opt_shape=(8, 3, 224, 224),
        max_shape=(16, 3, 224, 224),
        dtype=torch.float,
        name="x",
    ),
)

使用 Cudagraphs 集成进行推理

# We can enable the cudagraphs API with a context manager
with torch_tensorrt.runtime.enable_cudagraphs(opt) as cudagraphs_module:
    out_trt = cudagraphs_module(inputs)

# Alternatively, we can set the cudagraphs mode for the session
torch_tensorrt.runtime.set_cudagraphs_mode(True)
out_trt = opt(inputs)

# We can also turn off cudagraphs mode and perform inference as normal
torch_tensorrt.runtime.set_cudagraphs_mode(False)
out_trt = opt(inputs)
# If we provide new input shapes, cudagraphs will re-record the graph
inputs_2 = torch.randn((8, 3, 224, 224)).cuda()
inputs_3 = torch.randn((4, 3, 224, 224)).cuda()

with torch_tensorrt.runtime.enable_cudagraphs(opt) as cudagraphs_module:
    out_trt_2 = cudagraphs_module(inputs_2)
    out_trt_3 = cudagraphs_module(inputs_3)

带有包含图形断点的模块的 Cuda 图

当 CUDA 图应用于包含图形断点的 TensorRT 模型时,每个断点都会引入额外的开销。发生这种情况的原因是图形断点阻止整个模型作为一个连续的优化单元执行。因此,CUDA 图通常提供的一些性能优势(例如减少的内核启动开销和提高的执行效率)可能会降低。将包装的运行时模块与 CUDA 图一起使用,允许您将操作序列封装到可以有效执行的图中,即使存在图形断点也是如此。如果 TensorRT 模块有图形断点,CUDA Graph 上下文管理器将返回一个 wrapped_module。此模块捕获整个执行图,通过减少内核启动开销和提高性能,从而实现后续推理期间的有效重放。请注意,使用包装器模块初始化涉及一个预热阶段,在该阶段模块会执行多次。此预热确保内存分配和初始化不会记录在 CUDA 图中,这有助于保持一致的执行路径并优化性能。

class SampleModel(torch.nn.Module):
    def forward(self, x):
        return torch.relu((x + 2) * 0.5)


model = SampleModel().eval().cuda()
input = torch.randn((1, 3, 224, 224)).to("cuda")

# The 'torch_executed_ops' compiler option is used in this example to intentionally introduce graph breaks within the module.
# Note: The Dynamo backend is required for the CUDA Graph context manager to handle modules in an Ahead-Of-Time (AOT) manner.
opt_with_graph_break = torch_tensorrt.compile(
    model,
    ir="dynamo",
    inputs=[input],
    min_block_size=1,
    pass_through_build_failures=True,
    torch_executed_ops={"torch.ops.aten.mul.Tensor"},
)

如果模块有图形断点,则整个子模块将被 CUDA 图记录和重放

with torch_tensorrt.runtime.enable_cudagraphs(
    opt_with_graph_break
) as cudagraphs_module:
    cudagraphs_module(input)

脚本总运行时间: (0 分钟 0.000 秒)

由 Sphinx-Gallery 生成的图库

文档

访问 PyTorch 的全面开发者文档

查看文档

教程

获取面向初学者和高级开发者的深度教程

查看教程

资源

查找开发资源并获得问题解答

查看资源