跳转到主要内容
博客

一个更好的 Transformer,用于快速 Transformer 推理

提要: Transformer 在自然语言处理(NLP)领域取得了最先进的性能,并且在许多其他任务中也越来越受欢迎。但它们计算成本高昂,这阻碍了其广泛的生产化。PyTorch 1.12 发布了 BetterTransformer,它为 Transformer 编码器推理实现了一个向后兼容的 torch.nn.TransformerEncoder 快速路径,并且不需要模型作者修改其模型。BetterTransformer 的改进在许多常见的执行场景中可以将速度和吞吐量提高 2 倍以上。要使用 BetterTransformer,请安装 PyTorch 1.12,并立即开始使用 PyTorch API 构建高质量、高性能的 Transformer 模型。

Transformer 编码器架构图(来自“Attention Is All You Need”)。在推理过程中,整个模块将作为一个 PyTorch 原生函数执行。

在这篇博客文章中,我们将分享以下主题——性能改进、向后兼容性以及利用快速路径。请在下方了解有关这些主题的更多信息。

性能改进

BetterTransformer 推出了针对 CPU 和 GPU 的 MultiHeadAttention 和 TransformerEncoderLayer 的加速原生实现。这些快速路径已集成到标准的 PyTorch Transformer API 中,并将加速 TransformerEncoderTransformerEncoderLayerMultiHeadAttention nn.modules。这些新模块实现了两种类型的优化:(1) 融合内核将通常用于实现 Transformer 的多个独立操作组合起来,以提供更高效的实现;(2) 利用输入中的稀疏性,避免对填充标记执行不必要的操作。在许多用于自然语言处理的 Transformer 模型中,填充标记经常占输入批次的大部分。

向后兼容性

有利的是,无需修改模型即可受益于 BetterTransformer 提供的性能提升。 为了受益于快速路径执行,输入和操作条件必须满足一些访问条件(见下文)。虽然 Transformer API 的内部实现已更改,但 PyTorch 1.12 严格保持与先前版本中提供的 Transformer 模块的兼容性,使 PyTorch 用户能够使用使用先前 PyTorch 版本创建和训练的模型,同时受益于 BetterTransformer 的改进。

除了启用 PyTorch nn.Modules 之外,BetterTransformer 还为 PyTorch 库提供了改进。性能优势将通过两种不同的启用路径实现:

  1. 透明加速: PyTorch nn.Modules(例如 MultiHeadAttention)以及更高级别的 Transformer 组件的当前用户将自动受益于新 nn.Modules 的改进性能。一个例子是 torchvision 库中使用的 视觉 Transformer (ViT) 实现(代码链接)。
  2. Torchtext 库加速: 作为该项目的一部分,我们优化了 Torchtext,使其基于 PyTorch 核心 API 构建,以受益于 BetterTransformer 增强功能,同时保持与先前库版本和使用先前 Torchtext 版本训练的模型的严格透明兼容性。在 Torchtext 中使用 PyTorch Transformer 还确保 Torchtext 将受益于 PyTorch Transformer 实现的预期未来增强功能。

利用快速路径

BetterTransformer 是 PyTorch Transformer API 的快速路径。快速路径是针对 CPU 和 GPU 的关键 Transformer 函数的本机专用实现,适用于常见的 Transformer 用例。

为了利用输入稀疏性(即填充)来加速您的模型(参见图 2),请在实例化 TransformerEncoder 时设置关键字参数 enable_nested_tensor=True,并在推理期间传入 src_key_padding_mask 参数(表示填充标记)。这要求填充掩码是连续的,这是典型情况。

目前,BetterTransformer 的加速仅适用于推理中使用的 Transformer 编码器模型。为了受益于快速路径执行,模型必须由以下任何组件组成:TransformerEncoderTransformerEncoderLayerMultiheadAttention (MHA)。快速路径执行还受某些标准限制。最重要的是,模型必须在推理模式下执行,并对不收集梯度带信息的输入张量进行操作(例如,使用 torch.no_grad 运行)。条件完整列表可在以下链接找到,分别用于 nn.MultiHeadAttentionnn.TransformerEncoder。如果未满足条件,则控制流将转到旧版 PyTorch 1.11 Transformer 实现,该实现具有相同的 API,但缺少快速路径性能提升。

使用 PyTorch MultiheadAttention 模块的其他 Transformer 模型(例如解码器模型)将受益于 BetterTransformer 快速路径。计划中的未来工作是将端到端 BetterTransformer 快速路径扩展到基于 TransformerDecoder 的模型,以支持流行的 seq2seq 和仅解码器(例如,OPT)模型架构,并扩展到训练。

加速

以下图表显示了 BERT-base 模型在小规模和大规模输入下实现的性能。

图 1:PyTorch 1.12 通过 BetterTransformer 快速路径执行实现的改进

图 2:PyTorch 1.12 通过 BetterTransformer 快速路径执行实现的改进
在启用 enable_nested_tensor=True 的情况下进行稀疏性优化

BetterTransformer 包括两种优化类型:(1) 融合内核,在一个内核中更有效地实现多个操作;(2) 通过避免对填充标记进行不必要的处理来利用稀疏性。对于小输入大小,性能提升主要受益于融合内核实现,并且无论填充量如何,都显示出恒定的性能改进。虽然大输入仍然受益于融合内核,但计算密集型处理限制了融合内核可能获得的收益,因为基线性能已经接近理论峰值。然而,随着填充量的增加,性能显著提高,因为通过利用 NLP 工作负载中填充引入的稀疏性,可以避免越来越多的计算。

未来工作

作为 PyTorch BetterTransformer 持续工作的一部分,我们正在努力将 BetterTransformer 的改进扩展到 Transformer 解码器。我们还旨在将范围从推理扩展到训练。

我们正在与 FairSeq、MetaSeq 和 HuggingFace 等其他库合作,以在这些库中启用 BetterTransformer,使所有基于 Transformer 的 PyTorch 模型受益。作为本博客系列的一部分,我们将提供 BetterTransformer 加速在更广泛的 PyTorch 生态系统中的进展的未来更新。

致谢:作者衷心感谢 Lin Qiao、Ajit Mathews、Andrew Tulloch、Dmytro Dzhulgakov、Natalia Gimelshein、Emad El-Haraty、Mark Saroufim、Adnan Aziz、Geeta Chauhan 和 Hamid Shojanazeri 在本项目整个过程中以及在本文撰写过程中提供的支持、贡献和许多有益的建议。