• 文档 >
  • 使用 Cudagraphs 进行 Torch 导出
快捷方式

使用 Cudagraphs 进行 Torch 导出

此交互式脚本旨在概述如何在 ir=”dynamo” 路径中使用 Torch-TensorRT Cudagraphs 集成。此功能在 torch.compile 路径中也类似地工作。

导入和模型定义

import torch
import torchvision.models as models

import torch_tensorrt

使用默认设置使用 torch_tensorrt.compile 进行编译

# We begin by defining and initializing a model
model = models.resnet18(pretrained=True).eval().to("cuda")

# Define sample inputs
inputs = torch.randn((16, 3, 224, 224)).cuda()
# Next, we compile the model using torch_tensorrt.compile
# We use the `ir="dynamo"` flag here, and `ir="torch_compile"` should
# work with cudagraphs as well.
opt = torch_tensorrt.compile(
    model,
    ir="dynamo",
    inputs=torch_tensorrt.Input(
        min_shape=(1, 3, 224, 224),
        opt_shape=(8, 3, 224, 224),
        max_shape=(16, 3, 224, 224),
        dtype=torch.float,
        name="x",
    ),
)

使用 Cudagraphs 集成进行推理

# We can enable the cudagraphs API with a context manager
with torch_tensorrt.runtime.enable_cudagraphs():
    out_trt = opt(inputs)

# Alternatively, we can set the cudagraphs mode for the session
torch_tensorrt.runtime.set_cudagraphs_mode(True)
out_trt = opt(inputs)

# We can also turn off cudagraphs mode and perform inference as normal
torch_tensorrt.runtime.set_cudagraphs_mode(False)
out_trt = opt(inputs)
# If we provide new input shapes, cudagraphs will re-record the graph
inputs_2 = torch.randn((8, 3, 224, 224)).cuda()
inputs_3 = torch.randn((4, 3, 224, 224)).cuda()

with torch_tensorrt.runtime.enable_cudagraphs():
    out_trt_2 = opt(inputs_2)
    out_trt_3 = opt(inputs_3)

脚本的总运行时间:(0 分 0.000 秒)

由 Sphinx-Gallery 生成的图库

文档

访问 PyTorch 的全面开发者文档

查看文档

教程

获取针对初学者和高级开发人员的深入教程

查看教程

资源

查找开发资源并解答您的问题

查看资源