• 教程 >
  • torch.export Python 运行时 AOTInductor 教程(Beta 版)
快捷方式

torch.export Python 运行时 AOTInductor 教程(Beta 版)

作者:Ankith Gunapal、Bin Bao、Angela Yi

警告

torch._inductor.aot_compiletorch._export.aot_load 处于 Beta 状态,可能会发生破坏向后兼容性的更改。本教程提供了一个使用 Python 运行时进行模型部署的 API 使用示例。

之前已展示了如何使用 AOTInductor 对 PyTorch 导出的模型进行提前编译,方法是创建一个可以在非 Python 环境中运行的共享库。请参阅此处

在本教程中,您将学习如何使用 AOTInductor 进行 Python 运行时的端到端示例。我们将了解如何使用 torch._inductor.aot_compile() 以及 torch.export.export() 生成共享库。此外,我们将检查如何使用 torch._export.aot_load() 在 Python 运行时执行共享库。您将了解到使用 AOTInductor 在第一次推理时看到的加速效果,尤其是在使用 max-autotune 模式时,该模式的执行可能需要一些时间。

内容

先决条件

学习目标

  • 如何在 python 运行时使用 AOTInductor。

  • 如何使用 torch._inductor.aot_compile() 以及 torch.export.export() 生成共享库

  • 如何使用 torch._export.aot_load() 在 Python 运行时运行共享库。

  • 何时在 python 运行时使用 AOTInductor

模型编译

我们将使用 TorchVision 预训练的 ResNet18 模型和 TorchInductor 在使用 torch._inductor.aot_compile() 导出的 PyTorch 程序上。

注意

此 API 还支持 torch.compile() 选项,例如 mode。这意味着如果在启用 CUDA 的设备上使用,例如,您可以设置 "max_autotune": True,它利用基于 Triton 的矩阵乘法和卷积,并默认启用 CUDA 图。

我们还为批次维度指定了 dynamic_shapes。在此示例中,min=2 不是错误,并在 0/1 专业化问题 中进行了解释。

import os
import torch
from torchvision.models import ResNet18_Weights, resnet18

model = resnet18(weights=ResNet18_Weights.DEFAULT)
model.eval()

with torch.inference_mode():

    # Specify the generated shared library path
    aot_compile_options = {
            "aot_inductor.output_path": os.path.join(os.getcwd(), "resnet18_pt2.so"),
    }
    if torch.cuda.is_available():
        device = "cuda"
        aot_compile_options.update({"max_autotune": True})
    else:
        device = "cpu"

    model = model.to(device=device)
    example_inputs = (torch.randn(2, 3, 224, 224, device=device),)

    # min=2 is not a bug and is explained in the 0/1 Specialization Problem
    batch_dim = torch.export.Dim("batch", min=2, max=32)
    exported_program = torch.export.export(
        model,
        example_inputs,
        # Specify the first dimension of the input x as dynamic
        dynamic_shapes={"x": {0: batch_dim}},
    )
    so_path = torch._inductor.aot_compile(
        exported_program.module(),
        example_inputs,
        # Specify the generated shared library path
        options=aot_compile_options
    )
Downloading: "https://download.pytorch.org/models/resnet18-f37072fd.pth" to /var/lib/ci-user/.cache/torch/hub/checkpoints/resnet18-f37072fd.pth


  0%|          | 0.00/44.7M [00:00<?, ?B/s]

 94%|#########4| 42.1M/44.7M [00:00<00:00, 441MB/s]
100%|##########| 44.7M/44.7M [00:00<00:00, 430MB/s]
AUTOTUNE convolution(2x3x224x224, 64x3x7x7)
  convolution 0.0463 ms 100.0%
  triton_convolution2d_0 0.1038 ms 44.6% ALLOW_TF32=True, BLOCK_K=16, BLOCK_M=64, BLOCK_N=64, GROUPS=1, KERNEL_H=7, KERNEL_W=7, PADDING_H=3, PADDING_W=3, STRIDE_H=2, STRIDE_W=2, UNROLL=False, num_stages=2, num_warps=4
  triton_convolution2d_4 0.1069 ms 43.3% ALLOW_TF32=True, BLOCK_K=16, BLOCK_M=64, BLOCK_N=64, GROUPS=1, KERNEL_H=7, KERNEL_W=7, PADDING_H=3, PADDING_W=3, STRIDE_H=2, STRIDE_W=2, UNROLL=False, num_stages=2, num_warps=8
  triton_convolution2d_3 0.1280 ms 36.2% ALLOW_TF32=True, BLOCK_K=16, BLOCK_M=128, BLOCK_N=64, GROUPS=1, KERNEL_H=7, KERNEL_W=7, PADDING_H=3, PADDING_W=3, STRIDE_H=2, STRIDE_W=2, UNROLL=False, num_stages=2, num_warps=8
  triton_convolution2d_1 0.1404 ms 33.0% ALLOW_TF32=True, BLOCK_K=16, BLOCK_M=256, BLOCK_N=64, GROUPS=1, KERNEL_H=7, KERNEL_W=7, PADDING_H=3, PADDING_W=3, STRIDE_H=2, STRIDE_W=2, UNROLL=False, num_stages=2, num_warps=4
  triton_convolution2d_5 0.1862 ms 24.9% ALLOW_TF32=True, BLOCK_K=16, BLOCK_M=256, BLOCK_N=64, GROUPS=1, KERNEL_H=7, KERNEL_W=7, PADDING_H=3, PADDING_W=3, STRIDE_H=2, STRIDE_W=2, UNROLL=False, num_stages=2, num_warps=8
  triton_convolution2d_2 0.2195 ms 21.1% ALLOW_TF32=True, BLOCK_K=16, BLOCK_M=1024, BLOCK_N=16, GROUPS=1, KERNEL_H=7, KERNEL_W=7, PADDING_H=3, PADDING_W=3, STRIDE_H=2, STRIDE_W=2, UNROLL=False, num_stages=1, num_warps=8
SingleProcess AUTOTUNE benchmarking takes 0.8298 seconds and 0.0073 seconds precompiling
AUTOTUNE convolution(2x64x56x56, 64x64x3x3)
  convolution 0.0439 ms 100.0%
  triton_convolution2d_6 0.0742 ms 59.2% ALLOW_TF32=True, BLOCK_K=16, BLOCK_M=64, BLOCK_N=64, GROUPS=1, KERNEL_H=3, KERNEL_W=3, PADDING_H=1, PADDING_W=1, STRIDE_H=1, STRIDE_W=1, UNROLL=False, num_stages=2, num_warps=4
  triton_convolution2d_9 0.0753 ms 58.3% ALLOW_TF32=True, BLOCK_K=32, BLOCK_M=128, BLOCK_N=64, GROUPS=1, KERNEL_H=3, KERNEL_W=3, PADDING_H=1, PADDING_W=1, STRIDE_H=1, STRIDE_W=1, UNROLL=False, num_stages=2, num_warps=8
  triton_convolution2d_12 0.0775 ms 56.6% ALLOW_TF32=True, BLOCK_K=32, BLOCK_M=256, BLOCK_N=64, GROUPS=1, KERNEL_H=3, KERNEL_W=3, PADDING_H=1, PADDING_W=1, STRIDE_H=1, STRIDE_W=1, UNROLL=False, num_stages=2, num_warps=8
  triton_convolution2d_11 0.0837 ms 52.5% ALLOW_TF32=True, BLOCK_K=32, BLOCK_M=64, BLOCK_N=64, GROUPS=1, KERNEL_H=3, KERNEL_W=3, PADDING_H=1, PADDING_W=1, STRIDE_H=1, STRIDE_W=1, UNROLL=False, num_stages=2, num_warps=8
  triton_convolution2d_10 0.0840 ms 52.2% ALLOW_TF32=True, BLOCK_K=32, BLOCK_M=64, BLOCK_N=64, GROUPS=1, KERNEL_H=3, KERNEL_W=3, PADDING_H=1, PADDING_W=1, STRIDE_H=1, STRIDE_W=1, UNROLL=False, num_stages=2, num_warps=4
  triton_convolution2d_8 0.1412 ms 31.1% ALLOW_TF32=True, BLOCK_K=16, BLOCK_M=1024, BLOCK_N=16, GROUPS=1, KERNEL_H=3, KERNEL_W=3, PADDING_H=1, PADDING_W=1, STRIDE_H=1, STRIDE_W=1, UNROLL=False, num_stages=1, num_warps=8
  triton_convolution2d_7 0.1451 ms 30.3% ALLOW_TF32=True, BLOCK_K=16, BLOCK_M=256, BLOCK_N=64, GROUPS=1, KERNEL_H=3, KERNEL_W=3, PADDING_H=1, PADDING_W=1, STRIDE_H=1, STRIDE_W=1, UNROLL=False, num_stages=2, num_warps=4
SingleProcess AUTOTUNE benchmarking takes 0.9664 seconds and 0.0005 seconds precompiling
AUTOTUNE convolution(2x64x56x56, 128x64x3x3)
  convolution 0.0347 ms 100.0%
  triton_convolution2d_38 0.0628 ms 55.3% ALLOW_TF32=True, BLOCK_K=32, BLOCK_M=64, BLOCK_N=64, GROUPS=1, KERNEL_H=3, KERNEL_W=3, PADDING_H=1, PADDING_W=1, STRIDE_H=2, STRIDE_W=2, UNROLL=False, num_stages=2, num_warps=4
  triton_convolution2d_40 0.0819 ms 42.4% ALLOW_TF32=True, BLOCK_K=32, BLOCK_M=256, BLOCK_N=64, GROUPS=1, KERNEL_H=3, KERNEL_W=3, PADDING_H=1, PADDING_W=1, STRIDE_H=2, STRIDE_W=2, UNROLL=False, num_stages=2, num_warps=8
  triton_convolution2d_34 0.0865 ms 40.1% ALLOW_TF32=True, BLOCK_K=16, BLOCK_M=64, BLOCK_N=128, GROUPS=1, KERNEL_H=3, KERNEL_W=3, PADDING_H=1, PADDING_W=1, STRIDE_H=2, STRIDE_W=2, UNROLL=False, num_stages=2, num_warps=4
  triton_convolution2d_39 0.0913 ms 38.0% ALLOW_TF32=True, BLOCK_K=32, BLOCK_M=64, BLOCK_N=128, GROUPS=1, KERNEL_H=3, KERNEL_W=3, PADDING_H=1, PADDING_W=1, STRIDE_H=2, STRIDE_W=2, UNROLL=False, num_stages=2, num_warps=8
  triton_convolution2d_37 0.1070 ms 32.5% ALLOW_TF32=True, BLOCK_K=32, BLOCK_M=128, BLOCK_N=128, GROUPS=1, KERNEL_H=3, KERNEL_W=3, PADDING_H=1, PADDING_W=1, STRIDE_H=2, STRIDE_W=2, UNROLL=False, num_stages=2, num_warps=8
  triton_convolution2d_35 0.1627 ms 21.3% ALLOW_TF32=True, BLOCK_K=16, BLOCK_M=256, BLOCK_N=64, GROUPS=1, KERNEL_H=3, KERNEL_W=3, PADDING_H=1, PADDING_W=1, STRIDE_H=2, STRIDE_W=2, UNROLL=False, num_stages=2, num_warps=4
  triton_convolution2d_36 0.3053 ms 11.4% ALLOW_TF32=True, BLOCK_K=16, BLOCK_M=1024, BLOCK_N=16, GROUPS=1, KERNEL_H=3, KERNEL_W=3, PADDING_H=1, PADDING_W=1, STRIDE_H=2, STRIDE_W=2, UNROLL=False, num_stages=1, num_warps=8
SingleProcess AUTOTUNE benchmarking takes 0.9781 seconds and 0.0007 seconds precompiling
AUTOTUNE convolution(2x64x56x56, 128x64x1x1)
  triton_convolution2d_52 0.0109 ms 100.0% ALLOW_TF32=True, BLOCK_K=32, BLOCK_M=64, BLOCK_N=64, GROUPS=1, KERNEL_H=1, KERNEL_W=1, PADDING_H=0, PADDING_W=0, STRIDE_H=2, STRIDE_W=2, UNROLL=True, num_stages=2, num_warps=4
  triton_convolution2d_53 0.0124 ms 87.7% ALLOW_TF32=True, BLOCK_K=32, BLOCK_M=64, BLOCK_N=128, GROUPS=1, KERNEL_H=1, KERNEL_W=1, PADDING_H=0, PADDING_W=0, STRIDE_H=2, STRIDE_W=2, UNROLL=True, num_stages=2, num_warps=8
  triton_convolution2d_48 0.0125 ms 87.4% ALLOW_TF32=True, BLOCK_K=16, BLOCK_M=64, BLOCK_N=128, GROUPS=1, KERNEL_H=1, KERNEL_W=1, PADDING_H=0, PADDING_W=0, STRIDE_H=2, STRIDE_W=2, UNROLL=True, num_stages=2, num_warps=4
  convolution 0.0131 ms 83.6%
  triton_convolution2d_54 0.0155 ms 70.5% ALLOW_TF32=True, BLOCK_K=32, BLOCK_M=256, BLOCK_N=64, GROUPS=1, KERNEL_H=1, KERNEL_W=1, PADDING_H=0, PADDING_W=0, STRIDE_H=2, STRIDE_W=2, UNROLL=True, num_stages=2, num_warps=8
  triton_convolution2d_51 0.0159 ms 68.5% ALLOW_TF32=True, BLOCK_K=32, BLOCK_M=128, BLOCK_N=128, GROUPS=1, KERNEL_H=1, KERNEL_W=1, PADDING_H=0, PADDING_W=0, STRIDE_H=2, STRIDE_W=2, UNROLL=True, num_stages=2, num_warps=8
  triton_convolution2d_50 0.0464 ms 23.5% ALLOW_TF32=True, BLOCK_K=16, BLOCK_M=1024, BLOCK_N=16, GROUPS=1, KERNEL_H=1, KERNEL_W=1, PADDING_H=0, PADDING_W=0, STRIDE_H=2, STRIDE_W=2, UNROLL=True, num_stages=1, num_warps=8
  triton_convolution2d_49 0.0695 ms 15.7% ALLOW_TF32=True, BLOCK_K=16, BLOCK_M=256, BLOCK_N=64, GROUPS=1, KERNEL_H=1, KERNEL_W=1, PADDING_H=0, PADDING_W=0, STRIDE_H=2, STRIDE_W=2, UNROLL=True, num_stages=2, num_warps=4
SingleProcess AUTOTUNE benchmarking takes 0.9778 seconds and 0.0006 seconds precompiling
AUTOTUNE convolution(2x128x28x28, 128x128x3x3)
  convolution 0.0432 ms 100.0%
  triton_convolution2d_59 0.1169 ms 37.0% ALLOW_TF32=True, BLOCK_K=32, BLOCK_M=64, BLOCK_N=64, GROUPS=1, KERNEL_H=3, KERNEL_W=3, PADDING_H=1, PADDING_W=1, STRIDE_H=1, STRIDE_W=1, UNROLL=False, num_stages=2, num_warps=4
  triton_convolution2d_61 0.1361 ms 31.8% ALLOW_TF32=True, BLOCK_K=32, BLOCK_M=256, BLOCK_N=64, GROUPS=1, KERNEL_H=3, KERNEL_W=3, PADDING_H=1, PADDING_W=1, STRIDE_H=1, STRIDE_W=1, UNROLL=False, num_stages=2, num_warps=8
  triton_convolution2d_55 0.1663 ms 26.0% ALLOW_TF32=True, BLOCK_K=16, BLOCK_M=64, BLOCK_N=128, GROUPS=1, KERNEL_H=3, KERNEL_W=3, PADDING_H=1, PADDING_W=1, STRIDE_H=1, STRIDE_W=1, UNROLL=False, num_stages=2, num_warps=4
  triton_convolution2d_60 0.1757 ms 24.6% ALLOW_TF32=True, BLOCK_K=32, BLOCK_M=64, BLOCK_N=128, GROUPS=1, KERNEL_H=3, KERNEL_W=3, PADDING_H=1, PADDING_W=1, STRIDE_H=1, STRIDE_W=1, UNROLL=False, num_stages=2, num_warps=8
  triton_convolution2d_56 0.1899 ms 22.8% ALLOW_TF32=True, BLOCK_K=16, BLOCK_M=256, BLOCK_N=64, GROUPS=1, KERNEL_H=3, KERNEL_W=3, PADDING_H=1, PADDING_W=1, STRIDE_H=1, STRIDE_W=1, UNROLL=False, num_stages=2, num_warps=4
  triton_convolution2d_58 0.1944 ms 22.2% ALLOW_TF32=True, BLOCK_K=32, BLOCK_M=128, BLOCK_N=128, GROUPS=1, KERNEL_H=3, KERNEL_W=3, PADDING_H=1, PADDING_W=1, STRIDE_H=1, STRIDE_W=1, UNROLL=False, num_stages=2, num_warps=8
  triton_convolution2d_57 0.2652 ms 16.3% ALLOW_TF32=True, BLOCK_K=16, BLOCK_M=1024, BLOCK_N=16, GROUPS=1, KERNEL_H=3, KERNEL_W=3, PADDING_H=1, PADDING_W=1, STRIDE_H=1, STRIDE_W=1, UNROLL=False, num_stages=1, num_warps=8
SingleProcess AUTOTUNE benchmarking takes 0.9627 seconds and 0.0006 seconds precompiling
AUTOTUNE convolution(2x128x28x28, 256x128x3x3)
  convolution 0.0371 ms 100.0%
  triton_convolution2d_73 0.0998 ms 37.1% ALLOW_TF32=True, BLOCK_K=32, BLOCK_M=64, BLOCK_N=64, GROUPS=1, KERNEL_H=3, KERNEL_W=3, PADDING_H=1, PADDING_W=1, STRIDE_H=2, STRIDE_W=2, UNROLL=False, num_stages=2, num_warps=4
  triton_convolution2d_75 0.1594 ms 23.3% ALLOW_TF32=True, BLOCK_K=32, BLOCK_M=256, BLOCK_N=64, GROUPS=1, KERNEL_H=3, KERNEL_W=3, PADDING_H=1, PADDING_W=1, STRIDE_H=2, STRIDE_W=2, UNROLL=False, num_stages=2, num_warps=8
  triton_convolution2d_72 0.2031 ms 18.2% ALLOW_TF32=True, BLOCK_K=32, BLOCK_M=128, BLOCK_N=128, GROUPS=1, KERNEL_H=3, KERNEL_W=3, PADDING_H=1, PADDING_W=1, STRIDE_H=2, STRIDE_W=2, UNROLL=False, num_stages=2, num_warps=8
  triton_convolution2d_70 0.2245 ms 16.5% ALLOW_TF32=True, BLOCK_K=16, BLOCK_M=256, BLOCK_N=64, GROUPS=1, KERNEL_H=3, KERNEL_W=3, PADDING_H=1, PADDING_W=1, STRIDE_H=2, STRIDE_W=2, UNROLL=False, num_stages=2, num_warps=4
  triton_convolution2d_71 0.2649 ms 14.0% ALLOW_TF32=True, BLOCK_K=16, BLOCK_M=1024, BLOCK_N=16, GROUPS=1, KERNEL_H=3, KERNEL_W=3, PADDING_H=1, PADDING_W=1, STRIDE_H=2, STRIDE_W=2, UNROLL=False, num_stages=1, num_warps=8
  triton_convolution2d_74 0.2849 ms 13.0% ALLOW_TF32=True, BLOCK_K=32, BLOCK_M=64, BLOCK_N=256, GROUPS=1, KERNEL_H=3, KERNEL_W=3, PADDING_H=1, PADDING_W=1, STRIDE_H=2, STRIDE_W=2, UNROLL=False, num_stages=2, num_warps=8
  triton_convolution2d_69 0.3374 ms 11.0% ALLOW_TF32=True, BLOCK_K=16, BLOCK_M=64, BLOCK_N=256, GROUPS=1, KERNEL_H=3, KERNEL_W=3, PADDING_H=1, PADDING_W=1, STRIDE_H=2, STRIDE_W=2, UNROLL=False, num_stages=2, num_warps=4
SingleProcess AUTOTUNE benchmarking takes 0.9723 seconds and 0.0005 seconds precompiling
AUTOTUNE convolution(2x128x28x28, 256x128x1x1)
  triton_convolution2d_87 0.0124 ms 100.0% ALLOW_TF32=True, BLOCK_K=32, BLOCK_M=64, BLOCK_N=64, GROUPS=1, KERNEL_H=1, KERNEL_W=1, PADDING_H=0, PADDING_W=0, STRIDE_H=2, STRIDE_W=2, UNROLL=True, num_stages=2, num_warps=4
  convolution 0.0204 ms 60.7%
  triton_convolution2d_88 0.0216 ms 57.3% ALLOW_TF32=True, BLOCK_K=32, BLOCK_M=64, BLOCK_N=256, GROUPS=1, KERNEL_H=1, KERNEL_W=1, PADDING_H=0, PADDING_W=0, STRIDE_H=2, STRIDE_W=2, UNROLL=True, num_stages=2, num_warps=8
  triton_convolution2d_89 0.0277 ms 44.6% ALLOW_TF32=True, BLOCK_K=32, BLOCK_M=256, BLOCK_N=64, GROUPS=1, KERNEL_H=1, KERNEL_W=1, PADDING_H=0, PADDING_W=0, STRIDE_H=2, STRIDE_W=2, UNROLL=True, num_stages=2, num_warps=8
  triton_convolution2d_85 0.0332 ms 37.4% ALLOW_TF32=True, BLOCK_K=16, BLOCK_M=1024, BLOCK_N=16, GROUPS=1, KERNEL_H=1, KERNEL_W=1, PADDING_H=0, PADDING_W=0, STRIDE_H=2, STRIDE_W=2, UNROLL=True, num_stages=1, num_warps=8
  triton_convolution2d_86 0.0455 ms 27.2% ALLOW_TF32=True, BLOCK_K=32, BLOCK_M=128, BLOCK_N=128, GROUPS=1, KERNEL_H=1, KERNEL_W=1, PADDING_H=0, PADDING_W=0, STRIDE_H=2, STRIDE_W=2, UNROLL=True, num_stages=2, num_warps=8
  triton_convolution2d_83 0.1236 ms 10.0% ALLOW_TF32=True, BLOCK_K=16, BLOCK_M=64, BLOCK_N=256, GROUPS=1, KERNEL_H=1, KERNEL_W=1, PADDING_H=0, PADDING_W=0, STRIDE_H=2, STRIDE_W=2, UNROLL=True, num_stages=2, num_warps=4
  triton_convolution2d_84 0.1468 ms 8.4% ALLOW_TF32=True, BLOCK_K=16, BLOCK_M=256, BLOCK_N=64, GROUPS=1, KERNEL_H=1, KERNEL_W=1, PADDING_H=0, PADDING_W=0, STRIDE_H=2, STRIDE_W=2, UNROLL=True, num_stages=2, num_warps=4
SingleProcess AUTOTUNE benchmarking takes 1.0137 seconds and 0.0004 seconds precompiling
AUTOTUNE convolution(2x256x14x14, 256x256x3x3)
  convolution 0.0533 ms 100.0%
  triton_convolution2d_94 0.1864 ms 28.6% ALLOW_TF32=True, BLOCK_K=32, BLOCK_M=64, BLOCK_N=64, GROUPS=1, KERNEL_H=3, KERNEL_W=3, PADDING_H=1, PADDING_W=1, STRIDE_H=1, STRIDE_W=1, UNROLL=False, num_stages=2, num_warps=4
  triton_convolution2d_92 0.2607 ms 20.4% ALLOW_TF32=True, BLOCK_K=16, BLOCK_M=512, BLOCK_N=16, GROUPS=1, KERNEL_H=3, KERNEL_W=3, PADDING_H=1, PADDING_W=1, STRIDE_H=1, STRIDE_W=1, UNROLL=False, num_stages=1, num_warps=8
  triton_convolution2d_96 0.2625 ms 20.3% ALLOW_TF32=True, BLOCK_K=32, BLOCK_M=256, BLOCK_N=64, GROUPS=1, KERNEL_H=3, KERNEL_W=3, PADDING_H=1, PADDING_W=1, STRIDE_H=1, STRIDE_W=1, UNROLL=False, num_stages=2, num_warps=8
  triton_convolution2d_91 0.3701 ms 14.4% ALLOW_TF32=True, BLOCK_K=16, BLOCK_M=256, BLOCK_N=64, GROUPS=1, KERNEL_H=3, KERNEL_W=3, PADDING_H=1, PADDING_W=1, STRIDE_H=1, STRIDE_W=1, UNROLL=False, num_stages=2, num_warps=4
  triton_convolution2d_93 0.3748 ms 14.2% ALLOW_TF32=True, BLOCK_K=32, BLOCK_M=128, BLOCK_N=128, GROUPS=1, KERNEL_H=3, KERNEL_W=3, PADDING_H=1, PADDING_W=1, STRIDE_H=1, STRIDE_W=1, UNROLL=False, num_stages=2, num_warps=8
  triton_convolution2d_95 0.5471 ms 9.7% ALLOW_TF32=True, BLOCK_K=32, BLOCK_M=64, BLOCK_N=256, GROUPS=1, KERNEL_H=3, KERNEL_W=3, PADDING_H=1, PADDING_W=1, STRIDE_H=1, STRIDE_W=1, UNROLL=False, num_stages=2, num_warps=8
  triton_convolution2d_90 0.6529 ms 8.2% ALLOW_TF32=True, BLOCK_K=16, BLOCK_M=64, BLOCK_N=256, GROUPS=1, KERNEL_H=3, KERNEL_W=3, PADDING_H=1, PADDING_W=1, STRIDE_H=1, STRIDE_W=1, UNROLL=False, num_stages=2, num_warps=4
SingleProcess AUTOTUNE benchmarking takes 0.9505 seconds and 0.0005 seconds precompiling
AUTOTUNE convolution(2x256x14x14, 512x256x3x3)
  convolution 0.0528 ms 100.0%
  triton_convolution2d_108 0.1927 ms 27.4% ALLOW_TF32=True, BLOCK_K=32, BLOCK_M=64, BLOCK_N=64, GROUPS=1, KERNEL_H=3, KERNEL_W=3, PADDING_H=1, PADDING_W=1, STRIDE_H=2, STRIDE_W=2, UNROLL=False, num_stages=2, num_warps=4
  triton_convolution2d_106 0.2816 ms 18.7% ALLOW_TF32=True, BLOCK_K=16, BLOCK_M=512, BLOCK_N=16, GROUPS=1, KERNEL_H=3, KERNEL_W=3, PADDING_H=1, PADDING_W=1, STRIDE_H=2, STRIDE_W=2, UNROLL=False, num_stages=1, num_warps=8
  triton_convolution2d_110 0.2940 ms 18.0% ALLOW_TF32=True, BLOCK_K=32, BLOCK_M=256, BLOCK_N=64, GROUPS=1, KERNEL_H=3, KERNEL_W=3, PADDING_H=1, PADDING_W=1, STRIDE_H=2, STRIDE_W=2, UNROLL=False, num_stages=2, num_warps=8
  triton_convolution2d_105 0.3824 ms 13.8% ALLOW_TF32=True, BLOCK_K=16, BLOCK_M=256, BLOCK_N=64, GROUPS=1, KERNEL_H=3, KERNEL_W=3, PADDING_H=1, PADDING_W=1, STRIDE_H=2, STRIDE_W=2, UNROLL=False, num_stages=2, num_warps=4
  triton_convolution2d_107 0.3895 ms 13.6% ALLOW_TF32=True, BLOCK_K=32, BLOCK_M=128, BLOCK_N=128, GROUPS=1, KERNEL_H=3, KERNEL_W=3, PADDING_H=1, PADDING_W=1, STRIDE_H=2, STRIDE_W=2, UNROLL=False, num_stages=2, num_warps=8
  triton_convolution2d_109 0.5599 ms 9.4% ALLOW_TF32=True, BLOCK_K=32, BLOCK_M=64, BLOCK_N=256, GROUPS=1, KERNEL_H=3, KERNEL_W=3, PADDING_H=1, PADDING_W=1, STRIDE_H=2, STRIDE_W=2, UNROLL=False, num_stages=2, num_warps=8
  triton_convolution2d_104 0.6858 ms 7.7% ALLOW_TF32=True, BLOCK_K=16, BLOCK_M=64, BLOCK_N=256, GROUPS=1, KERNEL_H=3, KERNEL_W=3, PADDING_H=1, PADDING_W=1, STRIDE_H=2, STRIDE_W=2, UNROLL=False, num_stages=2, num_warps=4
SingleProcess AUTOTUNE benchmarking takes 0.9598 seconds and 0.0005 seconds precompiling
AUTOTUNE convolution(2x256x14x14, 512x256x1x1)
  triton_convolution2d_122 0.0185 ms 100.0% ALLOW_TF32=True, BLOCK_K=32, BLOCK_M=64, BLOCK_N=64, GROUPS=1, KERNEL_H=1, KERNEL_W=1, PADDING_H=0, PADDING_W=0, STRIDE_H=2, STRIDE_W=2, UNROLL=True, num_stages=2, num_warps=4
  convolution 0.0261 ms 71.0%
  triton_convolution2d_120 0.0333 ms 55.7% ALLOW_TF32=True, BLOCK_K=16, BLOCK_M=512, BLOCK_N=16, GROUPS=1, KERNEL_H=1, KERNEL_W=1, PADDING_H=0, PADDING_W=0, STRIDE_H=2, STRIDE_W=2, UNROLL=True, num_stages=1, num_warps=8
  triton_convolution2d_124 0.0882 ms 21.0% ALLOW_TF32=True, BLOCK_K=32, BLOCK_M=256, BLOCK_N=64, GROUPS=1, KERNEL_H=1, KERNEL_W=1, PADDING_H=0, PADDING_W=0, STRIDE_H=2, STRIDE_W=2, UNROLL=True, num_stages=2, num_warps=8
  triton_convolution2d_123 0.0996 ms 18.6% ALLOW_TF32=True, BLOCK_K=32, BLOCK_M=64, BLOCK_N=256, GROUPS=1, KERNEL_H=1, KERNEL_W=1, PADDING_H=0, PADDING_W=0, STRIDE_H=2, STRIDE_W=2, UNROLL=True, num_stages=2, num_warps=8
  triton_convolution2d_121 0.1301 ms 14.2% ALLOW_TF32=True, BLOCK_K=32, BLOCK_M=128, BLOCK_N=128, GROUPS=1, KERNEL_H=1, KERNEL_W=1, PADDING_H=0, PADDING_W=0, STRIDE_H=2, STRIDE_W=2, UNROLL=True, num_stages=2, num_warps=8
  triton_convolution2d_118 0.2875 ms 6.4% ALLOW_TF32=True, BLOCK_K=16, BLOCK_M=64, BLOCK_N=256, GROUPS=1, KERNEL_H=1, KERNEL_W=1, PADDING_H=0, PADDING_W=0, STRIDE_H=2, STRIDE_W=2, UNROLL=True, num_stages=2, num_warps=4
  triton_convolution2d_119 0.2958 ms 6.3% ALLOW_TF32=True, BLOCK_K=16, BLOCK_M=256, BLOCK_N=64, GROUPS=1, KERNEL_H=1, KERNEL_W=1, PADDING_H=0, PADDING_W=0, STRIDE_H=2, STRIDE_W=2, UNROLL=True, num_stages=2, num_warps=4
SingleProcess AUTOTUNE benchmarking takes 1.0160 seconds and 0.0005 seconds precompiling
AUTOTUNE convolution(2x512x7x7, 512x512x3x3)
  convolution 0.0857 ms 100.0%
  triton_convolution2d_127 0.2801 ms 30.6% ALLOW_TF32=True, BLOCK_K=16, BLOCK_M=128, BLOCK_N=16, GROUPS=1, KERNEL_H=3, KERNEL_W=3, PADDING_H=1, PADDING_W=1, STRIDE_H=1, STRIDE_W=1, UNROLL=False, num_stages=1, num_warps=8
  triton_convolution2d_129 0.3599 ms 23.8% ALLOW_TF32=True, BLOCK_K=32, BLOCK_M=64, BLOCK_N=64, GROUPS=1, KERNEL_H=3, KERNEL_W=3, PADDING_H=1, PADDING_W=1, STRIDE_H=1, STRIDE_W=1, UNROLL=False, num_stages=2, num_warps=4
  triton_convolution2d_131 0.4236 ms 20.2% ALLOW_TF32=True, BLOCK_K=32, BLOCK_M=128, BLOCK_N=64, GROUPS=1, KERNEL_H=3, KERNEL_W=3, PADDING_H=1, PADDING_W=1, STRIDE_H=1, STRIDE_W=1, UNROLL=False, num_stages=2, num_warps=8
  triton_convolution2d_126 0.4836 ms 17.7% ALLOW_TF32=True, BLOCK_K=16, BLOCK_M=128, BLOCK_N=64, GROUPS=1, KERNEL_H=3, KERNEL_W=3, PADDING_H=1, PADDING_W=1, STRIDE_H=1, STRIDE_W=1, UNROLL=False, num_stages=2, num_warps=4
  triton_convolution2d_128 0.7235 ms 11.8% ALLOW_TF32=True, BLOCK_K=32, BLOCK_M=128, BLOCK_N=128, GROUPS=1, KERNEL_H=3, KERNEL_W=3, PADDING_H=1, PADDING_W=1, STRIDE_H=1, STRIDE_W=1, UNROLL=False, num_stages=2, num_warps=8
  triton_convolution2d_130 1.0996 ms 7.8% ALLOW_TF32=True, BLOCK_K=32, BLOCK_M=64, BLOCK_N=256, GROUPS=1, KERNEL_H=3, KERNEL_W=3, PADDING_H=1, PADDING_W=1, STRIDE_H=1, STRIDE_W=1, UNROLL=False, num_stages=2, num_warps=8
  triton_convolution2d_125 1.4547 ms 5.9% ALLOW_TF32=True, BLOCK_K=16, BLOCK_M=64, BLOCK_N=256, GROUPS=1, KERNEL_H=3, KERNEL_W=3, PADDING_H=1, PADDING_W=1, STRIDE_H=1, STRIDE_W=1, UNROLL=False, num_stages=2, num_warps=4
SingleProcess AUTOTUNE benchmarking takes 0.9553 seconds and 0.0005 seconds precompiling
AUTOTUNE addmm(2x1000, 2x512, 512x1000)
  addmm 0.0156 ms 100.0%
  triton_mm_142 0.0219 ms 71.2% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=32, BLOCK_M=16, BLOCK_N=32, B_PROLOGUE_CAST_TYPE=None, EVEN_K=True, GROUP_M=8, num_stages=5, num_warps=2
  triton_mm_153 0.0303 ms 51.4% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=32, BLOCK_M=16, BLOCK_N=64, B_PROLOGUE_CAST_TYPE=None, EVEN_K=True, GROUP_M=8, num_stages=4, num_warps=4
  triton_mm_152 0.0306 ms 51.0% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=32, BLOCK_M=16, BLOCK_N=64, B_PROLOGUE_CAST_TYPE=None, EVEN_K=True, GROUP_M=8, num_stages=3, num_warps=4
  triton_mm_141 0.0308 ms 50.7% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=32, BLOCK_M=16, BLOCK_N=64, B_PROLOGUE_CAST_TYPE=None, EVEN_K=True, GROUP_M=8, num_stages=5, num_warps=4
  triton_mm_146 0.0308 ms 50.6% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=64, BLOCK_M=16, BLOCK_N=64, B_PROLOGUE_CAST_TYPE=None, EVEN_K=True, GROUP_M=8, num_stages=3, num_warps=4
  triton_mm_139 0.0342 ms 45.5% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=16, BLOCK_M=16, BLOCK_N=32, B_PROLOGUE_CAST_TYPE=None, EVEN_K=True, GROUP_M=8, num_stages=1, num_warps=2
  triton_mm_145 0.0373 ms 41.7% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=32, BLOCK_M=16, BLOCK_N=64, B_PROLOGUE_CAST_TYPE=None, EVEN_K=True, GROUP_M=8, num_stages=2, num_warps=4
  triton_mm_144 0.0454 ms 34.3% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=16, BLOCK_M=16, BLOCK_N=64, B_PROLOGUE_CAST_TYPE=None, EVEN_K=True, GROUP_M=8, num_stages=2, num_warps=4
  triton_mm_148 0.0501 ms 31.1% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=32, BLOCK_M=16, BLOCK_N=128, B_PROLOGUE_CAST_TYPE=None, EVEN_K=True, GROUP_M=8, num_stages=3, num_warps=4
SingleProcess AUTOTUNE benchmarking takes 1.8298 seconds and 0.0021 seconds precompiling

Python 中的模型推理

通常,上面生成的共享对象用于非 Python 环境。在 PyTorch 2.3 中,我们添加了一个名为 torch._export.aot_load() 的新 API 来在 Python 运行时加载共享库。该 API 的结构类似于 torch.jit.load() API。您需要指定共享库的路径以及应加载它的设备。

注意

在上面的示例中,我们为推理指定了 batch_size=1,即使我们在 torch.export.export() 中指定了 min=2,它仍然可以正常工作。

import os
import torch

device = "cuda" if torch.cuda.is_available() else "cpu"
model_so_path = os.path.join(os.getcwd(), "resnet18_pt2.so")

model = torch._export.aot_load(model_so_path, device)
example_inputs = (torch.randn(1, 3, 224, 224, device=device),)

with torch.inference_mode():
    output = model(example_inputs)

何时在 Python 运行时使用 AOTInductor

使用 AOTInductor 的要求之一是模型不应该有任何图中断。满足此要求后,使用 AOTInductor Python 运行时的主要用例是使用 Python 进行模型部署。使用 AOTInductor Python 运行时主要有两个原因

  • torch._inductor.aot_compile 生成一个共享库。这对于部署的模型版本控制和跟踪模型性能随时间推移的变化很有用。

  • 由于 torch.compile() 是一个 JIT 编译器,因此第一次编译会产生预热成本。您的部署需要考虑第一次推理的编译时间。使用 AOTInductor,编译是在离线使用 torch.export.exporttorch._indutor.aot_compile 完成的。部署将仅使用 torch._export.aot_load 加载共享库并运行推理。

以下部分显示了使用 AOTInductor 进行第一次推理获得的加速效果

我们定义了一个实用函数 timed 来测量推理所需的时间

import time
def timed(fn):
    # Returns the result of running `fn()` and the time it took for `fn()` to run,
    # in seconds. We use CUDA events and synchronization for accurate
    # measurement on CUDA enabled devices.
    if torch.cuda.is_available():
        start = torch.cuda.Event(enable_timing=True)
        end = torch.cuda.Event(enable_timing=True)
        start.record()
    else:
        start = time.time()

    result = fn()
    if torch.cuda.is_available():
        end.record()
        torch.cuda.synchronize()
    else:
        end = time.time()

    # Measure time taken to execute the function in miliseconds
    if torch.cuda.is_available():
        duration = start.elapsed_time(end)
    else:
        duration = (end - start) * 1000

    return result, duration

让我们测量使用 AOTInductor 进行第一次推理的时间

torch._dynamo.reset()

model = torch._export.aot_load(model_so_path, device)
example_inputs = (torch.randn(1, 3, 224, 224, device=device),)

with torch.inference_mode():
    _, time_taken = timed(lambda: model(example_inputs))
    print(f"Time taken for first inference for AOTInductor is {time_taken:.2f} ms")
Time taken for first inference for AOTInductor is 2.92 ms

让我们测量使用 torch.compile 进行第一次推理的时间

torch._dynamo.reset()

model = resnet18(weights=ResNet18_Weights.DEFAULT).to(device)
model.eval()

model = torch.compile(model)
example_inputs = torch.randn(1, 3, 224, 224, device=device)

with torch.inference_mode():
    _, time_taken = timed(lambda: model(example_inputs))
    print(f"Time taken for first inference for torch.compile is {time_taken:.2f} ms")
Time taken for first inference for torch.compile is 6831.81 ms

我们可以看到,与 torch.compile 相比,使用 AOTInductor 进行第一次推理的时间大幅缩短了

结论

在本教程中,我们学习了如何通过使用 torch._inductor.aot_compiletorch._export.aot_load API 编译和加载预训练的 ResNet18 模型,从而有效地使用 AOTInductor 进行 Python 运行时。此过程演示了生成共享库并在 Python 环境中运行它的实际应用,即使考虑到动态形状和特定于设备的优化。我们还研究了在模型部署中使用 AOTInductor 的优势,例如第一次推理时间的加速。

脚本的总运行时间:(1 分 27.964 秒)

由 Sphinx-Gallery 生成的图库

文档

访问 PyTorch 的全面开发者文档

查看文档

教程

获取针对初学者和高级开发人员的深入教程

查看教程

资源

查找开发资源并获得问题的解答

查看资源