快捷方式

引擎缓存

随着模型尺寸的增加,编译成本也会随之增加。使用 torch.dynamo.compile 等 AOT 方法,此成本会在预先支付。但是,如果权重发生变化,会话结束或您使用 torch.compile 等 JIT 方法,由于图会失效而被重新编译,此成本将被重复支付。引擎缓存是一种通过将构建的引擎保存到磁盘并在可能的情况下重新使用它们来降低此成本的方法。本教程演示了如何在 PyTorch 中使用 TensorRT 中的引擎缓存。引擎缓存可以通过重用先前构建的 TensorRT 引擎来显著加快后续模型编译速度。

我们将探讨两种方法

  1. 使用 torch_tensorrt.dynamo.compile

  2. 使用 TensorRT 后端的 torch.compile

此示例使用预训练的 ResNet18 模型,并显示了无缓存编译、启用缓存编译以及重用缓存引擎之间的差异。

import os
from typing import Dict, Optional

import numpy as np
import torch
import torch_tensorrt as torch_trt
import torchvision.models as models
from torch_tensorrt.dynamo._defaults import TIMING_CACHE_PATH
from torch_tensorrt.dynamo._engine_cache import BaseEngineCache

np.random.seed(0)
torch.manual_seed(0)

model = models.resnet18(pretrained=True).eval().to("cuda")
enabled_precisions = {torch.float}
debug = False
min_block_size = 1
use_python_runtime = False


def remove_timing_cache(path=TIMING_CACHE_PATH):
    if os.path.exists(path):
        os.remove(path)

JIT 编译的引擎缓存

引擎缓存的主要目标是帮助加快 JIT 工作流程。torch.compile 在模型构建中提供了很大的灵活性,这使其成为尝试加快工作流程时的首选工具。但是,从历史上看,编译成本,尤其是重新编译成本,一直是许多用户的进入门槛。如果由于某种原因子图失效,则该子图会在添加引擎缓存之前从头重建。现在,随着引擎的构建,使用 cache_built_engines=True,引擎将保存到磁盘,并与对应 PyTorch 子图的哈希值相关联。如果在随后的编译中,无论是作为此会话的一部分还是新会话的一部分,缓存将提取构建的引擎并**重新拟合**权重,这可以将编译时间缩短几个数量级。因此,为了将新引擎插入缓存(即 cache_built_engines=True),引擎必须是可重新拟合的(make_refittable=True)。有关更多详细信息,请参阅 使用新权重重新拟合 Torch-TensorRT 程序

def torch_compile(iterations=3):
    times = []
    start = torch.cuda.Event(enable_timing=True)
    end = torch.cuda.Event(enable_timing=True)

    # The 1st iteration is to measure the compilation time without engine caching
    # The 2nd and 3rd iterations are to measure the compilation time with engine caching.
    # Since the 2nd iteration needs to compile and save the engine, it will be slower than the 1st iteration.
    # The 3rd iteration should be faster than the 1st iteration because it loads the cached engine.
    for i in range(iterations):
        inputs = [torch.rand((100, 3, 224, 224)).to("cuda")]
        # remove timing cache and reset dynamo just for engine caching messurement
        remove_timing_cache()
        torch._dynamo.reset()

        if i == 0:
            cache_built_engines = False
            reuse_cached_engines = False
        else:
            cache_built_engines = True
            reuse_cached_engines = True

        start.record()
        compiled_model = torch.compile(
            model,
            backend="tensorrt",
            options={
                "use_python_runtime": True,
                "enabled_precisions": enabled_precisions,
                "debug": debug,
                "min_block_size": min_block_size,
                "make_refittable": True,
                "cache_built_engines": cache_built_engines,
                "reuse_cached_engines": reuse_cached_engines,
            },
        )
        compiled_model(*inputs)  # trigger the compilation
        end.record()
        torch.cuda.synchronize()
        times.append(start.elapsed_time(end))

    print("----------------torch_compile----------------")
    print("disable engine caching, used:", times[0], "ms")
    print("enable engine caching to cache engines, used:", times[1], "ms")
    print("enable engine caching to reuse engines, used:", times[2], "ms")


torch_compile()

AOT 编译的引擎缓存

与 JIT 工作流程类似,AOT 工作流程也可以从引擎缓存中获益。当相同的架构或常见的子图被重新编译时,缓存将提取先前构建的引擎并重新拟合权重。

def dynamo_compile(iterations=3):
    times = []
    start = torch.cuda.Event(enable_timing=True)
    end = torch.cuda.Event(enable_timing=True)

    example_inputs = (torch.randn((100, 3, 224, 224)).to("cuda"),)
    # Mark the dim0 of inputs as dynamic
    batch = torch.export.Dim("batch", min=1, max=200)
    exp_program = torch.export.export(
        model, args=example_inputs, dynamic_shapes={"x": {0: batch}}
    )

    # The 1st iteration is to measure the compilation time without engine caching
    # The 2nd and 3rd iterations are to measure the compilation time with engine caching.
    # Since the 2nd iteration needs to compile and save the engine, it will be slower than the 1st iteration.
    # The 3rd iteration should be faster than the 1st iteration because it loads the cached engine.
    for i in range(iterations):
        inputs = [torch.rand((100 + i, 3, 224, 224)).to("cuda")]
        remove_timing_cache()  # remove timing cache just for engine caching messurement
        if i == 0:
            cache_built_engines = False
            reuse_cached_engines = False
        else:
            cache_built_engines = True
            reuse_cached_engines = True

        start.record()
        trt_gm = torch_trt.dynamo.compile(
            exp_program,
            tuple(inputs),
            use_python_runtime=use_python_runtime,
            enabled_precisions=enabled_precisions,
            debug=debug,
            min_block_size=min_block_size,
            make_refittable=True,
            cache_built_engines=cache_built_engines,
            reuse_cached_engines=reuse_cached_engines,
            engine_cache_size=1 << 30,  # 1GB
        )
        # output = trt_gm(*inputs)
        end.record()
        torch.cuda.synchronize()
        times.append(start.elapsed_time(end))

    print("----------------dynamo_compile----------------")
    print("disable engine caching, used:", times[0], "ms")
    print("enable engine caching to cache engines, used:", times[1], "ms")
    print("enable engine caching to reuse engines, used:", times[2], "ms")


dynamo_compile()

自定义引擎缓存

默认情况下,引擎缓存存储在系统的临时目录中。可以通过传递 engine_cache_direngine_cache_size 来自定义缓存目录和大小限制。用户还可以通过扩展 BaseEngineCache 类来定义自己的引擎缓存实现。如果需要,这允许远程或共享缓存。

自定义引擎缓存应实现以下方法
  • save:将引擎 Blob 保存到缓存中。

  • load:从缓存中加载引擎 Blob。

缓存系统提供的哈希值是源 PyTorch 子图(降级后)的与权重无关的哈希值。Blob 以 pickle 格式包含序列化引擎、调用规范数据和权重映射信息

以下是一个自定义引擎缓存实现的示例,它实现了 RAMEngineCache

class RAMEngineCache(BaseEngineCache):
    def __init__(
        self,
    ) -> None:
        """
        Constructs a user held engine cache in memory.
        """
        self.engine_cache: Dict[str, bytes] = {}

    def save(
        self,
        hash: str,
        blob: bytes,
    ):
        """
        Insert the engine blob to the cache.

        Args:
            hash (str): The hash key to associate with the engine blob.
            blob (bytes): The engine blob to be saved.

        Returns:
            None
        """
        self.engine_cache[hash] = blob

    def load(self, hash: str) -> Optional[bytes]:
        """
        Load the engine blob from the cache.

        Args:
            hash (str): The hash key of the engine to load.

        Returns:
            Optional[bytes]: The engine blob if found, None otherwise.
        """
        if hash in self.engine_cache:
            return self.engine_cache[hash]
        else:
            return None


def torch_compile_my_cache(iterations=3):
    times = []
    engine_cache = RAMEngineCache()
    start = torch.cuda.Event(enable_timing=True)
    end = torch.cuda.Event(enable_timing=True)

    # The 1st iteration is to measure the compilation time without engine caching
    # The 2nd and 3rd iterations are to measure the compilation time with engine caching.
    # Since the 2nd iteration needs to compile and save the engine, it will be slower than the 1st iteration.
    # The 3rd iteration should be faster than the 1st iteration because it loads the cached engine.
    for i in range(iterations):
        inputs = [torch.rand((100, 3, 224, 224)).to("cuda")]
        # remove timing cache and reset dynamo just for engine caching messurement
        remove_timing_cache()
        torch._dynamo.reset()

        if i == 0:
            cache_built_engines = False
            reuse_cached_engines = False
        else:
            cache_built_engines = True
            reuse_cached_engines = True

        start.record()
        compiled_model = torch.compile(
            model,
            backend="tensorrt",
            options={
                "use_python_runtime": True,
                "enabled_precisions": enabled_precisions,
                "debug": debug,
                "min_block_size": min_block_size,
                "make_refittable": True,
                "cache_built_engines": cache_built_engines,
                "reuse_cached_engines": reuse_cached_engines,
                "custom_engine_cache": engine_cache,
            },
        )
        compiled_model(*inputs)  # trigger the compilation
        end.record()
        torch.cuda.synchronize()
        times.append(start.elapsed_time(end))

    print("----------------torch_compile----------------")
    print("disable engine caching, used:", times[0], "ms")
    print("enable engine caching to cache engines, used:", times[1], "ms")
    print("enable engine caching to reuse engines, used:", times[2], "ms")


torch_compile_my_cache()

脚本的总运行时间:(0 分 0.000 秒)

由 Sphinx-Gallery 生成的图库

文档

访问 PyTorch 的全面开发者文档

查看文档

教程

获取针对初学者和高级开发人员的深入教程

查看教程

资源

查找开发资源并获得问题的解答

查看资源