• 教程 >
  • 使用区域编译减少 torch.compile 冷启动编译时间
快捷键

使用区域编译减少 torch.compile 冷启动编译时间

作者: Animesh Jain

随着深度学习模型越来越大,这些模型的编译时间也会增加。这种延长编译时间会导致推理服务启动时间过长,或者在大型训练中浪费资源。此食谱展示了如何通过选择编译模型的重复区域而不是整个模型来缩短冷启动编译时间。

先决条件

  • Pytorch 2.5 或更高版本

设置

在开始之前,我们需要安装 torch(如果尚未安装)。

pip install torch

注意

此功能从 2.5 版本开始可用。如果你使用的是 2.4 版本,可以启用配置标志 torch._dynamo.config.inline_inbuilt_nn_modules=True 来防止区域编译期间重新编译。在 2.5 版本中,此标志默认情况下处于启用状态。

from time import perf_counter

步骤

在此食谱中,我们将按照以下步骤进行操作

  1. 导入所有必要的库。

  2. 定义并初始化一个具有重复区域的神经网络。

  3. 了解完整模型与区域编译之间的区别。

  4. 测量完整模型和区域编译的编译时间。

首先,让我们导入加载数据的必要库

import torch
import torch.nn as nn

接下来,让我们定义并初始化一个具有重复区域的神经网络。

通常,神经网络由重复的层组成。例如,大型语言模型由许多 Transformer 块组成。在此食谱中,我们将使用 nn.Module 类创建一个 Layer 作为重复区域的代理。然后我们将创建一个 Model,它由 64 个此 Layer 类的实例组成。

class Layer(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.linear1 = torch.nn.Linear(10, 10)
        self.relu1 = torch.nn.ReLU()
        self.linear2 = torch.nn.Linear(10, 10)
        self.relu2 = torch.nn.ReLU()

    def forward(self, x):
        a = self.linear1(x)
        a = self.relu1(a)
        a = torch.sigmoid(a)
        b = self.linear2(a)
        b = self.relu2(b)
        return b


class Model(torch.nn.Module):
    def __init__(self, apply_regional_compilation):
        super().__init__()
        self.linear = torch.nn.Linear(10, 10)
        # Apply compile only to the repeated layers.
        if apply_regional_compilation:
            self.layers = torch.nn.ModuleList(
                [torch.compile(Layer()) for _ in range(64)]
            )
        else:
            self.layers = torch.nn.ModuleList([Layer() for _ in range(64)])

    def forward(self, x):
        # In regional compilation, the self.linear is outside of the scope of `torch.compile`.
        x = self.linear(x)
        for layer in self.layers:
            x = layer(x)
        return x

接下来,我们来回顾一下完整模型编译和区域编译之间的区别。

在完整模型编译中,整个模型作为一个整体进行编译。这是大多数用户使用torch.compile的常见方法。在这个例子中,我们将torch.compile应用于Model对象。这将有效地内联 64 层,生成一个大型图进行编译。您可以通过使用TORCH_LOGS=graph_code运行此配方来查看完整图。

model = Model(apply_regional_compilation=False).cuda()
full_compiled_model = torch.compile(model)

另一方面,区域编译会编译模型的一个区域。通过策略性地选择编译模型的重复区域,我们可以编译一个更小的图,然后将编译后的图重复使用到所有区域。在示例中,torch.compile仅应用于layers,而不是整个模型。

regional_compiled_model = Model(apply_regional_compilation=True).cuda()

将编译应用于重复区域而不是完整模型,可以节省大量的编译时间。在这里,我们将只编译一个层实例,然后在Model对象中重复使用它 64 次。

请注意,对于重复区域,模型的某些部分可能不会被编译。例如,Model中的self.linear位于区域编译范围之外。

另外请注意,性能加速和编译时间之间存在权衡。完整模型编译涉及一个更大的图,理论上提供了更多优化的空间。但是,从实际目的出发,并且取决于模型,我们观察到完整模型编译和区域编译之间的速度差异很小。

接下来,我们来测量完整模型和区域编译的编译时间。

torch.compile是一个 JIT 编译器,这意味着它在第一次调用时进行编译。在下面的代码中,我们测量了第一次调用中花费的总时间。虽然这种方法不精确,但它提供了一个很好的估计,因为大部分时间都花在了编译上。

def measure_latency(fn, input):
    # Reset the compiler caches to ensure no reuse between different runs
    torch.compiler.reset()
    with torch._inductor.utils.fresh_inductor_cache():
        start = perf_counter()
        fn(input)
        torch.cuda.synchronize()
        end = perf_counter()
        return end - start


input = torch.randn(10, 10, device="cuda")
full_model_compilation_latency = measure_latency(full_compiled_model, input)
print(f"Full model compilation time = {full_model_compilation_latency:.2f} seconds")

regional_compilation_latency = measure_latency(regional_compiled_model, input)
print(f"Regional compilation time = {regional_compilation_latency:.2f} seconds")

assert regional_compilation_latency < full_model_compilation_latency
Full model compilation time = 29.65 seconds
Regional compilation time = 1.76 seconds

结论

这个配方展示了如何在模型具有重复区域的情况下控制冷启动编译时间。这种方法需要用户修改,将torch.compile应用于重复区域,而不是更常用的完整模型编译。我们一直在努力减少冷启动编译时间。

脚本总运行时间: ( 0 分 31.488 秒)

Sphinx-Gallery 生成的图库

文档

访问 PyTorch 的全面开发人员文档

查看文档

教程

获取面向初学者和高级开发人员的深入教程

查看教程

资源

查找开发资源并获得问题的解答

查看资源