作者:Michael Gschwind, Eric Han, Scott Wolchok, Rui Zhu, Christian Puhrsch

太长不看: Transformer 在自然语言处理 (NLP) 领域取得了最先进的性能,并且在许多其他任务中也越来越受欢迎。它们计算成本高昂,这阻碍了其大规模生产化应用。随 PyTorch 1.12 一同发布的 BetterTransformer,为 Transformer Encoder 推理实现了向后兼容的快速路径(Fast Path),该快速路径基于 torch.nn.TransformerEncoder,并且无需模型作者修改其模型。在许多常见执行场景中,BetterTransformer 的改进可以将速度和吞吐量提升两倍以上。要使用 BetterTransformer,请安装 PyTorch 1.12,立即开始使用 PyTorch API 来运行高质量、高性能的 Transformer 模型。

Transformer Encoder 架构图(摘自“Attention Is All You Need”)。在推理过程中,整个模块将作为一个单独的 PyTorch 原生函数执行。

在这篇博文中,我们将分享以下主题:性能改进、向后兼容性以及如何利用快速路径。您可以在下方了解更多这些主题的内容。

性能改进

BetterTransformer 提供了 MultiHeadAttention 和 TransformerEncoderLayer 在 CPU 和 GPU 上的加速原生实现。这些快速路径已集成到标准的 PyTorch Transformer API 中,并将加速 TransformerEncoderTransformerEncoderLayerMultiHeadAttention nn.modules。这些新模块实现了两种类型的优化:(1)融合内核(fused kernels)将通常用于实现 Transformer 的多个独立算子合并,提供更高效的实现;(2)利用输入中的稀疏性(sparsity),避免对填充标记(padding tokens)执行不必要的操作。在许多用于自然语言处理的 Transformer 模型中,填充标记经常占输入批次很大一部分。

向后兼容性

有利的一点是,无需对模型进行任何修改即可从 BetterTransformer 提供的性能提升中受益。要从快速路径执行中受益,输入和操作条件必须满足一些访问条件(见下文)。尽管 Transformer API 的内部实现已更改,但 PyTorch 1.12 严格兼容先前版本中提供的 Transformer 模块,这使得 PyTorch 用户可以使用在先前 PyTorch 版本中创建和训练的模型,同时也能从 BetterTransformer 的改进中受益。

除了启用 PyTorch nn.Modules 外,BetterTransformer 还为 PyTorch 库提供了改进。性能优势将通过两种不同的启用途径获得

  1. 透明加速:MultiHeadAttention 等 PyTorch nn.Modules 以及更高级别 Transformer 组件的现有用户将自动受益于新 nn.Modules 提升的性能。一个例子是 torchvision 库中使用的视觉 Transformer (ViT) 实现(代码链接)。

  2. Torchtext 库加速:作为本项目的一部分,我们优化了 Torchtext,使其基于 PyTorch 核心 API 构建,从而受益于 BetterTransformer 的增强功能,同时与先前的库版本以及使用先前 Torchtext 版本训练的模型保持严格和透明的兼容性。在 Torchtext 中使用 PyTorch Transformer 也确保了 Torchtext 将受益于 PyTorch Transformer 实现未来预期的增强功能。

利用快速路径

BetterTransformer 是 PyTorch Transformer API 的快速路径(Fast Path)。快速路径是针对 CPU 和 GPU 的关键 Transformer 函数的本地、专门实现,适用于常见的 Transformer 用例。

为了利用输入稀疏性(即填充)来加速您的模型(参见图2),在实例化 TransformerEncoder 时,设置关键字参数 enable_nested_tensor=True,并在推理期间传入 src_key_padding_mask 参数(表示填充标记)。这要求填充掩码是连续的,这是典型情况。

目前,BetterTransformer 的加速仅适用于用于推理的 Transformer Encoder 模型。要从快速路径执行中受益,模型必须由以下任一组件构成:TransformerEncoderTransformerEncoderLayerMultiheadAttention (MHA)。快速路径执行也受限于一些条件。最重要的是,模型必须在推理模式下执行,并且在不收集梯度信息(例如,使用 torch.no_grad 运行)的输入张量上操作。完整的条件列表可在以下链接中找到,分别对应 nn.MultiHeadAttentionnn.TransformerEncoder。如果不满足条件,控制流将回到传统的 PyTorch 1.11 Transformer 实现,该实现具有相同的 API,但缺乏快速路径的性能提升。

使用 PyTorch MultiheadAttention 模块的其他 Transformer 模型(例如 Decoder 模型)也将受益于 BetterTransformer 的快速路径。未来的计划工作是将端到端 BetterTransformer 快速路径扩展到基于 TransformerDecoder 的模型,以支持流行的 seq2seq 和仅 Decoder(例如 OPT)模型架构,并扩展到训练过程。

加速效果

以下图表展示了 BERT-base 模型在小型和大规模输入下的性能表现

图 1:PyTorch 1.12 通过 BetterTransformer 快速路径执行实现的改进

图 2:PyTorch 1.12 通过 BetterTransformer 快速路径执行实现的改进
启用 enable_nested_tensor=True 后实现的稀疏性优化

BetterTransformer 包含两种类型的优化:(1)融合内核,在一个内核中更有效地实现多个操作;(2)通过避免对填充标记执行不必要的处理来利用稀疏性。对于小型输入尺寸,性能提升主要受益于融合内核实现,并且无论填充量如何,性能改进都是恒定的。虽然大型输入仍然受益于融合内核,但计算密集型处理限制了融合内核可以带来的好处,因为基准性能已经接近理论峰值。然而,随着我们增加填充量,性能会显著提升,因为通过利用自然语言处理 (NLP) 工作负载中填充引入的稀疏性,可以避免越来越多的计算量。

未来工作

作为我们在 PyTorch BetterTransformer 上持续工作的一部分,我们正在努力将 BetterTransformer 的改进扩展到 Transformer Decoders。我们的目标是将加速范围从推理扩展到训练。

我们正在合作,以便在 FairSeq、MetaSeq 和 HuggingFace 等其他库上启用 BetterTransformer,从而使所有基于 Transformer 的 PyTorch 模型受益。作为本博客系列的一部分,我们将提供关于 BetterTransformer 加速在更广泛的 PyTorch 生态系统中的进展的未来更新。

致谢:作者谨感谢 Lin Qiao、Ajit Mathews、Andrew Tulloch、Dmytro Dzhulgakov、Natalia Gimelshein、Emad El-Haraty、Mark Saroufim、Adnan Aziz、Geeta Chauhan 和 Hamid Shojanazeri 在本项目进行过程中以及在准备这篇博客时提供的支持、贡献和许多有益的建议。