在过去的一年里,我们已在 PyTorch 中添加了对半结构化 (2:4) 稀疏性的支持。只需几行代码,我们就能通过用稀疏矩阵乘法替换密集矩阵乘法,在 segment-anything 上实现 10% 的端到端推理加速。
然而,矩阵乘法并非仅限于神经网络推理,训练过程中也会发生。通过扩展我们之前用于加速推理的核心原语,我们还能加速模型训练。我们编写了一个替代的 nn.Linear 层 SemiSparseLinear
,它能够在 NVIDIA A100 上实现 ViT-L MLP 块中线性层的前向 + 后向传播 1.3 倍的 加速。
总体而言,我们在 DINOv2 ViT-L 训练中观察到实际运行时间减少了 6%,且开箱即用的精度几乎没有下降(ImageNet top-1 精度分别为 82.8 和 82.7)。
我们比较了在 4 个 NVIDIA A100 上训练 ViT 模型 125k 迭代的两种策略:全密集(蓝色)或 70% 的训练过程使用稀疏性,然后转为密集(橙色)。两者在基准测试中都取得了相似的结果,但稀疏变体训练速度快了 6%。对于这两项实验,我们都评估了中间检查点,无论是否应用稀疏性。
据我们所知,这是首个开源实现的加速稀疏训练,我们很高兴在 torchao 中提供用户 API。您只需几行代码即可尝试加速自己的训练。
# Requires torchao and pytorch nightlies and CUDA compute capability 8.0+
import torch
from torchao.sparsity.training import (
SemiSparseLinear,
swap_linear_with_semi_sparse_linear,
)
model = torch.nn.Sequential(torch.nn.Linear(1024, 4096)).cuda().half()
# Specify the fully-qualified-name of the nn.Linear modules you want to swap
sparse_config = {
"seq.0": SemiSparseLinear
}
# Swap nn.Linear with SemiSparseLinear, you can run your normal training loop after this step
swap_linear_with_semi_sparse_linear(model, sparse_config)
工作原理是什么?
稀疏性背后的核心思想很简单:跳过涉及零值张量元素的计算以加速矩阵乘法。然而,仅仅将权重设为零是不够的,因为密集张量仍然包含这些剪枝的元素,并且密集矩阵乘法核将继续处理它们,从而产生相同的延迟和内存开销。为了实现实际的性能提升,我们需要用智能地绕过涉及剪枝元素的计算的稀疏核来替换密集核。
这些核处理稀疏矩阵,它们移除剪枝的元素并将指定的元素存储在压缩格式中。有许多不同的稀疏格式,但我们特别感兴趣的是半结构化稀疏性,也称为 2:4 结构化稀疏性或细粒度结构化稀疏性,更一般的说法是 N:M 结构化稀疏性。
2:4 稀疏压缩表示。原始 来源
2:4 稀疏矩阵是一种矩阵,其中每 4 个元素中最多只有 2 个非零元素,如上图所示。半结构化稀疏性之所以具有吸引力,是因为它在性能和精度之间取得了恰到好处的平衡。
- 自 Ampere 以来,NVIDIA GPU 为此格式提供了硬件加速和库支持 (cuSPARSELt),矩阵乘法速度提升高达 1.6 倍。
- 将模型剪枝以适应这种稀疏模式不会像其他模式那样大幅降低精度。NVIDIA 的 白皮书 显示,先剪枝再重新训练能够恢复大多数视觉模型的精度。
NVIDIA GPU 上 2:4(稀疏)矩阵乘法的图示。原始 来源
使用半结构化稀疏性加速推理非常简单。由于我们的权重在推理期间是固定的,我们可以提前(离线)剪枝和压缩权重,并存储压缩的稀疏表示而不是密集张量。
然后,我们不再调度到密集矩阵乘法,而是调度到稀疏矩阵乘法,传入压缩的稀疏权重而不是正常的密集权重。有关使用 2:4 稀疏性加速模型推理的更多信息,请参阅我们的教程。
将稀疏推理加速扩展到训练
为了使用稀疏性减少模型训练时间,我们需要考虑何时计算掩码,因为一旦我们存储了压缩表示,掩码就固定了。
对现有训练好的密集模型应用固定掩码进行训练(也称为剪枝)不会降低精度,但这需要进行两次训练运行——一次用于获取密集模型,另一次用于使其稀疏,这不会带来加速。
相反,我们希望从零开始训练稀疏模型(动态稀疏训练),但从零开始使用固定掩码进行训练会导致评估结果显著下降,因为稀疏性掩码会在初始化时选择,而此时模型权重基本上是随机的。
为了在从零开始训练时保持模型的精度,我们在运行时剪枝和压缩权重,以便在训练过程的每一步都能计算出最优掩码。
从概念上讲,您可以将我们的方法视为一种近似矩阵乘法技术,其中我们执行 `prune_and_compress`
并调度到 `sparse_GEMM`
,这比调用 `dense_GEMM`
所需的时间要少。这很困难,因为原生的剪枝和压缩函数太慢,无法显示出加速。
考虑到我们的 ViT-L 训练矩阵乘法的形状 (13008x4096x1024),我们测得密集 GEMM 和稀疏 GEMM 的运行时分别为 538us 和 387us。换句话说,权值矩阵的剪枝和压缩步骤必须在少于 538-387=151us 的时间内运行才能获得任何效率提升。不幸的是,cuSPARSELt 中提供的压缩核已经需要 380us(甚至没有考虑剪枝步骤!)。
考虑到 NVIDIA A100 的最大内存 IO (2TB/s),并考虑到剪枝和压缩核将受内存限制,理论上我们可以在 4us 内剪枝和压缩我们的权重 (4096x1024x2 bytes = 8MB) (8MB / 2TB/s)!事实上,我们能够编写一个核,将矩阵剪枝并压缩成 2:4 稀疏格式,并在 36 us 内运行(比 cuSPARSELt 中的压缩核快 10 倍),从而使整个 GEMM(包括稀疏化)更快。我们的核已可用于 PyTorch。
我们的自定义稀疏化核(包括剪枝 + 压缩)在线性层前向+后向传播中提高了约 30% 的速度。基准测试运行在 NVIDIA A100-80GB GPU 上。
编写高性能的运行时稀疏化核
为了实现高性能的运行时稀疏化核,我们面临着多项挑战,下面将进行探讨。
1) 处理后向传播
对于后向传播,我们需要计算 dL/dX 和 dL/dW 以进行梯度更新和后续层计算,这意味着我们需要分别计算 xWT 和 xTW。
用于训练加速的运行时稀疏化概述(前向 + 后向传播)
然而,这有问题,因为压缩表示无法转置,因为无法保证张量在两个方向上都是 2:4 稀疏的。
两个矩阵都是有效的 2:4 矩阵。然而,右边的矩阵一旦转置就不再是有效的 2:4 矩阵,因为有一列包含超过 2 个元素
因此,我们剪枝一个 4x4 的瓦片,而不是 1x4 的条带。我们贪婪地保留最大值,确保每行/每列最多取 2 个值。虽然这种方法不能保证最优,因为有时我们只保留 7 个值而不是 8 个,但它能有效地计算出一个在行和列方向上都为 2:4 稀疏的张量。
然后我们压缩打包张量和打包转置张量,并存储转置张量用于后向传播。通过同时计算打包张量和打包转置张量,我们避免了后向传播中的二次核调用。
我们的核在寄存器中剪枝权重矩阵,并将压缩值写入全局内存。它同时剪枝 W.t,这是后向传播所需的,从而最大程度地减少了内存 IO
处理后向传播还需要一些额外的转置技巧——底层硬件只支持第一个矩阵是稀疏的运算。对于推理期间的权重稀疏化,当我们需要计算 xWT 时,我们依赖转置属性来交换操作数的顺序。
在推理期间,我们使用 torch.compile
将外部转置融合到后续的逐点操作中,以避免性能损失。
然而,在训练的后向传播情况下,我们没有后续的逐点操作可以融合。相反,我们利用 cuSPARSELt 指定结果矩阵的行/列布局的能力,将转置融合到我们的矩阵乘法中。
2) 用于高效内存 IO 的核瓦片化
为了使我们的核尽可能高效,我们希望合并读/写操作,因为我们发现内存 IO 是主要的瓶颈。这意味着在一个 CUDA 线程中,我们希望每次读/写 128 字节的块,以便多个并行读/写可以由 GPU 内存控制器合并成一个请求。
因此,线程不再处理一个仅有 4x4x2 = 32 字节的 4x4 瓦片,我们决定每个线程处理 4 个 4x4 瓦片(即一个 8x8 瓦片),这使我们能够操作 8x8x2 = 128 字节的块。
3) 在 4x4 瓦片中对元素进行排序而不发生 Warp 分歧
对于线程中的每个独立的 4x4 瓦片,我们计算一个位掩码,指定哪些元素剪枝,哪些元素保留。为此,我们对所有 16 个元素进行排序,并贪婪地保留元素,只要它们不违反我们的 2:4 行/列约束。这只保留具有最大值的权重。
关键的是我们观察到我们只对固定数量的元素进行排序,因此通过使用无分支排序网络,我们可以避免 Warp 分歧。
为了清晰起见,省略了转置的打包张量和元数据。排序网络图来自维基百科。
Warp 分歧发生在线程块内存在条件执行时。在 CUDA 中,同一工作组(线程块)中的工作项在硬件层面以批次(warps)调度。如果我们有条件执行,导致同一批次中的某些工作项运行不同的指令,那么它们在 Warp 调度时会被屏蔽,或顺序调度。
例如,如果我们有一些代码,如 if (condition) do(A) else do(B)
,其中条件被所有奇数工作项满足,则此条件语句的总运行时间是 do(A) + do(B)
,因为我们会为所有奇数工作项调度 do(A)
,屏蔽偶数工作项,并为所有偶数工作项调度 do(B)
,屏蔽奇数工作项。这篇回答提供了更多关于 Warp 分歧的信息。
4) 写入压缩矩阵和元数据
计算出位掩码后,必须以压缩格式将权重数据写回全局内存。这并非易事,因为数据需要保留在寄存器中,而且无法对寄存器进行索引(例如 C[i++] = a
会阻止我们将 C
存储在寄存器中)。此外,我们发现 nvcc
使用的寄存器比我们预期的多得多,这导致了寄存器溢出并影响了整体性能。我们将此压缩矩阵以列主序格式写入全局内存,以提高写入效率。
我们还需要写入 cuSPARSELt 元数据。此元数据布局与开源 CUTLASS 库中的布局非常相似,并针对通过 GEMM 核中的共享内存使用 PTX ldmatrix
指令高效加载进行了优化。
然而,这种布局并未优化以实现高效写入:元数据张量的前 128 位包含关于行 0, 8, 16 和 24 的前 32 列的元数据。回想一下,每个线程处理一个 8x8 瓦片,这意味着这些信息分散在 16 个线程中。
我们分别依赖一系列 Warp Shuffle 操作,一次用于原始表示,一次用于转置表示,来写入元数据。幸运的是,这部分数据占总 IO 的比例不到 10%,所以我们可以承受不必完全合并写入。
DINOv2 稀疏训练:实验设置和结果
对于我们的实验,ViT-L 模型使用 DINOv2 方法在 ImageNet 上训练 125k 步。所有实验均在 4 个 AMD EPYC 7742 64 核 CPU 和 4 个 NVIDIA A100-80GB GPU 上运行。在稀疏训练期间,模型在训练的前半部分启用了 2:4 稀疏性,其中只有一半的权重被启用。由于权重在优化过程中不断更新,权重的稀疏性掩码在每一步动态重新计算。在剩余的步骤中,模型以密集方式训练,生成最终模型时不带 2:4 稀疏性(100% 稀疏训练设置除外),然后进行评估。
训练设置 | ImageNet 1k 对数回归 |
0% 稀疏(125k 密集步,基线) | 82.8 |
40% 稀疏(50k 稀疏 -> 75k 密集步) | 82.9 |
60% 稀疏(75k 稀疏 -> 50k 密集步) | 82.8 |
70% 稀疏(87.5k 稀疏 -> 37.5k 密集步) | 82.7 |
80% 稀疏(100k 稀疏 -> 25k 密集步) | 82.7 |
90% 稀疏(112.5k 稀疏 -> 12.5k 密集步) | 82.0 |
100% 稀疏(125k 稀疏步) | 82.3 (2:4 稀疏模型) |
在稀疏训练步骤中,我们在后向传播中获得稀疏权重的密集梯度。为了使梯度下降有效,我们应该在使用此梯度在优化器中更新权重之前也对其进行稀疏化。但我们没有这样做,而是使用完整的密集梯度来更新权重——我们发现这在实践中效果更好:这是 STE(直通估计器)策略。换句话说,我们在每一步都更新所有参数,即使是我们不使用的参数。
结论和未来工作
在这篇博文中,我们展示了如何使用半结构化稀疏性加速神经网络训练,并解释了我们面临的一些挑战。我们在 DINOv2 训练中实现了 6% 的端到端加速,精度略微下降 0.1 pp。
这项工作有几个扩展方向:
- 扩展到新的稀疏模式:研究人员创建了新的稀疏模式,如 V:N:M 稀疏性,它们利用底层的半结构化稀疏核以提供更大的灵活性。这对于将稀疏性应用于 LLM 尤其有趣,因为 2:4 稀疏性会导致精度下降过多,但我们看到了更通用的 N:M 模式取得了一些积极结果。
- 稀疏微调的性能优化:本文涵盖了从零开始的稀疏训练,但通常我们希望微调基础模型。在这种情况下,静态掩码可能足以保持精度,这将使我们能够进行额外的性能优化。
- 更多关于剪枝策略的实验:我们在网络的每一步计算掩码,但每 n 步计算一次掩码可能会获得更好的训练精度。总的来说,弄清在训练期间使用半结构化稀疏性的最佳策略是一个开放的研究领域。
- 与 fp8 的兼容性:硬件也支持 fp8 半结构化稀疏性,原则上这种方法应该与 fp8 类似地工作。实践中,我们需要编写类似的稀疏化核,并可能将它们与张量的缩放融合。
- 激活稀疏性:高效的稀疏化核也使得在训练期间稀疏化激活成为可能。由于稀疏化开销与稀疏化矩阵大小呈线性增长,与权重张量相比激活张量较大的设置可能从激活稀疏性中获益更多于权重稀疏性。此外,由于使用了 ReLU 或 GELU 激活函数,激活自然是稀疏的,从而减少了精度下降。
如果您对这些问题感兴趣,请随时在 torchao 中提交 issue / PR,这是我们正在为量化和稀疏性等架构优化技术建立的社区。此外,如果您对稀疏性有普遍兴趣,请在 CUDA-MODE (#sparsity) 中联系。