在过去的一年中,我们已将半结构化 (2:4) 稀疏性支持添加到 PyTorch 中。只需几行代码,我们就能通过用稀疏矩阵乘法替换密集矩阵乘法,在 segment-anything 上实现 10% 的端到端推理加速。
然而,矩阵乘法并非神经网络推理所独有——训练期间也会发生。通过扩展我们之前用于加速推理的核心原语,我们也能够加速模型训练。我们编写了一个替代的 nn.Linear 层,即 SemiSparseLinear
,它能够在 NVIDIA A100 上,对 ViT-L 的 MLP 块中线性层的前向+反向传播实现 1.3 倍的 加速。
端到端来看,对于 DINOv2 ViT-L 训练,我们观察到墙钟时间减少了 6%,并且几乎没有开箱即用的精度下降(ImageNet top-1 精度为 82.8 对比 82.7)。

我们比较了在 4 块 NVIDIA A100 上训练 ViT 模型 12.5 万次迭代的 2 种策略:要么完全密集(蓝色),要么稀疏训练 70% 的时间,然后密集(橙色)。两者在基准测试中都取得了相似的结果,但稀疏变体训练速度快了 6%。对于这两个实验,我们都评估了有稀疏性和没有稀疏性的中间检查点。
据我们所知,这是加速稀疏训练的首次开源实现,我们很高兴能在 torchao 中提供用户 API。您只需几行代码即可尝试加速您自己的训练运行。
# Requires torchao and pytorch nightlies and CUDA compute capability 8.0+
import torch
from torchao.sparsity.training import (
SemiSparseLinear,
swap_linear_with_semi_sparse_linear,
)
model = torch.nn.Sequential(torch.nn.Linear(1024, 4096)).cuda().half()
# Specify the fully-qualified-name of the nn.Linear modules you want to swap
sparse_config = {
"seq.0": SemiSparseLinear
}
# Swap nn.Linear with SemiSparseLinear, you can run your normal training loop after this step
swap_linear_with_semi_sparse_linear(model, sparse_config)
这是如何实现的?
稀疏性背后的核心思想很简单:跳过涉及零值张量元素的计算,以加速矩阵乘法。然而,简单地将权重设置为零是不够的,因为密集张量仍然包含这些修剪过的元素,并且密集矩阵乘法核将继续处理它们,导致相同的延迟和内存开销。为了实现实际的性能提升,我们需要用智能地绕过涉及修剪过的元素计算的稀疏核来替换密集核。
这些核作用于稀疏矩阵,它们移除被剪枝的元素,并以压缩格式存储指定的元素。存在许多不同的稀疏格式,但我们特别关注半结构化稀疏性,也称为2:4 结构化稀疏性或细粒度结构化稀疏性,更一般地称为N:M 结构化稀疏性。

2:4 稀疏压缩表示。原始 来源
2:4 稀疏矩阵是一种矩阵,其中每 4 个元素中最多有 2 个非零元素,如上图所示。半结构化稀疏性之所以吸引人,是因为它在性能和精度之间找到了一个“金发姑娘”点。
- NVIDIA GPU 自 Ampere 架构以来就为此格式提供了硬件加速和库支持(cuSPARSELt),矩阵乘法速度最高可提升 1.6 倍。
- 修剪模型以适应这种稀疏模式,其精度下降不如其他模式严重。NVIDIA 的 白皮书 显示,先修剪后重新训练能够恢复大多数视觉模型的准确性。

NVIDIA GPU 上 2:4 (稀疏) 矩阵乘法的图示。原始 来源
使用半结构化稀疏性加速推理是直接的。由于我们的权重在推理期间是固定的,我们可以提前(离线)修剪和压缩权重,并存储压缩后的稀疏表示而不是我们的密集张量。

然后,我们不再调度到密集矩阵乘法,而是调度到稀疏矩阵乘法,传入压缩的稀疏权重而不是正常的密集权重。有关使用 2:4 稀疏性加速模型推理的更多信息,请参阅我们的教程。
将稀疏推理加速扩展到训练
为了利用稀疏性来减少模型训练时间,我们需要考虑何时计算掩码,因为一旦我们存储了压缩表示,掩码就是固定的。
使用固定掩码应用于现有训练好的密集模型进行训练(也称为剪枝)不会降低精度,但这需要两次训练运行——一次是为了获得密集模型,另一次是为了使其稀疏,这并不能提供加速。
相反,我们希望从头开始训练一个稀疏模型(动态稀疏训练),但从头开始使用固定掩码进行训练将导致评估结果显著下降,因为稀疏性掩码将在初始化时选择,而此时模型权重基本上是随机的。
为了在从头训练时保持模型的准确性,我们在运行时修剪和压缩权重,以便在训练过程的每一步计算最佳掩码。
从概念上讲,您可以将我们的方法视为一种近似矩阵乘法技术,我们 `prune_and_compress`
并分发到 `sparse_GEMM`
,其所需时间少于 `dense_GEMM`
调用所需的时间。这很困难,因为原生剪枝和压缩函数太慢,无法显示加速。
根据我们的 ViT-L 训练矩阵乘法形状(13008x4096x1024),我们测量了密集 GEMM 和稀疏 GEMM 的运行时间分别为 538 微秒和 387 微秒。换句话说,权重矩阵的剪枝和压缩步骤必须在小于 538-387=151 微秒内运行才能获得任何效率增益。不幸的是,cuSPARSELt 中提供的压缩内核已经需要 380 微秒(甚至不考虑剪枝步骤!)。
考虑到 NVIDIA A100 的最大内存 IO (2TB/s),并考虑到剪枝和压缩核将受内存限制,我们理论上可以在 4 微秒内(8MB / 2TB/s!)剪枝和压缩我们的权重 (4096x1024x2 字节=8MB)!事实上,我们能够编写一个核,将矩阵剪枝并压缩成 2:4 稀疏格式,并且在 36 微秒内运行(比 cuSPARSELt 中的压缩核快 10 倍),使得整个 GEMM(包括稀疏化)更快。我们的核可在 PyTorch 中使用。

我们的定制稀疏化内核,包括剪枝 + 压缩,在线性层前向+后向传播中快约 30%。基准测试在 NVIDIA A100-80GB GPU 上运行。
编写高性能运行时稀疏化内核
为了实现高性能运行时稀疏化内核,我们面临了多项挑战,下面将对此进行探讨。
1) 处理反向传播
对于反向传播,我们需要计算 dL/dX 和 dL/dW 以进行梯度更新和后续层,这意味着我们需要分别计算 xWT 和 xTW。

训练加速运行时稀疏化概述 (前向 + 后向传播)
然而,这有问题,因为压缩表示不能转置,因为无法保证张量在两个方向上都是 2:4 稀疏的。

两个矩阵都是有效的 2:4 矩阵。然而,右边的矩阵一旦转置就不再是有效的 2:4 矩阵,因为其中一列包含超过 2 个元素。
因此,我们剪枝一个 4×4 的瓦片,而不是一个 1×4 的条带。我们贪婪地保留最大值,确保每行/每列最多取 2 个值。虽然这种方法不能保证是最优的,因为我们有时只保留 7 个值而不是 8 个,但它有效地计算出一个在行方向和列方向都是 2:4 稀疏的张量。
然后我们压缩打包张量和打包转置张量,存储转置张量用于反向传播。通过同时计算打包张量和打包转置张量,我们避免了反向传播中的二次核调用。

我们的内核在寄存器中剪枝权重矩阵,并将压缩值写入全局内存。它还同时剪枝 W.t,这是反向传播所需的,从而最大限度地减少了内存 IO。
为了处理反向传播,还需要一些额外的转置技巧——底层硬件只支持第一个矩阵是稀疏的操作。对于推理期间的权重稀疏化,当我们需要计算 xWT 时,我们依赖转置属性来交换操作数的顺序。

在推理期间,我们使用 torch.compile
将外部转置融合到后续的逐点操作中,以避免性能损失。
然而在训练的反向传播情况下,我们没有后续的逐点操作可以融合。相反,我们通过利用 cuSPARSELt 指定结果矩阵的行/列布局的能力,将转置融合到我们的矩阵乘法中。
2) 用于高效内存 I/O 的内核分块
为了使我们的内核尽可能高效,我们希望合并读/写操作,因为我们发现内存 IO 是主要的瓶颈。这意味着在 CUDA 线程中,我们希望一次读取/写入 128 字节的块,这样多个并行读/写可以被 GPU 内存控制器合并为单个请求。
因此,我们不再让一个线程处理一个只有 4x4x2 = 32 字节的 4x4 瓦片,而是决定每个线程处理 4 个 4x4 瓦片(即一个 8x8 瓦片),这使我们能够操作 8x8x2 = 128 字节的块。

3) 在 4x4 瓦片中无分支发散排序元素
对于我们线程中的每个独立的 4x4 瓦片,我们计算一个位掩码,指定哪些元素要剪枝,哪些元素要保留。为此,我们对所有 16 个元素进行排序并贪婪地保留元素,只要它们不违反我们的 2:4 行/列约束。这只保留值最大的权重。
至关重要的是,我们观察到我们始终只对固定数量的元素进行排序,因此通过使用无分支的排序网络,我们可以避免分支发散。

为清晰起见,转置打包张量和元数据已省略。排序网络图取自维基百科。
当我们在线程块中进行条件执行时,就会发生线程束发散。在 CUDA 中,同一工作组(线程块)中的工作项在硬件层面以批次(线程束)的形式调度。如果存在条件执行,使得同一批次中的某些工作项运行不同的指令,那么当调度线程束时,它们会被掩蔽,或者按顺序调度。
例如,如果我们有一些代码,如 if (condition) do(A) else do(B)
,其中条件被所有奇数编号的工作项满足,则此条件语句的总运行时间为 do(A) + do(B)
,因为我们将为所有奇数编号的工作项调度 do(A)
,屏蔽偶数编号的工作项,并为所有偶数编号的工作项调度 do(B)
,屏蔽奇数编号的工作项。这个 答案 提供了有关线程束发散的更多信息。
4) 写入压缩矩阵和元数据
一旦计算出位掩码,权重数据必须以压缩格式写入全局内存。这并非易事,因为数据需要保留在寄存器中,并且不可能对寄存器进行索引(例如 C[i++] = a
阻止我们将 C
存储在寄存器中)。此外,我们发现 nvcc
使用的寄存器数量比我们预期的要多得多,这导致寄存器溢出并影响了整体性能。我们将此压缩矩阵以列主格式写入全局内存,以使写入更高效。

我们还需要写入 cuSPARSELt 元数据。此元数据布局与开源 CUTLASS 库的布局非常相似,并经过优化,可通过 GEMM 内核中的共享内存使用 PTX ldmatrix
指令高效加载。
然而,这种布局并未针对高效写入进行优化:元数据张量的前 128 位包含行 0、8、16 和 24 的前 32 列的元数据。请记住,每个线程处理一个 8×8 的瓦片,这意味着这些信息分散在 16 个线程中。
我们依靠一系列线程束混洗操作,分别用于原始和转置表示,以写入元数据。幸运的是,这些数据占总 I/O 的比例不到 10%,因此我们不必完全合并写入。
DINOv2 稀疏训练:实验设置和结果
在我们的实验中,ViT-L 模型使用 DINOv2 方法在 ImageNet 上训练了 125k 步。所有实验均在 4 个 AMD EPYC 7742 64 核 CPU 和 4 个 NVIDIA A100-80GB GPU 上运行。在稀疏训练期间,模型在训练的第一部分启用了 2:4 稀疏性,其中只有一半的权重被启用。权重的稀疏性掩码在每一步都会动态重新计算,因为权重在优化过程中会不断更新。在剩余的步骤中,模型进行密集训练,生成一个最终没有 2:4 稀疏性的模型(除了 100% 稀疏训练设置),然后对其进行评估。
训练设置 | ImageNet 1k 对数回归 |
0% 稀疏(12.5万密集步,基线) | 82.8 |
40% 稀疏(5万稀疏 -> 7.5万密集步) | 82.9 |
60% 稀疏(7.5万稀疏 -> 5万密集步) | 82.8 |
70% 稀疏(8.75万稀疏 -> 3.75万密集步) | 82.7 |
80% 稀疏(10万稀疏 -> 2.5万密集步) | 82.7 |
90% 稀疏(11.25万稀疏 -> 1.25万密集步) | 82.0 |
100% 稀疏(12.5万稀疏步) | 82.3 (2:4 稀疏模型) |

在稀疏训练步骤中,在反向传播过程中,我们获得了稀疏权重的密集梯度。为了使梯度下降合理,我们也应该在将其用于优化器更新权重之前将此梯度稀疏化。我们没有这样做,而是使用完整的密集梯度来更新权重——我们发现这在实践中效果更好:这就是 STE(Straight Through Estimator)策略。换句话说,我们每一步都更新所有参数,即使是那些我们不使用的参数。
结论与未来工作
在这篇博客文章中,我们展示了如何利用半结构化稀疏性加速神经网络训练,并解释了我们面临的一些挑战。我们成功地在 DINOv2 训练中实现了 6% 的端到端加速,同时精度仅下降了 0.1 个百分点。
这项工作有几个扩展领域:
- 扩展到新的稀疏模式:研究人员已经创建了新的稀疏模式,如 V:N:M 稀疏性,它利用底层的半结构化稀疏核以提供更大的灵活性。这对于将稀疏性应用于 LLM 尤其有趣,因为 2:4 稀疏性会过度降低准确性,但我们已经看到了更通用的 N:M 模式的一些积极 结果。
- 稀疏微调的性能优化:本文涵盖了从头开始的稀疏训练,但通常我们希望微调一个基础模型。在这种情况下,静态掩码可能足以保持准确性,这将使我们能够进行额外的性能优化。
- 更多关于剪枝策略的实验:我们会在网络的每一步计算掩码,但每 n 步计算一次掩码可能会产生更好的训练精度。总的来说,找出在使用半结构化稀疏性进行训练时的最佳策略是一个开放的研究领域。
- 与 fp8 的兼容性:硬件也支持 fp8 半结构化稀疏性,原则上这种方法应该与 fp8 类似地工作。实际上,我们需要编写类似的稀疏化内核,并且可能将其与张量的缩放融合。
- 激活稀疏性:高效的稀疏化内核还可以实现训练期间的激活稀疏化。由于稀疏化开销随着稀疏矩阵的大小线性增长,与权重张量相比,具有大激活张量的设置可以从激活稀疏性中获得比权重稀疏性更大的好处。此外,由于 ReLU 或 GELU 激活函数的使用,激活自然是稀疏的,从而降低了精度下降。
如果您对此类问题感兴趣,请随时在 torchao 中提出 issue/PR,这是一个我们正在为量化和稀疏性等架构优化技术构建的社区。此外,如果您对稀疏性有普遍兴趣,请在 CUDA-MODE(#sparsity)中联系我们。