跳转到主要内容
博客

使用 PyTorch 2.0 开箱即用加速 🤗 解码器模型并节省内存

作者: 2023 年 5 月 22 日2024 年 11 月 14 日暂无评论

作为 PyTorch 2.0 发布的一部分,“Better Transformer”项目(在 PyTorch 中称为 Accelerated Transformers)中的注意力机制的加速实现已原生添加到 PyTorch 中,即 torch.nn.functional.scaled_dot_product_attention。此实现利用了 FlashAttention内存高效注意力 中的融合核,并支持训练和推理。

我们还发布了一个展示此集成示例的笔记本,请点击 此处

在看到 扩散模型在推理时提速 20-30% 后,我们继续通过 🤗 Optimum 库 实现了与 🤗 Transformers 模型的集成。与 之前编码器模型的集成 类似,此集成将 Transformers 中的模块替换为使用 torch.nn.functional.scaled_dot_product_attention 的高效实现。用法如下:

from optimum.bettertransformer import BetterTransformer
from transformers import AutoModelForCausalLM

with torch.device(“cuda”):
model = AutoModelForCausalLM.from_pretrained(“gpt2-large”, torch_dtype=torch.float16)

model = BetterTransformer.transform(model)

# do your inference or training here

# if training and want to save the model
model = BetterTransformer.reverse(model)
model.save_pretrained(“fine_tuned_model”)
model.push_to_hub(“fine_tuned_model”) 

下面总结了我们关于 torch.nn.functional.scaled_dot_product_attention 的发现:

  • 它对于在给定硬件上训练更大的模型、序列长度或批处理大小最有用。
  • 训练期间 GPU 上的内存占用可节省 20% 到 110%+。
  • 训练期间提速 10% 到 70%。
  • 推理期间提速 5% 到 20%。
  • 对于小头维度,scaled_dot_product_attention 可独立提速高达 3 倍,内存节省高达 40 倍(取决于序列长度)。

您可能会对内存节省和提速的巨大范围感到惊讶。在这篇博文中,我们将讨论我们的基准测试、此功能的亮点以及未来 PyTorch 版本中即将推出的改进。

在 Transformer 的下一个版本中,您只需安装合适的 Optimum 版本并运行:

model = model.to_bettertransformer()

使用 BetterTransformer API 转换您的模型。您可以通过从源代码安装 Transformer 来尝试此功能。

与 🤗 Transformers 的基准测试和用法

torch.nn.functional.scaled_dot_product_attention 适用于任何使用标准注意力的架构,并且专门替换了样板代码:

# native scaled_dot_product_attention is equivalent to the following:
def eager_sdpa(query, key, value, attn_mask, dropout_p, is_causal, scale):
	scale_factor = 1 / math.sqrt(Q.size(-1)) if scale is None else scale
	attn_mask = torch.ones(L, S, dtype=torch.bool).tril(diagonal=0) if is_causal else attn_mask
	attn_mask = attn_mask.masked_fill(not attn_mask, -float('inf')) if attn_mask.dtype==torch.bool else attn_mask
	attn_weight = torch.softmax((Q @ K.transpose(-2, -1) * scale_factor) + attn_mask, dim=-1)
	attn_weight = torch.dropout(attn_weight, dropout_p)
	return attn_weight @ V

在 🤗 Optimum 与 Transformers 模型的集成中,目前支持以下架构:gpt2、gpt-neo、gpt-neox、gptj、t5、bart、codegen、pegasus、opt、LLaMA、blenderbot、m2m100。您可以期待此列表在不久的将来会扩展!

为了验证原生缩放点积注意力带来的好处,我们进行了推理和训练基准测试,结果如下所示。

在单个 A10G GPU、AWS g5.4xlarge 实例上的推理基准测试 在单个 A10G GPU、AWS g5.4xlarge 实例上的推理基准测试

在单个 A10G GPU、AWS g5.4xlarge 实例上的训练基准测试 在单个 A10G GPU、AWS g5.4xlarge 实例上的训练基准测试

在单个 A100-SXM4-80GB、Nvidia DGX 上的训练基准测试 在单个 A100-SXM4-80GB、Nvidia DGX 上的训练基准测试

从这个基准测试中,最有趣的发现是,原生 SDPA 允许使用更长的序列长度和更大的批处理大小,而不会出现内存不足问题。此外,推理期间可提速高达 20%,训练期间提速甚至更大。

正如训练基准测试所示,较小的头维度似乎带来更高的提速和内存节省,我们将在下一节中讨论这一点。

通过将 device_map=”auto” 传递给 from_pretrained 方法,此实现也支持多 GPU 设置,这得益于 🤗 Accelerate 库。以下是两个 A100-SXM4-80GB 上的训练结果。

在两个 A100-SXM4-80GB、Nvidia DGX 上使用 🤗 Accelerate 库进行分布式训练的训练基准测试 在两个 A100-SXM4-80GB、Nvidia DGX 上使用 🤗 Accelerate 库进行分布式训练的训练基准测试

请注意,某些内核仅支持 sm_80 计算能力(即 A100 GPU 的计算能力),这限制了在各种硬件上的可用性,尤其是当头维度不是 2 的幂时。例如,截至 PyTorch 2.0.0 训练期间,opt-2.7b(headdim=80)和 gpt-neox-20b(headdim=96)无法调度到使用 flash attention 的内核,除非在 A100 GPU 上运行。未来可能会开发更好的内核:https://github.com/pytorch/pytorch/issues/98140#issuecomment-1518101895

Flash Attention、内存高效注意力与数学差异

原生 scaled_dot_product_attention 依赖于三种可能的后端实现:flash attention、内存高效注意力,以及所谓的数学实现,它为所有 PyTorch 平台提供硬件无关的回退。

当给定问题规模的融合内核可用时,将使用 flash-attention 或内存高效注意力,从而有效降低内存占用,因为在内存高效注意力情况下,GPU 全局内存中执行 O(N) 内存分配,而不是传统急切注意力实现中的经典 O(N^2)。使用 flash attention,内存访问(读写)次数预计会减少,因此两者都提供了提速和内存节省。

“数学”实现只是一个 使用 PyTorch C++ API 的实现。此实现中值得注意的是,查询和键张量为了数值稳定性而单独缩放,因此启动了两个 aten::div 操作,而不是可能在不包含此数值稳定性优化的急切实现中只启动一个。

头维度对提速和内存节省的影响

基准测试 torch.nn.functional.scaled_dot_product_attention 时,我们注意到随着头维度增加,提速/内存增益会降低。这对于某些架构(如 EleutherAI/gpt-neo-2.7B,其头维度相对较大,为 128;或 EleutherAI/gpt-j-6B(以及派生模型如 PygmalionAI/pygmalion-6b),其头维度为 256)来说是一个问题,这些架构实际上目前无法调度到融合内核,因为头维度太大。

这种趋势可以在下面的图中看到,其中 torch.nn.scaled_dot_production 与上述急切实现进行了独立基准测试。此外,我们使用 torch.backends.cuda.sdp_kernel 上下文管理器来强制分别使用数学、flash attention 和内存高效注意力实现。

使用内存高效注意力 SDP 内核(仅向前),A100 使用内存高效注意力 SDP 内核(仅向前),A100

使用数学(无 dropout),A100 使用数学(无 dropout),A100

使用 flash attention SDP 内核(无 dropout),A100 使用 flash attention SDP 内核(无 dropout),A100

使用内存高效注意力 SDP 内核(无 dropout),A100 使用内存高效注意力 SDP 内核(无 dropout),A100

我们看到,对于相同的问题规模,无论是仅推理还是训练,提速都会随着头维度的增加而降低,例如,使用 flash attention 内核时,headdim=8 为 3.4 倍,而 headdim=128 为 1.01 倍。

随着头维度增加,内存节省减少是预料之中的。回顾标准注意力计算:

Math equation

由于中间计算,此标准分步计算中的全局内存占用为 2 * N * N + N * d。内存高效注意力建议迭代更新 softmax 归一化常数,并将其计算移动到最后,从而仅允许恒定的输出内存分配 N * d。

因此,内存节省比率为 2 * N / d + 1,它随着头维度的增加而减小。

在 Flash Attention 中,权衡在于头维度 d 和 GPU 流式多处理器共享内存大小 M 之间,总内存访问次数为 O(N² * d²/M)。因此,内存访问量与头维度呈平方关系,与标准注意力呈线性关系不同。原因是,在 Flash Attention 中,对于较大的头维度 d,键和值 K、V 需要分成更多块以适应共享内存,反过来,每个块都需要加载完整的查询 Q 和输出 O。

因此,Flash Attention 的最高提速是在 d² / M 比率足够小的区域。

PyTorch 2.0.0 的当前限制

缺少比例参数

截至 PyTorch 2.0.0,torch.nn.functional.scaled_dot_product_attention 没有比例参数,并使用隐藏大小的默认平方根 sqrt(d_k)。

Math equation

然而,某些架构(如 OPT 或 T5)在注意力中不使用缩放,这在 PyTorch 2.0.0 中迫使其在调用 scaled_dot_product_attention 之前进行人工重新缩放。这引入了不必要的开销,因为除了注意力中不必要的除法之外,还需要额外的乘法。

此问题的修复已合并到 PyTorch 仓库 中。

支持带自定义掩码的 Flash Attention / 内存高效注意力

截至 PyTorch 2.0.0,当传递自定义注意力掩码时,无法使用 Flash Attention 和内存高效注意力。在这种情况下,scaled_dot_product_attention 会自动调度到 C++ 实现。

然而,正如我们所见,某些架构需要自定义注意力掩码,例如使用位置偏差的 T5。此外,在批处理大小大于 1 且某些输入可能被填充的情况下,也需要传递自定义注意力掩码。对于后一种情况,替代方法是使用 SDPA 支持的 NestedTensor

因此,对自定义掩码的有限支持限制了 SDPA 在这些特定情况下的优势,尽管我们希望未来能有更广泛的支持 (此处)

请注意,PyTorch 的 SDPA 部分受其启发,目前支持任意注意力掩码:https://github.com/facebookresearch/xformers/blob/658ebab39545f180a6075385b3897921623d6c3b/xformers/ops/fmha/cutlass.py#L147-L156。HazyResearch 的 Flash Attention 实现也支持等效的填充实现,因为使用了累积序列长度数组以及打包的查询/键/值——本质上类似于 NestedTensor。

结论

使用 torch.nn.functional.scaled_dot_product_attention 是一种免费的优化,它使您的代码更具可读性,占用更少的内存,并且在大多数常见情况下速度更快。

尽管 PyTorch 2.0.0 中的实现仍然存在一些小的限制,但在大多数情况下,推理和训练已经从 SDPA 中受益匪浅。我们鼓励您使用此原生实现来训练或部署您的 PyTorch 模型,并将其作为 🤗 Transformers 模型的一行转换!

将来,我们希望调整 API,以便用户也能在基于编码器的模型中使用 SDPA。

感谢 Benjamin Lefaudeux、Daniel Haziza 和 Francisco Massa 对头维度影响的建议,以及 Michael Gschwind、Christian Puhrsch 和 Driss Guessous 对博客文章的反馈!

基准复现

本文中展示的基准测试是使用 torch==2.0.0、transformers==4.27.4、accelerate==0.18.0 和 optimum==1.8.0 完成的。

可以使用 🤗 Transformers 模型的 推理训练 脚本以及 独立 SDPA 轻松复现基准测试。