• 教程 >
  • 通过区域编译减少 torch.compile 冷启动编译时间
快捷方式

通过区域编译减少 torch.compile 冷启动编译时间

创建于:2024 年 10 月 10 日 | 最后更新:2024 年 10 月 16 日 | 最后验证:2024 年 10 月 10 日

作者: Animesh Jain

随着深度学习模型变得越来越大,这些模型的编译时间也随之增加。延长的编译时间会导致推理服务启动时间过长,或者大规模训练中资源浪费。本食谱展示了一个示例,说明如何通过选择编译模型中重复的区域而不是整个模型来减少冷启动编译时间。

先决条件

  • Pytorch 2.5 或更高版本

设置

在开始之前,如果尚未安装 torch,我们需要安装它。

pip install torch

注意

此功能从 2.5 版本开始提供。如果您使用的是 2.4 版本,可以启用配置标志 torch._dynamo.config.inline_inbuilt_nn_modules=True 以防止区域编译期间的重新编译。在 2.5 版本中,此标志默认启用。

from time import perf_counter

步骤

在本食谱中,我们将遵循以下步骤

  1. 导入所有必要的库。

  2. 定义并初始化具有重复区域的神经网络。

  3. 了解完整模型和区域编译之间的区别。

  4. 测量完整模型和区域编译的编译时间。

首先,让我们导入加载数据所需的库

import torch
import torch.nn as nn

接下来,让我们定义并初始化具有重复区域的神经网络。

通常,神经网络由重复的层组成。例如,大型语言模型由许多 Transformer 块组成。在本食谱中,我们将使用 nn.Module 类创建一个 Layer,作为重复区域的代理。然后,我们将创建一个 Model,它由 64 个此 Layer 类的实例组成。

class Layer(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.linear1 = torch.nn.Linear(10, 10)
        self.relu1 = torch.nn.ReLU()
        self.linear2 = torch.nn.Linear(10, 10)
        self.relu2 = torch.nn.ReLU()

    def forward(self, x):
        a = self.linear1(x)
        a = self.relu1(a)
        a = torch.sigmoid(a)
        b = self.linear2(a)
        b = self.relu2(b)
        return b


class Model(torch.nn.Module):
    def __init__(self, apply_regional_compilation):
        super().__init__()
        self.linear = torch.nn.Linear(10, 10)
        # Apply compile only to the repeated layers.
        if apply_regional_compilation:
            self.layers = torch.nn.ModuleList(
                [torch.compile(Layer()) for _ in range(64)]
            )
        else:
            self.layers = torch.nn.ModuleList([Layer() for _ in range(64)])

    def forward(self, x):
        # In regional compilation, the self.linear is outside of the scope of `torch.compile`.
        x = self.linear(x)
        for layer in self.layers:
            x = layer(x)
        return x

接下来,让我们回顾一下完整模型和区域编译之间的区别。

在完整模型编译中,整个模型作为一个整体进行编译。这是大多数用户使用 torch.compile 采用的常见方法。在此示例中,我们将 torch.compile 应用于 Model 对象。这将有效地内联 64 层,从而生成一个大型图进行编译。您可以通过使用 TORCH_LOGS=graph_code 运行此食谱来查看完整图。

model = Model(apply_regional_compilation=False).cuda()
full_compiled_model = torch.compile(model)

另一方面,区域编译编译模型的某个区域。通过策略性地选择编译模型中重复的区域,我们可以编译一个更小的图,然后将编译后的图重用于所有区域。在该示例中,torch.compile 仅应用于 layers,而不是完整模型。

regional_compiled_model = Model(apply_regional_compilation=True).cuda()

对重复区域而不是完整模型应用编译,可以大大节省编译时间。在这里,我们将只编译一个层实例,然后在 Model 对象中重复使用它 64 次。

请注意,对于重复区域,模型的某些部分可能未编译。例如,Model 中的 self.linear 超出了区域编译的范围。

另请注意,性能加速和编译时间之间存在权衡。完整模型编译涉及更大的图,理论上为优化提供了更多空间。但是,出于实际目的,并且取决于模型,我们观察到完整模型和区域编译之间只有很小的加速差异的许多情况。

接下来,让我们测量完整模型和区域编译的编译时间。

torch.compile 是一个 JIT 编译器,这意味着它在首次调用时进行编译。在下面的代码中,我们测量首次调用中花费的总时间。虽然此方法不精确,但它提供了一个很好的估计,因为大部分时间都花在编译上。

def measure_latency(fn, input):
    # Reset the compiler caches to ensure no reuse between different runs
    torch.compiler.reset()
    with torch._inductor.utils.fresh_inductor_cache():
        start = perf_counter()
        fn(input)
        torch.cuda.synchronize()
        end = perf_counter()
        return end - start


input = torch.randn(10, 10, device="cuda")
full_model_compilation_latency = measure_latency(full_compiled_model, input)
print(f"Full model compilation time = {full_model_compilation_latency:.2f} seconds")

regional_compilation_latency = measure_latency(regional_compiled_model, input)
print(f"Regional compilation time = {regional_compilation_latency:.2f} seconds")

assert regional_compilation_latency < full_model_compilation_latency
Full model compilation time = 24.60 seconds
Regional compilation time = 1.71 seconds

结论

本食谱展示了如果您的模型具有重复区域,如何控制冷启动编译时间。此方法需要用户修改以将 torch.compile 应用于重复区域,而不是更常用的完整模型编译。我们正在不断努力减少冷启动编译时间。

脚本的总运行时间: ( 0 分钟 26.422 秒)

图库由 Sphinx-Gallery 生成


评价本教程

© 版权所有 2024, PyTorch。

使用 Sphinx 构建,主题由 theme 提供,由 Read the Docs 提供支持。

文档

访问 PyTorch 的综合开发者文档

查看文档

教程

获取面向初学者和高级开发者的深入教程

查看教程

资源

查找开发资源并获得问题解答

查看资源