作者:Michael Gschwind, Driss Guessous, Christian Puhrsch

PyTorch 2.0 版本包含一个 PyTorch Transformer API 的新型高性能实现,旨在降低最先进 Transformer 模型的训练和部署成本。继成功发布“fastpath”推理执行(“Better Transformer”)之后,此版本引入了使用缩放点积注意力 (SPDA) 自定义内核架构实现的高性能训练和推理支持。

您可以通过直接调用新的 SDPA 算子(如SDPA 教程中所述),或通过集成到现有的 PyTorch Transformer API 中,透明地利用新的融合 SDPA 内核。PyTorch Transformer API 的所有功能将继续保持兼容,许多功能映射到高性能 SDPA 内核,而其他功能则无法以更高性能支持(例如,如下所述的 need_weights),同时对其他功能的增强高性能支持可能仍在积极开发中。

与“fastpath”架构类似,自定义内核完全集成到 PyTorch Transformer API 中——因此,使用原生的 Transformer 和 MultiHeadAttention API 将使开发者透明地看到显著的速度提升。与“fastpath”架构不同的是,新引入的“自定义内核”支持更多用例,包括使用交叉注意力的模型、Transformer 解码器和模型训练,此外还支持现有的针对固定和可变序列长度 Transformer 编码器及自注意力用例的 fastpath 推理。

为了充分利用不同的硬件模型和 Transformer 用例,PyTorch 支持多种 SDPA 自定义内核,通过自定义内核选择逻辑,为给定模型和硬件类型选择性能最高的内核。具体来说,PyTorch 2.0 版本中包含的首批自定义内核是Flash Attention 内核(sdpa_flash,适用于具有 SM80+ 架构级别的英伟达 GPU 上的 16 位浮点训练和推理)和xFormers 内存高效注意力内核(sdpa_mem_eff,适用于各种英伟达 GPU 上的 16 位和 32 位浮点训练和推理)。当自定义内核不适用时,通用内核 sdpa_math 提供了一种实现。

如前所述,自定义内核提供了更广泛的执行场景支持。为了确保高效执行(例如,使用 GPU Tensor Cores),模型配置需要满足少数要求。这个要求列表将随时间演进,有望放宽限制当前支持的自定义内核使用的约束,或将来提供额外的内核。

有关自定义内核和调度约束的最新列表,您可以参考 sdp_utils.h。截至 PyTorch 2.0,现有融合 SDPA 内核有以下约束:

  • Flash Attention 仅支持 16 位浮点数据类型(float16 和 bfloat16)。
  • 对于 16 位浮点数,头维度必须是 8 的倍数;对于 32 位浮点数,必须是 4 的倍数。目前,Flash Attention 自定义内核支持的最大 head_dim 为 128。
  • CUDA 架构级别对于 mem_efficient 内核必须是 sm5x 或更高,对于 Flash Attention 必须是 sm80。
  • Flash Attention 支持任意 dropout,在 PyTorch 2.0 中,mem_efficient 内核不支持 dropout(即,对于此内核要被选中,dropout 必须设置为零)。
  • 为了支持可变序列长度批次,所有 SDPA 内核都支持使用可变序列长度张量将输入数据和填充信息结合起来进行前向计算的 Nested Tensor 输入。(您可以在Nested Tensor 教程中找到有关 Nested Tensors 的更多信息。)
  • 您可以通过在传递给 SDPA 算子之前组合 key_padding_maskattn_mask 来同时指定它们。特别是,您可以使用 nn.Transformer API 中每个批次元素的 key padding mask 来实现批次中可变序列长度输入的训练。
  • 目前,融合内核实现支持的注意力掩码仅为训练中常用的因果掩码。要在自定义内核中指定因果掩码,必须使用 is_causal 布尔值指定,并且 attn_mask 必须为 None。
  • 对 Nested Tensors 的支持仍在开发中。具体来说,在 PyTorch 2.0 中,只有 sdpa_math 内核支持使用 Nested Tensors 进行训练。此外,PyTorch 2.0 不支持将 Nested Tensors 作为使用 torch.compile() 编译代码的一部分。
  • SDPA 算子不支持返回平均注意力权重,因为计算它们会破坏使融合内核更高效执行的优化。torch.nn.MultiheadAttention 的 forward 函数的参数 need_weights 默认为 True。为了使用融合内核,need_weights 需要设置为 need_weights=False

我们发现注意力掩码在实际应用中很少使用,除了训练期间的因果掩码。因此,我们通过内置将因果掩码用作注意力掩码的选项来降低内核复杂性和计算成本,并使用与新 SDPA 算子一同引入的 is_causal 参数来选择此新功能。

为频繁使用的因果掩码提供 is_causal 布尔标志还避免了昂贵且内存密集型的因果掩码分配,通过允许更多内存用于大批次大小来提高训练内存效率,并通过无需加载注意力掩码张量来减少内存带宽和缓存争用——这在 GPU 加速器中都非常宝贵。

如果任何可用自定义内核的约束条件都不满足,则训练将回退到使用默认的 sdpa_math 内核,该内核使用一系列 PyTorch 算子实现缩放点积注意力的数学方程。这是最通用的“包罗万象”的回退内核,以确保所有模型都能成功训练。

除了现有的 Transformer API 外,模型开发者还可以通过调用新的 scaled_dot_product_attention() 算子直接使用缩放点积注意力内核。如SDPA 教程中所述,此算子可用于结合输入投影和输出投影高效地实现多头注意力。

除了添加自定义内核外,加速的 PyTorch 2 Transformers 还与 PyTorch 2.0 编译集成。为了在使用您的模型时受益于 PT2 编译的额外加速(用于推理或训练),请使用以下代码对模型进行预处理:

model = torch.compile(model)

我们通过结合使用自定义内核和 torch.compile(),使用加速的 PyTorch 2 Transformers 在训练 Transformer 模型特别是大型语言模型方面取得了显著的加速。

Better Transformer chart 图:使用带有自定义内核和 torch.compile 的缩放点积注意力,可以显著加速训练大型语言模型,例如此处所示的 nanoGPT

最后,由于自定义内核的内存效率大大提高,请尝试增加训练批次的大小,以通过更大的批次大小实现更快的训练。

除了自动内核选择外,上下文管理器还允许开发者覆盖内核选择算法——这对于日常操作不是必需的,但使开发者能够调试其代码,并使性能工程师能够覆盖内核选择。SDPA 教程提供了有关使用 SDPA 上下文管理器的更多信息。

除了作为 nn.Transformer API 的一部分可用外,加速的 PyTorch 2 Transformer 自定义内核在 PyTorch 2.0 发布时也与 torchtext、torchvision 和 fairseq 等领域库一同可用。