过去一年中,我们为 PyTorch 增加了对半结构化 (2:4) 稀疏性的支持。通过用稀疏矩阵乘法替代稠密矩阵乘法,仅需几行代码,我们就让 segment-anything 的端到端推理速度提升了 10%。
然而,矩阵乘法并非神经网络推理所独有,它们在训练过程中同样频繁发生。通过扩展我们之前用于加速推理的核心原语,我们也能够加速模型训练。我们编写了一个替代 nn.Linear 的层——SemiSparseLinear,它在 NVIDIA A100 上运行 ViT-L 的 MLP 块的线性层的前向和反向传播时,能够实现 1.3 倍的加速。
在端到端层面,DINOv2 ViT-L 的训练耗时减少了 6%,且几乎没有精度损失(ImageNet top-1 准确率从 82.8% 变为 82.7%)。

我们对比了在 4x NVIDIA A100 上将 ViT 模型训练 12.5 万次迭代的两种策略:完全稠密(蓝色),或 70% 的训练时间使用稀疏训练,随后转为稠密(橙色)。两种策略在基准测试中均取得了相似的结果,但稀疏变体训练速度快了 6%。对于这两项实验,我们均评估了使用和不使用稀疏性的中间检查点。
据我们所知,这是首个开源的加速稀疏训练实现,我们非常高兴能在 torchao 中提供用户 API。你只需几行代码即可尝试加速你自己的训练任务。
# Requires torchao and pytorch nightlies and CUDA compute capability 8.0+
import torch
from torchao.sparsity.training import (
SemiSparseLinear,
swap_linear_with_semi_sparse_linear,
)
model = torch.nn.Sequential(torch.nn.Linear(1024, 4096)).cuda().half()
# Specify the fully-qualified-name of the nn.Linear modules you want to swap
sparse_config = {
"seq.0": SemiSparseLinear
}
# Swap nn.Linear with SemiSparseLinear, you can run your normal training loop after this step
swap_linear_with_semi_sparse_linear(model, sparse_config)
这是如何工作的?
稀疏性的基本理念很简单:跳过涉及零值张量元素的计算,以加速矩阵乘法。然而,仅仅将权重置为零是不够的,因为稠密张量仍然包含这些被剪枝的元素,稠密矩阵乘法内核仍会处理它们,从而产生相同的延迟和内存开销。为了实现实际的性能提升,我们需要用稀疏内核替换稠密内核,这些内核能够智能地跳过涉及被剪枝元素的计算。
这些内核作用于稀疏矩阵,通过移除被剪枝的元素并以压缩格式存储指定元素来工作。稀疏格式有很多种,但我们特别关注半结构化稀疏性,也称为 2:4 结构化稀疏性、细粒度结构化稀疏性,或更通用的 N:M 结构化稀疏性。

2:4 稀疏压缩表示。原始来源
2:4 稀疏矩阵是指每 4 个元素中最多有 2 个非零元素的矩阵,如上图所示。半结构化稀疏性之所以具有吸引力,是因为它在性能和准确性之间找到了一个“金发姑娘点”(即最佳平衡点)。
- 自 Ampere 架构以来的 NVIDIA GPU 为这种格式提供了硬件加速和库支持 (cuSPARSELt),矩阵乘法速度提升最高可达 1.6 倍。
- 将模型剪枝以符合这种稀疏模式,其准确率下降程度远小于其他模式。NVIDIA 的白皮书显示,剪枝后重新训练可以恢复大多数视觉模型的准确率。

NVIDIA GPU 上 2:4(稀疏)矩阵乘法演示。原始来源
利用半结构化稀疏性加速推理非常直接。由于权重在推理过程中是固定的,我们可以预先(离线)剪枝并压缩权重,然后存储压缩后的稀疏表示,而不是稠密张量。

然后,我们不再分发给稠密矩阵乘法,而是分发给稀疏矩阵乘法,传入压缩后的稀疏权重而非普通的稠密权重。关于使用 2:4 稀疏性加速模型推理的更多信息,请参考我们的教程。
将稀疏推理加速扩展到训练任务
为了利用稀疏性减少模型训练时间,我们需要考虑掩码(mask)的计算时机,因为一旦存储了压缩表示,掩码就固定了。
将固定掩码应用于现有的已训练稠密模型(也称为剪枝)不会降低精度,但这需要进行两次训练——一次获得稠密模型,另一次使其稀疏,这无法提供任何加速。
相反,我们希望从零开始训练一个稀疏模型(动态稀疏训练),但如果使用固定掩码从头训练,会导致评估效果显著下降,因为稀疏掩码是在初始化时选定的,而此时模型权重基本上是随机的。
为了在从头开始训练时保持模型准确率,我们在运行时剪枝并压缩权重,这样我们就可以在训练过程的每一步计算出最优掩码。
从概念上讲,你可以将我们的方法视为一种近似矩阵乘法技术,即我们在比 dense_GEMM 调用更短的时间内完成 prune_and_compress(剪枝和压缩)并分发给 sparse_GEMM。这很难实现,因为原生的剪枝和压缩函数太慢,无法体现加速效果。
鉴于我们 ViT-L 训练中矩阵乘法的维度(13008x4096x1024),我们测得稠密和稀疏 GEMM 的运行时间分别为 538us 和 387us。换句话说,权重矩阵的剪枝和压缩步骤必须在少于 538-387=151us 的时间内完成才能获得效率提升。不幸的是,cuSPARSELt 提供的压缩内核本身就需要 380us(这还没算上剪枝步骤!)。
考虑到 NVIDIA A100 的最大内存 IO (2TB/s),且剪枝和压缩内核是内存受限的,理论上我们可以在 4us 内完成权重的剪枝和压缩(4096x1024x2 字节=8MB,8MB / 2TB/s = 4us)!事实上,我们编写了一个内核,能以 36us 的速度将矩阵修剪并压缩为 2:4 稀疏格式(比 cuSPARSELt 中的压缩内核快 10 倍),从而使整个 GEMM(包含稀疏化过程)速度更快。我们的内核现已在 PyTorch 中可用。

我们的自定义稀疏化内核(包含剪枝 + 压缩)在线性层前向+反向传播中快了约 30%。基准测试在 NVIDIA A100-80GB GPU 上运行。
编写高性能的运行时稀疏化内核
在实现高性能运行时稀疏化内核的过程中,我们面临多个挑战,下面将逐一探讨。
1) 处理反向传播
在反向传播中,我们需要为梯度更新和后续层计算 dL/dX 和 dL/dW,这意味着我们需要分别计算 xWT 和 xTW。

训练加速的运行时稀疏化概述 (前向 + 反向传播)
然而,这存在问题,因为压缩表示无法转置,因为无法保证张量在两个方向上都保持 2:4 稀疏。

两个矩阵都是有效的 2:4 矩阵。然而,右侧的矩阵在转置后不再是有效的 2:4 矩阵,因为其中一列包含超过 2 个元素。
因此,我们不再剪枝 1×4 的条带,而是剪枝 4×4 的块。我们贪心地保留最大的值,确保每行/每列最多取 2 个值。虽然这种方法不能保证最优(有时我们只保留 7 个值而不是 8 个),但它能高效地计算出一个在行向和列向均为 2:4 稀疏的张量。
然后,我们同时压缩打包后的张量和打包后的转置张量,并存储转置张量供反向传播使用。通过同时计算打包张量和打包后的转置张量,我们避免了反向传播中的二次内核调用。

我们的内核在寄存器中剪枝权重矩阵,并将压缩后的值写入全局内存。它同时还剪枝了反向传播所需的 W.t,最大限度地减少了内存 IO。
处理反向传播还需要一些额外的转置技巧——底层硬件仅支持第一个矩阵为稀疏的操作。对于推理期间的权重稀疏化,当我们计算 xWT 时,我们利用转置属性来交换操作数的顺序。

在推理期间,我们使用 torch.compile 将外部转置与后续的点运算融合,以避免产生性能损失。
然而,在训练的反向传播中,没有后续的点运算可供融合。相反,我们利用 cuSPARSELt 指定结果矩阵行/列布局的能力,将转置过程融合到矩阵乘法中。
2) 用于高效内存 IO 的内核平铺 (Kernel Tiling)
为了使我们的内核尽可能高效,我们需要合并读/写操作,因为我们发现内存 IO 是主要的瓶颈。这意味着在一个 CUDA 线程中,我们希望一次读取/写入 128 字节的数据块,这样多个并行读/写操作就可以由 GPU 内存控制器合并为一个请求。
因此,我们决定每个线程处理 4 个 4×4 块(即 8×8 块),而不是让线程处理单个 4×4 块(仅 4x4x2 = 32 字节),这允许我们操作 8x8x2 = 128 字节的数据块。

3) 在 4×4 块内进行排序且不产生 warp 散度
对于线程内每个单独的 4×4 块,我们计算一个位掩码,指定哪些元素保留,哪些剪枝。为了做到这一点,我们对所有 16 个元素进行排序,并贪心地保留元素,只要它们不破坏 2:4 行/列约束。这样只保留了数值最大的权重。
关键在于我们观察到排序的元素数量是固定的,因此通过使用无分支的排序网络 (sorting network),我们可以避免 warp 散度。

为清晰起见,省略了转置后的打包张量和元数据。排序网络图摘自 Wikipedia。
Warp 散度发生在线程块内部出现条件执行时。在 CUDA 中,同一工作组(线程块)中的任务是在硬件级别以批次(warps)形式分发的。如果存在条件执行,使得同一批次中的部分任务运行不同的指令,则这些任务在分发时会被掩盖或顺序执行。
例如,如果代码类似于 if (condition) do(A) else do(B),并且所有奇数编号的任务都满足条件,则该条件语句的总运行时间为 do(A) + do(B),因为我们需要为所有奇数任务分发 do(A)(掩盖偶数任务),并为所有偶数任务分发 do(B)(掩盖奇数任务)。此解答提供了关于 warp 散度的更多信息。
4) 写入压缩后的矩阵和元数据
位掩码计算完成后,权重数据必须以压缩格式写回全局内存。这并不简单,因为数据需要保留在寄存器中,且寄存器无法索引(例如 C[i++] = a 会阻止我们存储寄存器中的 C)。此外,我们发现 nvcc 使用了比预期更多的寄存器,导致寄存器溢出并影响了全局性能。我们将压缩后的矩阵以列优先 (Column-Major) 格式写入全局内存,以使写入更高效。

我们还需要写入 cuSPARSELt 元数据。该元数据布局与开源 CUTLASS 库非常相似,并针对通过 GEMM 内核中的共享内存使用 PTX ldmatrix 指令进行高效加载进行了优化。
然而,这种布局并未针对高效写入进行优化:元数据张量的前 128 位包含关于第 0、8、16 和 24 行的前 32 列的元数据。回想一下,每个线程处理一个 8×8 块,这意味着这些信息分散在 16 个线程中。
我们依赖一系列 warp-shuffle 操作(分别针对原始和转置表示)来写入元数据。幸运的是,这部分数据占总 IO 的不到 10%,因此我们不需要完全合并这些写入操作。
DINOv2 稀疏训练:实验设置和结果
在我们的实验中,ViT-L 模型使用 DINOv2 方法在 ImageNet 上训练了 12.5 万步。所有实验均在 4x AMD EPYC 7742 64核 CPU 和 4x NVIDIA A100-80GB GPU 上运行。在稀疏训练期间,模型在训练初期启用 2:4 稀疏性,此时仅激活一半权重。该权重稀疏掩码在每一步都会动态重新计算,因为权重在优化过程中不断更新。在剩余步骤中,模型以稠密方式训练,生成最终无 2:4 稀疏性的模型(100% 稀疏训练设置除外),然后进行评估。
| 训练设置 | ImageNet 1k 逻辑回归 |
| 0% 稀疏(12.5万步稠密训练,基准) | 82.8 |
| 40% 稀疏(5万步稀疏 -> 7.5万步稠密) | 82.9 |
| 60% 稀疏(7.5万步稀疏 -> 5万步稠密) | 82.8 |
| 70% 稀疏(8.75万步稀疏 -> 3.75万步稠密) | 82.7 |
| 80% 稀疏(10万步稀疏 -> 2.5万步稠密) | 82.7 |
| 90% 稀疏(11.25万步稀疏 -> 1.25万步稠密) | 82.0 |
| 100% 稀疏(12.5万步稀疏训练) | 82.3 (2:4-稀疏模型) |

在稀疏训练步骤中,我们在反向传播中获得稀疏权重的稠密梯度。为了使梯度下降合理,我们也应该在优化器更新权重之前对梯度进行稀疏化。但我们没有这样做,而是使用完整的稠密梯度来更新权重——我们发现这种方法在实践中效果更好:即 STE (直通估计器) 策略。换句话说,我们在每一步都更新所有参数,即使是那些我们不使用的参数。
结论与未来工作
在本博文中,我们展示了如何利用半结构化稀疏性加速神经网络训练,并解释了我们面临的一些挑战。我们在 DINOv2 训练上实现了 6% 的端到端加速,且仅有 0.1 个百分点的精度下降。
这项工作还有几个扩展方向:
- 扩展到新的稀疏模式: 研究人员创造了如 V:N:M 稀疏等新模式,它们利用底层的半结构化稀疏内核以获得更大的灵活性。这对将稀疏性应用于 LLM 尤其有意义,因为 2:4 稀疏性会导致过多的精度损失,但我们已经看到针对更通用的 N:M 模式的一些积极结果。
- 针对稀疏微调的性能优化: 本文涵盖了从零开始的稀疏训练,但很多时候我们想要微调基础模型。在这种情况下,静态掩码可能足以保持准确性,这将使我们能够进行额外的性能优化。
- 关于剪枝策略的更多实验: 我们在网络的每一步计算掩码,但每 n 步计算一次掩码可能会产生更好的训练准确性。总的来说,寻找训练期间使用半结构化稀疏性的最佳策略是一个开放的研究领域。
- 与 fp8 的兼容性: 硬件也支持 fp8 半结构化稀疏性,这种方法原则上同样适用于 fp8。在实践中,我们需要编写类似的稀疏化内核,并可能将它们与张量缩放融合。
- 激活稀疏性: 高效的稀疏化内核也能够对训练期间的激活值进行稀疏化。由于稀疏化开销随稀疏矩阵大小线性增长,相比权重张量,具有较大激活张量的设置可能从激活稀疏性中受益更多。此外,由于 ReLU 或 GELU 激活函数的使用,激活值天生是稀疏的,从而降低了准确性损失。
如果你对这些问题感兴趣,欢迎在 torchao 中提出 issue 或 PR,这是一个我们正在为量化和稀疏等架构优化技术构建的社区。此外,如果你对稀疏性有广泛兴趣,请加入 CUDA-MODE (#sparsity) 进行交流。