作者:Felix Marty, Younes Belkada, Hamid Shojanazeri, Driss Guessous

作为 PyTorch 2.0 版本的一部分,“Better Transformer”项目(在 PyTorch 中也称为 Accelerated Transformers)中的注意力机制的加速实现已原生添加到 PyTorch 中,作为 torch.nn.functional.scaled_dot_product_attention。此实现利用了来自 FlashAttentionMemory-efficient attention 的融合内核,并支持训练和推理。

我们还发布了一个 notebook,展示了此集成的示例 此处

在看到 扩散模型在推理时有 20-30% 的加速 后,我们着手实现了与 🤗 Transformers 模型的集成,通过 🤗 Optimum 库。类似于 之前对编码器模型的集成,该集成将 Transformers 中的模块替换为使用 torch.nn.functional.scaled_dot_product_attention 的高效实现。用法如下

from optimum.bettertransformer import BetterTransformer
from transformers import AutoModelForCausalLM

with torch.device(“cuda”):
model = AutoModelForCausalLM.from_pretrained(“gpt2-large”, torch_dtype=torch.float16)

model = BetterTransformer.transform(model)

# do your inference or training here

# if training and want to save the model
model = BetterTransformer.reverse(model)
model.save_pretrained(“fine_tuned_model”)
model.push_to_hub(“fine_tuned_model”) 

下面总结了我们关于 torch.nn.functional.scaled_dot_product_attention 的发现

  • 它在给定硬件上训练大型模型、长序列或大批量时最有用。
  • GPU 在训练期间的内存占用节省从 20% 到 110%+ 不等。
  • 训练期间的加速从 10% 到 70% 不等。
  • 推理期间的加速从 5% 到 20% 不等。
  • 单独来看,对于较小的头维度,scaled_dot_product_attention 加速高达 3 倍,内存节省高达 40 倍(取决于序列长度)。

内存节省和加速的范围如此之广可能会让您感到惊讶。在这篇博客文章中,我们将讨论我们的基准测试,这项功能的亮点以及 PyTorch 未来版本中的改进。

在 Transformers 的下一个版本中,您只需安装适当版本的 optimum 并运行

model = model.to_bettertransformer()

使用 BetterTransformer API 转换您的模型。您现在可以通过从源码安装 transformers 来试用此功能。

基准测试以及与 🤗 Transformers 的用法

torch.nn.functional.scaled_dot_product_attention 适用于任何使用标准注意力机制的架构,特别是取代了样板代码

# native scaled_dot_product_attention is equivalent to the following:
def eager_sdpa(query, key, value, attn_mask, dropout_p, is_causal, scale):
	scale_factor = 1 / math.sqrt(Q.size(-1)) if scale is None else scale
	attn_mask = torch.ones(L, S, dtype=torch.bool).tril(diagonal=0) if is_causal else attn_mask
	attn_mask = attn_mask.masked_fill(not attn_mask, -float('inf')) if attn_mask.dtype==torch.bool else attn_mask
	attn_weight = torch.softmax((Q @ K.transpose(-2, -1) * scale_factor) + attn_mask, dim=-1)
	attn_weight = torch.dropout(attn_weight, dropout_p)
	return attn_weight @ V

在 🤗 Optimum 与 Transformers 模型的集成中,目前支持以下架构:gpt2, gpt-neo, gpt-neox, gptj, t5, bart, codegen, pegasus, opt, LLaMA, blenderbot, m2m100。预计此列表将在不久的将来扩展!

为了验证原生 scaled dot-product attention 的优势,我们运行了推理和训练基准测试,结果如下所示。

在单个 A10G GPU、AWS g5.4xlarge 实例上的推理基准测试 在单个 A10G GPU、AWS g5.4xlarge 实例上的推理基准测试

在单个 A10G GPU、AWS g5.4xlarge 实例上的训练基准测试 在单个 A10G GPU、AWS g5.4xlarge 实例上的训练基准测试

在单个 A100-SXM4-80GB、Nvidia DGX 上的训练基准测试 在单个 A100-SXM4-80GB、Nvidia DGX 上的训练基准测试

从这个基准测试中,最有趣的发现是,原生 SDPA 允许使用更长的序列长度和更大的批量,而不会出现内存不足问题。此外,推理期间可以看到高达 20% 的加速,训练期间甚至更大。

从训练基准测试中可以看出,头维度越小,带来的加速和内存节省越多,我们将在下一节讨论这一点。

该实现也支持多 GPU 设置,感谢 🤗 Accelerate 库,通过将 device_map=”auto” 传递给 from_pretrained 方法。以下是在两块 A100-SXM4-80GB 上的训练结果。

在两块 A100-SXM4-80GB、Nvidia DGX 上使用 🤗 Accelerate 库进行分布式训练的基准测试 在两块 A100-SXM4-80GB、Nvidia DGX 上使用 🤗 Accelerate 库进行分布式训练的基准测试

请注意,某些内核仅支持 sm_80 计算能力(即 A100 GPU 的计算能力),这限制了在广泛硬件上的可用性,特别是当头维度不是 2 的幂时。例如,截至 PyTorch 2.0.0 训练期间,opt-2.7b (headim=80) 和 gpt-neox-20b (headdim=96) 无法调度到使用 flash attention 的内核,除非在 A100 GPU 上运行。未来可能会开发更好的内核:https://github.com/pytorch/pytorch/issues/98140#issuecomment-1518101895

Flash Attention、Memory-efficient attention & 数学实现差异

原生 scaled_dot_product_attention 依赖于三种可能的后端实现:flash attention、memory-efficient attention 和所谓的数学实现,它为所有 PyTorch 平台提供硬件无关的备用方案。

当对于给定问题大小有融合内核可用时,将使用 flash-attention 或 memory-efficient attention,从而有效降低内存占用,因为在 memory-efficient attention 的情况下,内存分配是在 GPU 全局内存上以 O(N) 进行的,而不是传统 eager attention 实现的经典 O(N^2)。使用 flash attention,预计会减少内存访问次数(读写),因此既提供了加速又节省了内存。

“数学”实现仅仅是一个使用 PyTorch C++ API 的 实现。在此实现中值得注意的是,query 和 key 张量为了数值稳定性而单独进行缩放,因此会启动两个 aten::div 操作,而不是在没有此数值稳定性优化的 eager 实现中可能只有一个。

头维度对加速、内存节省的影响

torch.nn.functional.scaled_dot_product_attention 进行基准测试时,我们注意到随着头维度增加,加速/内存节省效益会下降。这对某些架构来说是个问题,例如 EleutherAI/gpt-neo-2.7B,其头维度相对较大,为 128,或者 EleutherAI/gpt-j-6B(以及衍生的模型如 PygmalionAI/pygmalion-6b),其头维度为 256(实际上目前由于头维度过大,无法调度到融合内核)。

此趋势可以在下面的图中看到,其中对 torch.nn.scaled_dot_production 进行了独立基准测试,与上述 eager 实现进行对比。此外,我们使用 torch.backends.cuda.sdp_kernel 上下文管理器来分别强制使用 math、flash attention 和 memory-efficient attention 实现。

使用 memory-efficient attention SDP 内核(仅前向),A100 使用 memory-efficient attention SDP 内核(仅前向),A100

使用 math 实现(无 dropout),A100 使用 math 实现(无 dropout),A100

使用 flash attention SDP 内核(无 dropout),A100 使用 flash attention SDP 内核(无 dropout),A100

使用 memory-efficient attention SDP 内核(无 dropout),A100 使用 memory-efficient attention SDP 内核(无 dropout),A100

我们看到对于相同的问题大小,无论是仅用于推理还是训练,随着头维度增加,加速效果降低,例如使用 flash attention 内核时,headdim=8 时加速 3.4 倍,而 headdim=128 时加速 1.01 倍。

头维度越大,内存节省越少是可以预期的。回顾标准注意力计算

Math equation

在这种标准的逐步计算中,由于中间计算,全局内存占用为 2 * N * N + N * d。Memory-efficient attention 提出迭代更新 softmax 重归一化常数,并将其计算移至最后,从而只允许常数输出内存分配 N * d。

因此,内存节省比例为 2 * N / d + 1,随着头维度增加而降低。

在 flash attention 中,权衡在于头维度 d 与 GPU 流式多处理器共享内存大小 M 之间,总内存访问次数为 O(N² * d²/M)。因此,内存访问次数在头维度上呈平方增长,与呈线性增长的标准注意力机制相反。原因是,在 flash attention 中,对于较大的头维度 d,key 和 value K, V 需要分成更多块才能放入共享内存,反过来,每个块都需要加载完整的 query Q 和 output O。

因此,flash attention 的最大加速出现在比率 d² / M 足够小的区域。

PyTorch 2.0.0 的当前限制

缺少 scale 参数

截至 PyTorch 2.0.0 版本,torch.nn.functional.scaled_dot_product_attention 没有 scale 参数,并使用默认的隐藏大小的平方根 sqrt(d_k)。

Math equation

然而,OPT 或 T5 等一些架构在注意力机制中不使用缩放,这在 PyTorch 2.0.0 版本中强制在调用 scaled_dot_product_attention 之前进行人工缩放。这引入了不必要的开销,因为需要额外的乘法,再加上注意力机制中不需要的除法。

此问题的修复已合并到 PyTorch 仓库中。

flash attention / memory-efficient attention 支持自定义掩码

截至 PyTorch 2.0.0 版本,当传递自定义注意力掩码时,flash attention 和 memory-efficient attention 无法使用。在这种情况下,scaled_dot_product_attention 自动调度到 C++ 实现。

然而,正如我们所见,一些架构需要自定义注意力掩码,例如使用位置偏差的 T5。此外,在批量大于一且部分输入可能被填充的情况下,也需要传递自定义注意力掩码。对于后一种情况,另一种方法是使用 NestedTensor,SDPA 支持它。

这种对自定义掩码的有限支持因此限制了 SDPA 在这些特定情况下的优势,尽管我们可以期待未来 扩展支持

请注意,PyTorch 的 SDPA 部分借鉴了 xformers,它目前支持任意注意力掩码:https://github.com/facebookresearch/xformers/blob/658ebab39545f180a6075385b3897921623d6c3b/xformers/ops/fmha/cutlass.py#L147-L156。HazyResearch 的 flash attention 实现也支持等效的填充实现,因为使用累积序列长度数组以及 packed query/key/values - 本质上类似于 NestedTensor。

结论

使用 torch.nn.functional.scaled_dot_product_attention 是一种免费的优化,既使您的代码更易读,使用更少的内存,并且在大多数情况下更快。

尽管 PyTorch 2.0.0 中的实现仍有一些小限制,但在大多数情况下,推理和训练已经从 SDPA 中大量受益。我们鼓励您使用此原生实现来训练或部署您的 PyTorch 模型,对于 🤗 Transformers 模型,它只需一行代码转换即可!

未来,我们希望调整 API,以便用户也能在基于编码器的模型中使用 SDPA。

我们感谢 Benjamin Lefaudeux、Daniel Haziza 和 Francisco Massa 在头维度影响方面提供的建议,以及 Michael Gschwind、Christian Puhrsch 和 Driss Guessous 对博客文章提供的反馈!

基准测试复现

本文提出的基准测试是使用 torch==2.0.0, transformers==4.27.4, accelerate==0.18.0 和 optimum==1.8.0 完成的。

可以使用 推理训练 的脚本轻松复现基准测试,适用于 🤗 Transformers 模型,以及 独立的 SDPA