作为 PyTorch 2.0 发布的一部分,“Better Transformer”项目(在 PyTorch 中称为 Accelerated Transformers)中的注意力机制的加速实现已原生添加到 PyTorch 中,即 torch.nn.functional.scaled_dot_product_attention
。此实现利用了 FlashAttention 和 内存高效注意力 中的融合核,并支持训练和推理。
我们还发布了一个展示此集成示例的笔记本,请点击 此处。
在看到 扩散模型在推理时提速 20-30% 后,我们继续通过 🤗 Optimum 库 实现了与 🤗 Transformers 模型的集成。与 之前编码器模型的集成 类似,此集成将 Transformers 中的模块替换为使用 torch.nn.functional.scaled_dot_product_attention
的高效实现。用法如下:
from optimum.bettertransformer import BetterTransformer
from transformers import AutoModelForCausalLM
with torch.device(“cuda”):
model = AutoModelForCausalLM.from_pretrained(“gpt2-large”, torch_dtype=torch.float16)
model = BetterTransformer.transform(model)
# do your inference or training here
# if training and want to save the model
model = BetterTransformer.reverse(model)
model.save_pretrained(“fine_tuned_model”)
model.push_to_hub(“fine_tuned_model”)
下面总结了我们关于 torch.nn.functional.scaled_dot_product_attention
的发现:
- 它对于在给定硬件上训练更大的模型、序列长度或批处理大小最有用。
- 训练期间 GPU 上的内存占用可节省 20% 到 110%+。
- 训练期间提速 10% 到 70%。
- 推理期间提速 5% 到 20%。
- 对于小头维度,
scaled_dot_product_attention
可独立提速高达 3 倍,内存节省高达 40 倍(取决于序列长度)。
您可能会对内存节省和提速的巨大范围感到惊讶。在这篇博文中,我们将讨论我们的基准测试、此功能的亮点以及未来 PyTorch 版本中即将推出的改进。
在 Transformer 的下一个版本中,您只需安装合适的 Optimum 版本并运行:
model = model.to_bettertransformer()
使用 BetterTransformer API 转换您的模型。您可以通过从源代码安装 Transformer 来尝试此功能。
与 🤗 Transformers 的基准测试和用法
torch.nn.functional.scaled_dot_product_attention
适用于任何使用标准注意力的架构,并且专门替换了样板代码:
# native scaled_dot_product_attention is equivalent to the following:
def eager_sdpa(query, key, value, attn_mask, dropout_p, is_causal, scale):
scale_factor = 1 / math.sqrt(Q.size(-1)) if scale is None else scale
attn_mask = torch.ones(L, S, dtype=torch.bool).tril(diagonal=0) if is_causal else attn_mask
attn_mask = attn_mask.masked_fill(not attn_mask, -float('inf')) if attn_mask.dtype==torch.bool else attn_mask
attn_weight = torch.softmax((Q @ K.transpose(-2, -1) * scale_factor) + attn_mask, dim=-1)
attn_weight = torch.dropout(attn_weight, dropout_p)
return attn_weight @ V
在 🤗 Optimum 与 Transformers 模型的集成中,目前支持以下架构:gpt2、gpt-neo、gpt-neox、gptj、t5、bart、codegen、pegasus、opt、LLaMA、blenderbot、m2m100。您可以期待此列表在不久的将来会扩展!
为了验证原生缩放点积注意力带来的好处,我们进行了推理和训练基准测试,结果如下所示。
在单个 A10G GPU、AWS g5.4xlarge 实例上的推理基准测试
在单个 A10G GPU、AWS g5.4xlarge 实例上的训练基准测试
在单个 A100-SXM4-80GB、Nvidia DGX 上的训练基准测试
从这个基准测试中,最有趣的发现是,原生 SDPA 允许使用更长的序列长度和更大的批处理大小,而不会出现内存不足问题。此外,推理期间可提速高达 20%,训练期间提速甚至更大。
正如训练基准测试所示,较小的头维度似乎带来更高的提速和内存节省,我们将在下一节中讨论这一点。
通过将 device_map=”auto”
传递给 from_pretrained
方法,此实现也支持多 GPU 设置,这得益于 🤗 Accelerate 库。以下是两个 A100-SXM4-80GB 上的训练结果。
在两个 A100-SXM4-80GB、Nvidia DGX 上使用 🤗 Accelerate 库进行分布式训练的训练基准测试
请注意,某些内核仅支持 sm_80 计算能力(即 A100 GPU 的计算能力),这限制了在各种硬件上的可用性,尤其是当头维度不是 2 的幂时。例如,截至 PyTorch 2.0.0 训练期间,opt-2.7b(headdim=80)和 gpt-neox-20b(headdim=96)无法调度到使用 flash attention 的内核,除非在 A100 GPU 上运行。未来可能会开发更好的内核:https://github.com/pytorch/pytorch/issues/98140#issuecomment-1518101895
Flash Attention、内存高效注意力与数学差异
原生 scaled_dot_product_attention
依赖于三种可能的后端实现:flash attention、内存高效注意力,以及所谓的数学实现,它为所有 PyTorch 平台提供硬件无关的回退。
当给定问题规模的融合内核可用时,将使用 flash-attention 或内存高效注意力,从而有效降低内存占用,因为在内存高效注意力情况下,GPU 全局内存中执行 O(N) 内存分配,而不是传统急切注意力实现中的经典 O(N^2)。使用 flash attention,内存访问(读写)次数预计会减少,因此两者都提供了提速和内存节省。
“数学”实现只是一个 使用 PyTorch C++ API 的实现。此实现中值得注意的是,查询和键张量为了数值稳定性而单独缩放,因此启动了两个 aten::div 操作,而不是可能在不包含此数值稳定性优化的急切实现中只启动一个。
头维度对提速和内存节省的影响
基准测试 torch.nn.functional.scaled_dot_product_attention
时,我们注意到随着头维度增加,提速/内存增益会降低。这对于某些架构(如 EleutherAI/gpt-neo-2.7B,其头维度相对较大,为 128;或 EleutherAI/gpt-j-6B(以及派生模型如 PygmalionAI/pygmalion-6b),其头维度为 256)来说是一个问题,这些架构实际上目前无法调度到融合内核,因为头维度太大。
这种趋势可以在下面的图中看到,其中 torch.nn.scaled_dot_production
与上述急切实现进行了独立基准测试。此外,我们使用 torch.backends.cuda.sdp_kernel
上下文管理器来强制分别使用数学、flash attention 和内存高效注意力实现。
使用内存高效注意力 SDP 内核(仅向前),A100
使用数学(无 dropout),A100
使用 flash attention SDP 内核(无 dropout),A100
使用内存高效注意力 SDP 内核(无 dropout),A100
我们看到,对于相同的问题规模,无论是仅推理还是训练,提速都会随着头维度的增加而降低,例如,使用 flash attention 内核时,headdim=8 为 3.4 倍,而 headdim=128 为 1.01 倍。
随着头维度增加,内存节省减少是预料之中的。回顾标准注意力计算:

由于中间计算,此标准分步计算中的全局内存占用为 2 * N * N + N * d。内存高效注意力建议迭代更新 softmax 归一化常数,并将其计算移动到最后,从而仅允许恒定的输出内存分配 N * d。
因此,内存节省比率为 2 * N / d + 1,它随着头维度的增加而减小。
在 Flash Attention 中,权衡在于头维度 d 和 GPU 流式多处理器共享内存大小 M 之间,总内存访问次数为 O(N² * d²/M)。因此,内存访问量与头维度呈平方关系,与标准注意力呈线性关系不同。原因是,在 Flash Attention 中,对于较大的头维度 d,键和值 K、V 需要分成更多块以适应共享内存,反过来,每个块都需要加载完整的查询 Q 和输出 O。
因此,Flash Attention 的最高提速是在 d² / M 比率足够小的区域。
PyTorch 2.0.0 的当前限制
缺少比例参数
截至 PyTorch 2.0.0,torch.nn.functional.scaled_dot_product_attention
没有比例参数,并使用隐藏大小的默认平方根 sqrt(d_k)。

然而,某些架构(如 OPT 或 T5)在注意力中不使用缩放,这在 PyTorch 2.0.0 中迫使其在调用 scaled_dot_product_attention
之前进行人工重新缩放。这引入了不必要的开销,因为除了注意力中不必要的除法之外,还需要额外的乘法。
此问题的修复已合并到 PyTorch 仓库 中。
支持带自定义掩码的 Flash Attention / 内存高效注意力
截至 PyTorch 2.0.0,当传递自定义注意力掩码时,无法使用 Flash Attention 和内存高效注意力。在这种情况下,scaled_dot_product_attention
会自动调度到 C++ 实现。
然而,正如我们所见,某些架构需要自定义注意力掩码,例如使用位置偏差的 T5。此外,在批处理大小大于 1 且某些输入可能被填充的情况下,也需要传递自定义注意力掩码。对于后一种情况,替代方法是使用 SDPA 支持的 NestedTensor。
因此,对自定义掩码的有限支持限制了 SDPA 在这些特定情况下的优势,尽管我们希望未来能有更广泛的支持 (此处)。
请注意,PyTorch 的 SDPA 部分受其启发,目前支持任意注意力掩码:https://github.com/facebookresearch/xformers/blob/658ebab39545f180a6075385b3897921623d6c3b/xformers/ops/fmha/cutlass.py#L147-L156。HazyResearch 的 Flash Attention 实现也支持等效的填充实现,因为使用了累积序列长度数组以及打包的查询/键/值——本质上类似于 NestedTensor。
结论
使用 torch.nn.functional.scaled_dot_product_attention
是一种免费的优化,它使您的代码更具可读性,占用更少的内存,并且在大多数常见情况下速度更快。
尽管 PyTorch 2.0.0 中的实现仍然存在一些小的限制,但在大多数情况下,推理和训练已经从 SDPA 中受益匪浅。我们鼓励您使用此原生实现来训练或部署您的 PyTorch 模型,并将其作为 🤗 Transformers 模型的一行转换!
将来,我们希望调整 API,以便用户也能在基于编码器的模型中使用 SDPA。
感谢 Benjamin Lefaudeux、Daniel Haziza 和 Francisco Massa 对头维度影响的建议,以及 Michael Gschwind、Christian Puhrsch 和 Driss Guessous 对博客文章的反馈!
基准复现
本文中展示的基准测试是使用 torch==2.0.0、transformers==4.27.4、accelerate==0.18.0 和 optimum==1.8.0 完成的。