博客

使用 GemLite、TorchAO 和 SGLang 加速 LLM 推理

作者: 2025 年 1 月 21 日2025 年 5 月 5 日无评论

大型语言模型 (LLM) 通常需要大量资源,需要大量的内存、计算能力和电力才能有效运行。量化通过将权重和激活从 16 位浮点数降低到较低的比特率(例如 8 位、4 位、2 位)来提供解决方案,从而实现显著的速度提升和内存节省,并且还支持更大的批次大小。

现有的低精度推理解决方案对于小批量大小效果很好,但存在以下问题:

  • 增加批次大小时性能下降
  • 对量化类型的限制,例如,某些内核仅支持对称量化,这可能会影响模型在较低比特下的准确性
  • 量化、序列化和张量并行 (TP) 之间的相互作用使得加载量化模型变得困难,并且需要修改用户模型

为了应对这些挑战,我们创建了一个端到端、高性能、模块化且可扩展的低精度推理解决方案,整合了以下库:

  • GemLite:一个 Triton 内核库,解决了大批量大小和量化类型限制带来的性能瓶颈
  • TorchAO:一个 PyTorch 原生库,为量化、稀疏性和张量并行(带 DTensor)提供了简化的体验
  • SGLang:一个快速、高效且易于修改的大型语言模型 (LLM) 和视觉语言模型 (VLM) 的服务框架,支持广泛的模型

如果您有兴趣在 SGLang 中尝试此功能,请遵循这些 重现说明。在本文的其余部分,我们将详细介绍 GemLite、TorchAO 和 SGLang 的相关细节,包括库本身的设计以及在解决上述问题方面的集成。最后,我们将展示在 Llama 3.1-8B 模型上针对不同批次大小和张量并行大小的基准测试结果。

1. 结果预告

以下是在 8xH100 机器上对 Llama 3.1-8B 进行解码的结果摘要。对于所有实验,基线是 bfloat16 torch.compiled 模型

  bfloat16 w/ torch.compile int4 仅权重量化,分组大小 64 float8 每行动态量化
批次大小 1,TP 大小 1 131 tokens/秒 255 tokens/秒(1.95 倍加速) 166 tokens/秒(1.27 倍加速)
批次大小 32,TP 大小 1 2799 tokens/秒 3241 tokens/秒(1.16 倍加速) 3586 tokens/秒(1.28 倍加速)
批次大小 32,TP 大小 4 5575 tokens/秒 6334 tokens/秒(1.14 倍加速) 6159 tokens/秒(1.10 倍加速)

我们的解决方案支持 NVIDIA GPU,包括 H100 和 A100,并在不同批次大小和 TP 大小下,对 int4 仅权重(从 1.14 倍到 1.95 倍)和 float8 动态量化(从 1.10 倍到 1.28 倍)实现了相对于编译后的 bfloat16 基线的加速。请注意,量化可能会对准确性产生轻微影响,这超出了本文的范围。我们的 int4 仅权重量化与 HQQ 等保持准确性的技术兼容。有关更多信息,请参阅 TorchAO 的 README此基准测试这篇博文

2. GemLite:内核开发

这些内核是 GemLite 项目的一部分,该项目致力于优化低比特矩阵乘法内核。GemLite 使用 Triton 开发,为各种激活、比特率和硬件提供了高度灵活且高性能的解决方案。简而言之,这些内核提供了:

  • 支持各种激活数据类型:fp16、int8 和 fp8
  • 兼容性:可与非打包格式(例如 int8、fp8)和打包格式(例如 uint4、uint2、uint1)无缝协作
  • 性能优化:包括优化的内核和自动调优工具,以在不同硬件和批次大小上实现高性能
  • 集成:与 torch.compile 和 CUDA 图兼容,确保支持张量并行等高级功能

内核选择

优化大型语言模型 (LLM) 生成的内核选择需要解决不同批次大小的独特需求。LLM 工作负载涉及计算密集型和内存密集型迭代的混合:小批量大小是内存密集型的,而大批量大小则变为计算密集型。GemLite 内核旨在适应这些不同的需求,确保每种情况的最佳执行。

在内存密集型场景中,数据传输是限制因素,处理器通常会等待数据获取,导致计算资源利用不足。对于批次大小 = 1,GEMV 内核表现最佳;对于更大的批次大小,GEMM 内核更有效。对于批次大小在 2 到 64 之间且矩阵“瘦”的情况下,使用 GEMM-SPLITK 内核以实现更好的 GPU 利用率(arXiv)。

GemLite 包含针对每种场景优化的以下内核:

单样本推理

对于单样本推理,我们使用 GEMV 内核。但是,非对称量化方法需要加载每个块的额外元数据,例如尺度和零点。这可能导致内存传输增加,因此仔细处理至关重要。

具体而言,对于打包数据,我们的实验表明,每个连续块仅加载一次尺度和零点可以最大限度地减少冗余操作。由于这些块共享相同的元数据,这种方法实现了:

  • 与默认 GEMV 内核相比,端到端推理速度提升 5-8%
  • 与传统的 Split-K 方法相比,性能提升 30-40%

这种新的内核/算法 GEMV_REVSPLITK 在此处提供。

对于非打包数据,采用了 GEMV_SPLITK 算法。该算法迭代 k 维度以计算点积,而不依赖于 Triton 的 tl.dot。

批量推理

对于中等批次大小,我们使用基于 GEMM 的 Split-K 方法(arXiv),该方法将 k 维度(权重行)拆分为多个作业。通过自动调优 1 到 16 的值来找到最佳 SPLIT_K 参数。将 SPLIT_K=1 设置为启用回退实现到 GEMM 内核,允许使用相同的内核代码来处理从 32 和 64 开始的计算密集型批次大小,具体取决于矩阵形状和设备。

最大化高性能:关键实现见解

必须仔细处理各种实现细节才能实现高性能。以下是我们为确保高性能而重点关注的一些关键方面:

  1. 性能自动调优 自动调优对于实现最佳内核性能至关重要。由于此过程可能耗时,GemLite 提供了用于自动保存和加载所有内核的自动调优结果的工具。这确保了自动调优过程仅在每个 GPU 设备上执行一次,从而最大限度地减少了运行时,减少了重复开销,并保持了跨运行的一致性能。
  2. 确保内核正确性确保不同量化和配置设置下的内核正确性至关重要。Triton 的 早期配置修剪在此过程中发挥着关键作用。例如,在 Split-K 调优期间,仅当 K 可被 BLOCK_SIZE_K × SPLIT_K 整除时才选择配置,并且 BLOCKS_SIZE_K 基于 group-size 值进一步修剪。这种方法确保了内核操作的效率和正确性。
  3. 克服比特解包瓶颈在数据中心级 GPU(如 NVIDIA 的 A100 和 H100)上部署时,观察到与比特解包相关的性能瓶颈。为了缓解这些问题,我们探索了各种比特打包配置,包括按列打包与按行打包以及尝试不同的比特打包宽度(例如 8 位与 32 位)。值得注意的是,从 32 位打包切换到 8 位打包在 A100 上带来了高达 18% 的性能提升,在 H100 上带来了 6% 的性能提升。
  4. torch.compile 兼容性为确保与 PyTorch 的 torch.compile 无缝兼容,内核调用被包装在 custom_op 中。这种集成允许预挂钩和早期配置修剪等高级功能正确工作,在不牺牲性能的情况下提供准确的结果。虽然 PyTorch 中尚未完全支持这些 功能,但 custom_op 实现有效地弥合了差距,确保了平稳集成和高性能。

3. TorchAO

TorchAO 是一个 PyTorch 原生量化和稀疏性库,支持训练和推理,提供简单的用户 API 来训练、量化和部署低精度模型,并与其他 PyTorch 功能(如分布式推理和 torch.compile)可组合。

PyTorch 默认不支持低精度数据类型或不同的打包格式。通过 Tensor Subclass,我们将 PyTorch 原生 Tensor 抽象和模型量化扩展为数据类型转换,而自定义内核的不同打包格式则通过布局来处理。例如,我们支持 int4 权重的量化线性操作,这些权重以对 Tensor Core 友好的布局打包,并使用 tinygemm 或 GemLite 内核实现。更多详细信息可在此处找到 此处

flow diagram

除了为开发人员提供更多 PyTorch 原生抽象之外,我们还想强调此设计对建模用户的两个好处。

  1. 序列化:像浮点模型一样将量化权重保存和加载到 state_dict 中,无需在加载量化权重之前将浮点模型转换为量化模型。这减少了分发和部署量化模型的摩擦。
  2. 可组合性:与张量并行等下游功能无缝集成,允许用户专注于建模,而不必担心与张量并行、torch.compile 和其他 PyTorch 功能的兼容性。由于这些功能是使用 Tensor 级抽象实现的,因此用户在大多数情况下都可以进行量化和分布式推理,而无需更改模型。

GemLite 内核集成

为了实现 GemLite 内核的上述优势,我们将 GemLite 集成到 TorchAO 中。此集成利用了 GemLite 的广泛支持和灵活性,支持 4 位和 8 位权重仅量化,在非对称和对称量化方案下,支持 32 位和 8 位打包大小,以及分组和非分组量化。我们通过 `quantize_` API 实现此集成,该 API 可与 GemLite 构造函数一起使用,如下所示:

quantize_(model, gemlite_uintx_weight_only(group_size, bit_width, packing_bitwidth))

创建此集成的主要难点在于确保 TorchAO 的可组合性保证满足 GemLite 量化内核选项的全部范围。虽然主要集成相对直接,但确保每种不同的量化类型及其相关的内核都能与张量并行良好配合并非易事。

Torch 张量并行

张量并行是加速 LLM 推理的有效方法。TP 将大型线性或嵌入模块矩阵分片到多个设备上,通常以列式或行式进行。随着权重矩阵的分发,计算也随之分解。例如,下面的列式模式允许在四个设备上同时进行矩阵向量乘法。

equation

PyTorch 通过将常规张量(例如矩阵 *A*)转换为 *DTensor* 来实现 TP。

dtensor = _shard_tensor(mA, device_mesh, (Shard(0),))

由于 DTensor 存储有关分片元信息,因此它知道在需要时如何重建完整结果。以 Transformers 的前馈模块为例,由于下投影和上投影分别使用列式和行式分片,DTensor 将在进入下一个操作时自动对 rank 的结果进行 all-reduce。这种自动化允许模型作者专注于计算,而不必担心分布式执行所需的通信。

张量并行和量化顺序

由于 DTensor 和量化都是张量级转换,因此应用顺序对于确保工作流程通常可以在不同设置下正常工作很重要。我们有两个观察结果:(i)检查点通常以量化格式保存,以在每次运行前节省量化开销;(ii)TP 可能运行在不同数量的设备上,具体取决于资源限制或服务协议。因此,我们首先将量化应用于原始张量,并根据是否需要重用将其保存到磁盘。在服务启动时,我们加载量化检查点,并在将张量加载到模型时,将它们动态地分片到 DTensor 中。

TorchAO 中的张量并行支持

由于我们首先量化模型然后分发张量,因此我们将拥有 `DTensor(QuantizedTensor(weight))`,其中 `DTensor` 表示 PyTorchAO 中的分布式张量类,`QuantizedTensor` 表示量化张量类。`QuantizedTensor` 应支持在构造 `DTensor` 时调用的运算符,包括切片和视图操作。为了确保整体执行效率,沿维度 0 和 1 切片的打包权重应与先切片解包权重然后打包的结果匹配(打包和切片操作应可交换),否则打包格式与张量并行不兼容。

4. SGLang

SGLang 是一个用于大型语言模型和视觉语言模型的高速服务框架。它以其几乎 零开销的批次调度器 和快速的 约束解码 而闻名。它主要用 Python 实现,轻量级且易于修改。它也是首批集成 torch.compile 的框架之一。

TorchAO 在 SGLang 中的集成

我们将 `quantize_` API 集成到 SGLang 中,用于将特定类型的量化应用于模型。目前支持 int4 权重仅量化(tinygemm 和 GemLite 版本)、float8 动态量化以及其他几种类型的量化。用户可以通过向基准测试脚本添加 `--torchao-config` 参数来启用量化。当前启用的选项还通过与 DTensor 的组合支持张量并行,DTensor 使用 `--tp-size` 选项启用。

SGLang 中的 Torch 原生张量并行支持

SGLang 中现有的模型定义使用与张量并行风格绑定的特殊线性模块,例如:`MergedColumnParallelLinear`、`QKVParallelLinear` 和 `RowParallelLinear`。为了解耦模型定义和张量并行风格,我们定义了一个 PyTorch 原生模型,它使用 PyTorch 的纯 `nn.Linear` 模块,并依赖 PyTorch 张量并行 API 进行并行化和 torch.compile 进行加速。在相关的模块层次结构中,我们添加了一个字典,描述子模块应如何并行化。例如,在 `class LlamaAttention` 中,我们定义了:

_tp_plan = {
    "qkv_proj": "Colwise_Sharded",
    "o_proj": "Rowwise",
}

其中 `"qkv_proj"` 和 `"o_proj"` 分别是 `wqkv` 和 `wo` 投影的 FQN,值是它们的 TP 样式。

然后我们在 `model_parallel.py` 中定义了一个 TP 引擎。它递归地搜索模型中的 `_tp_plan`,并使用 PyTorch 的 parallelize_module API 将指示的 TP 样式应用于子模块。

5. 结果

评估侧重于 H100 机器上两种流行的量化技术:int4 权重仅量化和 float8 动态量化。选择这些方法是因为它们在 H100 机器上广泛用于优化内存效率和计算性能,使其成为针对各种工作负载进行基准测试的理想选择。

  • int4 权重仅量化:此方法显著减少内存占用并加速内存密集型工作负载的解码,对预填充或大批量大小等计算密集型场景的性能影响最小。我们下面展示了 bf16、GemLite 和 tinygemm 内核在各种批次大小和张量并行配置下的结果。
  • float8 动态量化:虽然内存节省较少,但此方法通常能提供更高的准确性,并为内存密集型和计算密集型任务提供平衡的加速。借助 Hopper 级硬件和原生的 fp8 支持,AO 使用的高效 cutlass/cuBLAS 内核有助于显著加速。

下面的图表显示了不同 tp 大小的解码 tokens/秒,每个图表显示了不同批次大小和不同量化类型的ョ果。

  • BF16 是我们的 bfloat16,torch.compile 编译后的基线。
  • tinygemm-4-64 使用 TorchAO 中的 `int4_weight_only` 量化,它是一种 4 位分组量化,分组大小为 64,使用 tinygemm 内核。
  • gemlite-4-64 使用 TorchAO 中的 `gemlite_uintx_weight_only` 量化,4 表示 4 位,64 也是分组大小,使用 GemLite 内核。
  • fp8dq-per_row 使用 `float8_dynamic_activation_float8_weight` 量化,激活和权重均按行尺度量化。
bar chart
bar chart
bar chart

对于 int4 权重仅量化,在批次大小为 1 时,tinygemm 内核取得了最佳性能。然而,随着批次大小的增加,其效率有所下降。相反,GemLite 有效地弥合了这一差距,在更大的批次大小下提供了卓越的性能。GemLite 在预填充阶段比 tinygemm 加速了 9-10 倍,尽管持续的性能优化受到 Triton 的限制。

float8 动态量化在不同批次大小和张量并行大小为 1 时,相对于 bfloat16 始终显示出 1.3 倍的加速,在更大的张量并行大小时显示出 1.1 倍到 1.2 倍的加速。随着张量并行大小的增加,整体加速会降低,这符合预期,因为矩阵乘法大小减小了。请注意,我们也期望预填充也能获得加速,但由于我们依赖 torch.compile 进行加速,而 SGLang 尚未启用预填充编译,因此我们将这项工作留待未来。

重现说明

我们在 8xH100 机器上使用 GemLite 0.4.1、来自 commit feb2b76 的 SGLang 构建、TorchAO nightly 0.8.0.dev20241223+cu124 和 PyTorch 2.5.1 进行了基准测试。Llama-3.1 Instruct 模型被选为评估架构。

BATCH_SIZE=16
# Note: gemlite is only compatible with float16
# while int4wo-64 (tinygemm-4-64 as shown in the graph) and fp8dq-per_row should use bfloat16
DTYPE=float16
# int4wo-64, fp8dq-per_tensor
TORCHAO_CONFIG=gemlite-4-64
TP_SIZE=2
# Decode performance
python3 -m sglang.bench_offline_throughput --model-path meta-llama/Llama-3.1-8B-Instruct --json-model-override-args '{"architectures": ["TorchNativeLlamaForCausalLM"]}' --dataset-name random --random-input 1024 --random-output 512 --random-range 1 --num-prompts $BATCH_SIZE --enable-torch-compile --dtype $DTYPE --torchao-config $TORCHAO_CONFIG --tp-size $TP_SIZE

# Example output
# Benchmark...
# [2024-12-20 12:42:16 TP0] Prefill batch. #new-seq: 2, #new-token: 2046, #cached-token: 4, cache hit rate: \0.06%, token usage: 0.00, #running-req: 0, #queue-req: 0
# ...
# [2024-12-20 12:45:35 TP0] Decode batch. #running-req: 16, #token: 16763, token usage: 0.01, gen throughput\ (token/s): 2.20, #queue-req: 0
# [2024-12-20 12:45:38 TP0] Decode batch. #running-req: 16, #token: 24443, token usage: 0.02, gen throughput\ (token/s): 2739.89, #queue-req: 0

# We reported the last throughput (token/s) as the performance for decode

结论

通过 GemLite 提供的性能卓越且可扩展的内核,PyTorch 原生架构优化库 TorchAO 以及高性能推理框架 SGLang,我们通过简单且可组合的用户 API 展示了 int4 和 float8 在不同批次大小和张量并行大小下的快速端到端量化推理,以降低 LLM 的资源需求。此次集成是我们满足不同模型、工作负载、精度和硬件的快速推理需求的第一步,我们期待继续推进端到端混合和低精度 LLM 推理的最新技术。

我们近期的工作重点如下:

  • 探索权重和激活量化的各种组合,以在速度和准确性之间取得最佳平衡
  • 扩展对其他 GPU 架构的支持,以扩大可访问性
  • 增强与 MoE 模型兼容性,以满足日益增长的可扩展推理需求
  • 允许轻松集成 TorchAO 中的快速自定义内核,以便 SGLang 和其他推理框架可以轻松利用它们
  • 虽然我们在此博文中没有衡量准确性影响,但我们可以开发 TorchAO 中的自动量化工具,允许用户在性能和准确性之间进行权衡。
  • SGLang 中更好的张量并行集成,以支持运行更大的模型
  • 在 SGLang 中启用预填充阶段的 torch.compile

我们也邀请社区积极测试、提供反馈并为塑造快速高效 LLM 推理的未来做出贡献。