TLDR: 我们展示了在 A100 GPU 上对浮点32 Vision Transformer 的 MLP 模块权重应用块稀疏性后,获得了有前景的结果,速度提升高达 1.46 倍,准确率下降不到 2%。这种方法有望应用于其他类型的 Transformer,包括大型语言模型。我们的实现和基准测试可在 https://github.com/pytorch-labs/superblock 获取,以重现我们的结果。
引言
PyTorch 在实现块稀疏矩阵乘法的 CUDA 内核方面进行了许多改进。PyTorch 的最新更新可以在高稀疏度的大型矩阵乘法形状上实现相对于密集基线高达 4.8 倍的加速。
在这篇博客中,我们展示了在 Vision Transformer (ViT) 的 MLP(多层感知器)层的线性层权重上应用块稀疏性的有前景的结果,并展示了在 A100 英伟达 GPU 上的端到端模型加速。
回顾一下,块稀疏性是以预定大小的块瓦片而不是单个元素的方式稀疏化权重。这种特殊的稀疏模式很有趣,因为它可以通过快速稀疏内核进行 GPU 加速。有关不同稀疏模式之间差异或稀疏性整体的更多信息,请查阅 torchao。

不同类型稀疏性的插图。
方法
我们的方法可以分为两个不同的步骤:
- 使用块稀疏掩码子网从头开始训练模型。
- 将这些掩码折叠到我们的权重中以加速推理。
我们将在下面解释我们的训练和推理步骤。
训练
从一个未初始化的 Vision Transformer 开始,我们在注意力块的输出投影线性层、MLP(即前馈网络,FFN)内部的两个线性层以及最终的线性分类层的权重上应用具有指定块大小和稀疏度的随机可训练掩码。训练期间的前向传播遵循 Supermask 方法,因为每个掩码都使用根据稀疏性要求调整的阈值转换为二进制映射,例如,如果我们想要 80% 的稀疏度,阈值将自动调整以保留前 20% 的权重。掩码是正方形的 <块大小>x<块大小> 元素,其中 <块大小> 是一个超参数。权重的优先级取决于所训练的掩码值或分数。我们 将每层的二进制掩码与权重相乘 以稀疏化模型。

Supermask 稀疏化方法 的插图。
推理
训练完成后,密集权重可以通过与掩码相乘转换为稀疏权重 并存储以进行推理。在此阶段,尽管权重中包含大量零值,但它们仍以密集格式存储。我们使用 PyTorch 的 to_sparse_bsr() API 将权重转换为 块稀疏表示 (BSR) 格式,该格式仅存储非零值及其块的索引。此步骤只需执行一次,结果即可缓存以供运行时使用。
在运行时,无需更改代码。我们只需将任何输入张量传递给模型,当稀疏线性层的 `forward()` 函数被调用时,PyTorch 会自动调用针对块稀疏权重优化的矩阵乘法。这应该适用于 A100 和 H100 NVIDIA GPU。
结果:微基准测试
为了从性能角度验证块稀疏性的可行性,我们首先使用这个 简单脚本 运行了一系列微基准测试。我们使用 ViT-b 的线性形状,比较了我们的块稀疏内核在单个线性层上随权重矩阵稀疏度级别和块大小变化时的加速效果。
我们使用 PyTorch 2.3.0.dev20240305+cu121 nightly 版本在 NVIDIA A100 上运行,并报告了每种稀疏配置相对于密集基线的加速比。我们观察到,对于 float32,当块大小 >=32 或稀疏度 >= 0.8 时,有正向加速;而对于 bfloat16,我们观察到较小的加速,并且通常在块大小为 64 和更高稀疏度时。因此,为了获得模型的端到端加速,本博客将重点关注 float32,并将 bfloat16 留作未来工作。


ViT-b-16 线性层的微基准测试结果。
结果:视觉Transformer
一旦我们确认能够在线性层上实现加速,我们便专注于在 ViT_B_16 上实现端到端加速。
我们使用标准的 ViT_B_16 配方,从头开始在 ImageNet 数据集上训练这个模型。我们展示了 MLP 模块稀疏化的加速,并将注意力输入和输出投影的权重稀疏化留作未来工作。
我们研究了实际推断速度的提升,重点关注批量大小为 256 的情况。我们发现:
- 对于 90% 的稀疏度,块大小分别为 16、32 和 64 时,我们可以获得 1.24 倍、1.37 倍和 1.65 倍的加速。
- 为了获得加速,块大小为 16、32 和 64 的最小稀疏度分别为 0.86、0.82 和 0.7。因此,正如预期的那样,块大小越大,获得加速所需的稀疏度越小。
我们注意到 `sparse_bsr()` API 的一个限制:层维度需要是块大小的倍数。由于 ViT 中最后一个 FC 分类层的维度不是块大小的倍数,因此在我们的实验中它们没有转换为 BSR 表示。

ViT-b-16 在不同批量稀疏度和块大小下,MLP 模块的批量大小为 256 时的加速效果。
我们还探讨了 90% 稀疏度下不同批量大小的加速效果。我们观察到,从批量大小 16 开始,相对于基线都有加速。虽然在最大批量大小下,更大的块大小具有更大的加速效果,但对于更小的块大小,获得 >1 加速的最小批量大小更小。
我们相信设备硬件可以在批量大小为 1 时获得加速,因为它们——与服务器 GPU 不同——可以在如此小的批量大小下得到充分利用。

ViT-b-16 在 90% 稀疏度下,MLP 模块在不同批次大小和块大小下的加速。
查看不同块大小和稀疏度下稀疏化模型在 ImageNet 模糊测试集上的 Top-1 准确率,我们看到了一些预期的结果:
- 低稀疏度(<=70%)对准确率没有明显影响
- 中等稀疏度(>=80% 到 <90%)对准确率有有限的回归
- 高稀疏度(>=90%)移除了太多权重,导致准确率受到显著影响
可以进行更多研究来提高更高稀疏度和更大块大小的准确性。我们希望 PyTorch 中对块稀疏性的支持以及本博客中展示的加速将鼓励研究人员探索更准确的稀疏化方法。

使用 SuperMask 方法在 ImageNet-blurred 上训练 ViT-b-16 的准确性。
下一步
我们已经展示了在 float32 精度下对 ViT 的 MLP 模块进行块稀疏化,取得了可喜的加速。要在 bfloat16 上观察到加速,还有更多工作要做,我们希望很快能在这方面取得进展。进一步优化 Vision Transformer 和通用 Transformer 的块稀疏性的可能下一步措施:
- 对注意力输入和输出投影执行块稀疏化。
- 在微调而不是从头训练期间执行块稀疏化。
- 对 ViT 线性算子特定形状的矩阵乘法内核进行进一步优化(特别是对于 80% 及更低的稀疏度)。
- 与其他优化结合,例如 int8 和 torch.compile()。
- 探索其他权重稀疏化算法,例如 Spartan,以提高准确性。
- 探索选择要稀疏化的权重(例如,特定的 Transformer 层)。
如果您有疑问或对贡献块稀疏化感兴趣,请联系 melhoushi@meta.com!
此外,如果您对稀疏性广泛感兴趣,请随时联系 @jcaip / jessecai@meta.com,并请查看 torchao,这是一个我们正在构建的用于量化和稀疏性等架构优化技术的社区。