概要: Transformer 在自然语言处理(NLP)领域取得了最先进的性能,并正成为其他众多任务的热门选择。它们计算成本高昂,这阻碍了其广泛的生产化。PyTorch 1.12 发布了 BetterTransformer,它为 Transformer 编码器推理实现了一个向后兼容的 torch.nn.TransformerEncoder 快速路径,并且不需要模型作者修改他们的模型。BetterTransformer 的改进在许多常见的执行场景中可以将速度和吞吐量提高两倍以上。要使用 BetterTransformer,请安装 PyTorch 1.12,并立即开始使用 PyTorch API 的高质量、高性能 Transformer 模型。

Transformer 编码器架构图(来自“Attention Is All You Need”)。在推理过程中,整个模块将作为单个 PyTorch 原生函数执行。
在这篇博文中,我们将分享以下主题——性能改进、向后兼容性和利用快速路径。请在下面了解更多这些主题。
性能改进
BetterTransformer 推出了针对 CPU 和 GPU 的 MultiHeadAttention 和 TransformerEncoderLayer 的加速原生实现。这些快速路径已集成到标准的 PyTorch Transformer API 中,并将加速 TransformerEncoder、TransformerEncoderLayer 和 MultiHeadAttention nn.module。这些新模块实现了两种类型的优化:(1) 融合内核将通常用于实现 Transformer 的多个独立运算符组合在一起,以提供更高效的实现;(2) 利用输入中的稀疏性,避免对填充令牌执行不必要的操作。在许多用于自然语言处理的 Transformer 模型中,填充令牌经常占输入批次很大一部分。
向后兼容性
值得庆幸的是,无需修改模型即可受益于 BetterTransformer 带来的性能提升。要受益于快速路径执行,输入和操作条件必须满足一些访问条件(见下文)。虽然 Transformer API 的内部实现已更改,但 PyTorch 1.12 保持与以前版本中发布的 Transformer 模块的严格兼容性,使 PyTorch 用户可以使用使用以前 PyTorch 版本创建和训练的模型,同时受益于 BetterTransformer 的改进。
除了启用 PyTorch nn.Modules,BetterTransformer 还为 PyTorch 库提供了改进。性能优势将通过两种不同的启用路径实现:
- 透明加速: MultiHeadAttention 等 PyTorch nn.Modules 以及更高级别的 Transformer 组件的现有用户将自动受益于新 nn.Modules 的改进性能。一个例子是 torchvision 库中使用的视觉 Transformer (ViT) 实现(代码链接)。
- Torchtext 库加速:作为该项目的一部分,我们优化了 Torchtext,使其基于 PyTorch 核心 API 构建,以受益于 BetterTransformer 的增强,同时保持与以前的库版本以及使用以前 Torchtext 版本训练的模型严格和透明的兼容性。在 Torchtext 中使用 PyTorch Transformer 还确保 Torchtext 将受益于 PyTorch Transformer 实现未来预期的增强。
利用快速路径
BetterTransformer 是 PyTorch Transformer API 的快速路径。快速路径是用于 CPU 和 GPU 的关键 Transformer 函数的本机专用实现,适用于常见的 Transformer 用例。
为了利用输入稀疏性(即填充)来加速模型(参见图 2),在实例化 TransformerEncoder 时将关键字参数 enable_nested_tensor=True 设置为 true,并在推理期间传入 src_key_padding_mask 参数(表示填充令牌)。这要求填充掩码是连续的,这是典型情况。
目前,BetterTransformer 的加速仅适用于推理中使用的 Transformer 编码器模型。为了受益于快速路径执行,模型必须由以下任何组件组成:TransformerEncoder、TransformerEncoderLayer 或 MultiheadAttention (MHA)。快速路径执行还受某些条件限制。最重要的是,模型必须在推理模式下执行,并且在不收集梯度带信息(例如,使用 torch.no_grad 运行)的输入张量上操作。条件的完整列表可以在 nn.MultiHeadAttention 和 nn.TransformerEncoder 的这些链接中找到。如果不满足条件,控制将流向旧的 PyTorch 1.11 Transformer 实现,该实现具有相同的 API,但缺少快速路径性能提升。
其他使用 PyTorch MultiheadAttention 模块的 Transformer 模型(例如解码器模型)将受益于 BetterTransformer 快速路径。未来计划的工作是将端到端 BetterTransformer 快速路径扩展到基于 TransformerDecoder 的模型,以支持流行的 seq2seq 和仅解码器(例如,OPT)模型架构,以及用于训练。
加速
以下图表显示了 BERT-base 模型在小规模和大规模输入下实现的性能

图 1:PyTorch 1.12 与 BetterTransformer 快速路径执行的改进

图 2:PyTorch 1.12 与 BetterTransformer 快速路径执行的改进
通过 enable_nested_tensor=True 启用稀疏性优化
BetterTransformer 包含两种类型的优化:(1) 融合内核,在一个内核中更有效地实现多项操作;(2) 通过避免对填充令牌进行不必要的处理来利用稀疏性。小输入尺寸的性能增强主要受益于融合内核实现,并且无论填充量如何,都显示出持续的性能改进。虽然大输入仍然受益于融合内核,但计算密集型处理限制了融合内核可能获得的收益,因为基线性能已经接近理论峰值。然而,随着填充量的增加,性能显著提高,因为通过利用 NLP 工作负载中填充引入的稀疏性,可以避免越来越多的计算。
未来工作
作为我们正在进行的 PyTorch BetterTransformer 工作的一部分,我们正在努力将 BetterTransformer 的改进扩展到 Transformer 解码器。我们旨在将范围从推理扩展到训练。
我们正在合作在 FairSeq、MetaSeq 和 HuggingFace 等其他库上启用 BetterTransformer,以惠及所有基于 Transformer 的 PyTorch 模型。作为本博客系列的一部分,我们将提供有关 BetterTransformer 加速在更广泛的 PyTorch 生态系统中进展的未来更新。
致谢:作者衷心感谢 Lin Qiao、Ajit Mathews、Andrew Tulloch、Dmytro Dzhulgakov、Natalia Gimelshein、Emad El-Haraty、Mark Saroufim、Adnan Aziz、Geeta Chauhan 和 Hamid Shojanazeri 在本项目期间以及本博客的准备过程中给予的支持、贡献和许多有益的建议。