在本文中,我们介绍了 PyTorch 中针对大语言模型的端到端量化感知训练 (QAT) 流程。我们演示了 PyTorch 中的 QAT 如何在 hellaswag 上恢复 Llama3 相较于训练后量化 (PTQ) 高达 96% 的精度下降,并在 wikitext 上恢复 68% 的困惑度下降。我们展示了 torchao 中的 QAT API,并说明了用户如何利用它们在 torchtune 中进行微调。
图 1: 使用 int8 per token dynamic activations + int4 grouped per channel weights 在 C4 数据集(英文子集)上对 Llama3-8B 进行 QAT 或非 QAT 微调,并在 A100 GPU 上通过 hellaswag 和 wikitext 进行评估。请注意 wikitext 使用对数坐标(值越低越好)。
为了证明 QAT 在端到端流程中的有效性,我们通过 executorch 将量化后的模型进一步下沉到 XNNPACK,这是一个针对 iOS 和 Android 等后端的经过高度优化的神经网络库。下沉到 XNNPACK 后,QAT 模型相较于 PTQ 模型困惑度降低了 16.8%,同时保持了相同的模型大小和设备上推理及生成速度。
部署模型指标 | PTQ | QAT |
Wikitext 词困惑度 (↓) | 23.316 | 19.403 |
Wikitext 字节困惑度 (↓) | 1.850 | 1.785 |
Wikitext bits per byte (↓) | 0.887 | 0.836 |
模型大小 | 3.881 GB | 3.881 GB |
设备上推理速度 | 5.065 tok/s | 5.265 tok/s |
设备上生成速度 | 8.369 tok/s | 8.701 tok/s |
表 1: 将 Llama3-8B 模型下沉到 XNNPACK 后,QAT 实现了 16.8% 的困惑度降低,并保持了模型大小和设备上推理及生成速度不变。线性层使用 int8 per token dynamic activations + int4 grouped per channel weights 进行量化,嵌入层另外使用 group size 为 32 量化到 int4(QAT 仅应用于线性层)。Wikitext 评估使用服务器 CPU,通过 5 个样本和最大序列长度 127 进行(因为评估在设备上不可用)。所有 wikitext 结果都是值越低越好。设备上推理和生成在三星 Galaxy S22 智能手机上进行基准测试。
QAT API
我们很高兴用户能尝试 torchao 中的 QAT API,该 API 可用于训练和微调。此 API 包含两个步骤:prepare 和 convert。prepare 对模型中的线性层进行转换,以在训练期间模拟量化数值,而 convert 则在训练后将这些层实际量化为较低的位宽。转换后的模型可以与 PTQ 模型完全一样地使用。
import torch
from torchtune.models.llama3 import llama3
from torchao.quantization.prototype.qat import Int8DynActInt4WeightQATQuantizer
# Smaller version of llama3 to fit in a single GPU
model = llama3(
vocab_size=4096,
num_layers=16,
num_heads=16,
num_kv_heads=4,
embed_dim=2048,
max_seq_len=2048,
).cuda()
# Quantizer for int8 dynamic per token activations +
# int4 grouped per channel weights, only for linear layers
qat_quantizer = Int8DynActInt4WeightQATQuantizer()
# Insert "fake quantize" operations into linear layers.
# These operations simulate quantization numerics during
# training without performing any dtype casting
model = qat_quantizer.prepare(model)
# Standard training loop
optimizer = torch.optim.SGD(model.parameters(), lr=0.001, momentum=0.9, weight_decay=1e-5)
loss_fn = torch.nn.CrossEntropyLoss()
for i in range(10):
example = torch.randint(0, 4096, (2, 16)).cuda()
target = torch.randn((2, 16, 4096)).cuda()
output = model(example)
loss = loss_fn(output, target)
loss.backward()
optimizer.step()
optimizer.zero_grad()
# Convert fake quantize to actual quantize operations
# The quantized model has the exact same structure as the
# quantized model produced in the corresponding PTQ flow
# through `Int8DynActInt4WeightQuantizer`
model = qat_quantizer.convert(model)
# inference or generate
使用 torchtune 进行微调
我们还将此 QAT 流程集成到 torchtune 中,并提供了 配置 以在分布式设置中运行,类似于现有的 full fine-tune 分布式配置。用户可以通过运行以下命令在 LLM 微调期间额外应用 QAT。请参阅 此 README 以获取更多详细信息。
tune run --nproc_per_node 8 qat_distributed --config llama3/8B_qat_full
什么是量化感知训练?
量化感知训练 (QAT) 是一种常见的量化技术,用于减轻由量化引起的模型精度/困惑度下降。实现方法是在训练期间模拟量化数值,同时将权重和/或激活保留在原始数据类型(通常是浮点型)中,从而有效地对值进行“伪量化(fake quantizing)”,而不是实际将其转换为较低的位宽。
# PTQ: x_q is quantized and cast to int8
# scale and zero point (zp) refer to parameters used to quantize x_float
# qmin and qmax refer to the range of quantized values
x_q = (x_float / scale + zp).round().clamp(qmin, qmax).cast(int8)
# QAT: x_fq is still in float
# Fake quantize simulates the numerics of quantize + dequantize
x_fq = (x_float / scale + zp).round().clamp(qmin, qmax)
x_fq = (x_fq - zp) * scale
由于量化涉及舍入等不可微分操作,QAT 的反向传播通常使用 straight-through estimators (STE),这是一种用于估计流经非光滑函数梯度的机制,以确保传递给原始权重的梯度仍然有意义。通过这种方式,计算梯度时已知权重最终将在训练后被量化,从而有效地使模型能够在训练过程中适应量化噪声。请注意,QAT 的替代方法是量化训练,它在训练期间实际将值转换为较低位宽的数据类型,但 之前的尝试 仅在 8 位及以下取得了成功,而 QAT 在更低的位宽下仍然有效。
PyTorch 中的 QAT
我们在 torchao 的 prototype 中添加了一个初始 QAT 流程,见此处。目前,我们支持线性层的 int8 dynamic per-token activations + int4 grouped per-channel weights(简称 8da4w)。这些设置的动机结合了边缘后端上的内核可用性和 之前关于 LLM 量化的研究,这些研究发现 per-token 激活和 per-group 权重量化在 LLM 上比其他量化方案能获得最佳模型质量。
图 2: torchao QAT 流程。此流程包括两个步骤:(1) prepare,它将伪量化操作(fake quantization ops)插入模型的线性层中;(2) convert,它在训练后将这些伪量化操作转换为实际的量化(quantize)和反量化(dequantize)操作。
此流程产生的量化模型与使用相同量化设置(通过 Int8DynActInt4WeightQuantizer)的 PTQ 流程产生的模型完全相同,但其量化权重可获得更优越的精度和困惑度。因此,我们可以将 QAT 流程转换后的模型作为 PTQ 模型的直接替代品(drop-in replacement),并重用所有的后端委托逻辑和底层内核。
实验结果
本文中的所有实验均使用上述 torchtune QAT 集成进行。我们使用 6-8 块每块 80 GB 的 A100 GPU 在 C4 数据集(英文子集)上对 Llama2-7B 和 Llama3-8B 进行 5000 步的微调。对于所有实验,我们使用 batch size = 2,学习率 = 2e-5,Llama2 的最大序列长度 = 4096,Llama3 的最大序列长度 = 8192,使用 Fully Sharded Data Parallel (FSDP) 作为分布式策略,并使用激活检查点(activation checkpointing)来减少内存占用。对于 8da4w 实验,我们对权重使用了 group size 为 256。
由于预训练数据集不易获取,我们在微调过程中进行 QAT。凭经验,我们发现前 N 步禁用伪量化会产生更好的结果,这可能是因为这样做可以让权重在开始将量化噪声引入微调过程之前得以稳定。我们在所有实验中都禁用了前 1000 步的伪量化。
我们使用 torchtune 中集成的 lm-evaluation-harness 来评估我们的量化模型。我们报告了用于评估 LLM 的各种常见任务的评估结果,包括 hellaswag(一个常识性句子补全任务)、wikitext(一个下一个 token/字节预测任务)以及一些问答任务,如 arc、openbookqa 和 piqa。对于 wikitext,困惑度是指模型预测下一个词或字节能力的倒数(值越低越好);bits_per_byte
是指预测下一个字节所需的比特数(此处值越低也越好)。对于所有其他任务,acc_norm
指的是按目标字符串字节长度归一化的精度。
Int8 动态激活 + Int4 权重量化 (8da4w)
从 Llama2 8da4w 量化开始,我们看到 QAT 在 hellaswag 上相较于 PTQ 能够恢复 62% 的归一化精度下降,并在 wikitext 上分别恢复 58% 和 57% 的词和字节困惑度下降。我们在大多数其他任务上也看到了类似的改进。
图 3a: Llama2-7B 8da4w 量化,有 QAT 和无 QAT
图 3b: Llama2-7B 8da4w 量化,有 QAT 和无 QAT,在 wikitext 上评估(值越低越好)
Llama3 8da4w 量化在使用 QAT 后看到了更显著的改进。在 hellaswag 评估任务上,我们能够恢复相较于 PTQ 高达 96% 的归一化精度下降,与非量化精度相比,总体下降微乎其微(<1%)。在 wikitext 评估任务上,QAT 分别恢复了词和字节困惑度下降的 68% 和 65%。即使是对于 Llama2 QAT 来说比较困难的 arc_challenge 任务,我们也能够恢复 51% 的归一化精度下降。
图 4a: Llama3-8B 8da4w 量化,有 QAT 和无 QAT
图 4b: Llama3-8B 8da4w 量化,有 QAT 和无 QAT,在 wikitext 上评估(值越低越好)
低位宽仅权重量化
我们将 torchao QAT 流程进一步扩展到 2 位和 3 位仅权重量化,并对 Llama3-8B 重复了相同的实验。量化下降在较低位宽时更严重,因此我们在所有实验中使用了 group size 32 进行更精细的量化。
然而,这对于 2 位 PTQ 来说仍然不够,2 位 PTQ 的 wikitext 困惑度出现了爆炸性增长。为了缓解此问题,我们利用了之前敏感性分析的知识,即 Llama3 模型的前 3 层和最后 2 层最敏感,因此跳过(skip)对这些层进行量化,以换取量化模型大小的适度增加(2 位为 1.78 GB,3 位为 1.65 GB)。这使得 wikitext 词困惑度从 603336 下降到 6766,下降幅度很大,但距离可接受的范围仍有很大差距。为了进一步改进量化模型,我们转向了 QAT。
图 5a: Llama3-8B 2 位仅权重量化,有 QAT 和无 QAT,在 wikitext 上评估(值越低越好)。标有“skip”的条形图表示跳过对模型前 3 层和最后 2 层(对量化更敏感)进行量化。请注意对数坐标。
我们观察到,在跳过(skip)对前 3 层和最后 2 层进行量化的同时应用 QAT,词困惑度进一步下降到了一个更合理的数值 30(从 6766)。更普遍地说,相较于 PTQ,QAT 在 hellaswag 上能够恢复 53% 的归一化精度下降,并在 wikitext 上分别恢复 99% 和 89% 的词和字节困惑度下降。然而,如果不跳过敏感层,QAT 在减轻量化模型质量下降方面的效果会差很多。
图 5b: Llama3-8B 2 位仅权重量化,有 QAT 和无 QAT。标有“skip”的条形图表示跳过对模型前 3 层和最后 2 层(对量化更敏感)进行量化。
对于 3 位仅权重量化,即使不跳过前 3 层和最后 2 层,QAT 仍然有效,尽管跳过这些层对 PTQ 和 QAT 都能带来更好的结果。在跳过(skip)的情况下,相较于 PTQ,QAT 在 hellaswag 上能够恢复 63% 的归一化精度下降,并在 wikitext 上分别恢复 72% 和 65% 的词和字节困惑度下降。
图 6a: Llama3-8B 3 位仅权重量化,有 QAT 和无 QAT。标有“skip”的条形图表示跳过对模型前 3 层和最后 2 层(对量化更敏感)进行量化。
图 6b: Llama3-8B 3 位仅权重量化,有 QAT 和无 QAT,在 wikitext 上评估(值越低越好)。标有“skip”的条形图表示跳过对模型前 3 层和最后 2 层(对量化更敏感)进行量化。请注意对数坐标。
QAT 开销
QAT 在整个模型中插入了许多伪量化操作,给微调速度和内存使用都带来了相当大的开销。例如,对于像 Llama3-8B 这样的模型,我们有 (32 * 7) + 1 = 225 个线性层,每个层至少包含 1 个用于权重的伪量化操作,以及可能 1 个用于输入激活的伪量化操作。内存占用增加也很显著,因为我们不能就地修改权重,因此需要在应用伪量化之前克隆它们,尽管可以通过启用激活检查点(activation checkpointing)来大部分缓解这种开销。
在我们的微基准测试中,我们发现 8da4w QAT 微调比常规 full fine-tuning 慢约 34%。启用激活检查点后,每块 GPU 的内存增加约为 2.35 GB。这些开销大部分是 QAT 工作原理的基础,不过未来我们可能可以通过 torch.compile 加快计算速度。
每块 GPU 统计数据 | Full fine-tuning | QAT fine-tuning |
每秒中位数 token 数 | 546.314 tok/s | 359.637 tok/s |
中位数峰值内存 | 67.501 GB | 69.850 GB |
表 2: 在 6 块 A100 GPU(每块 80GB 内存)上进行 int8 per token dynamic activations + int4 grouped per channel weights 的 Llama3 QAT 微调开销。
展望未来
在本文中,我们通过 torchao 介绍了一种 LLM 的 QAT 流程,将此流程与 torchtune 中的微调 API 集成,并展示了其相较于 PTQ 恢复大部分量化下降并在某些任务上匹配非量化性能的潜力。未来还有许多探索方向
- 超参数调优。广泛的超参数调优很可能进一步改善微调和 QAT 的结果。除了学习率、批量大小、数据集大小和微调步数等一般超参数外,我们还应该调优 QAT 特定的超参数,例如何时开始/停止伪量化、伪量化的步数以及伪量化值的正则化参数等。
- 异常值减少技术。在我们的实验中,我们发现 PTQ 和 QAT 都容易受到异常值的影响。除了微调期间简单的 clamping 和正则化之外,我们可以探索让网络学习如何控制这些异常值的技术(例如:learned quantization ranges, clipped softmax 和 gated attention),或者甚至可以借鉴训练后设置中的异常值抑制技术(例如:SpinQuant, SmoothQuant),并在微调过程中谨慎地应用它们。
- 混合精度和更复杂的数据类型(dtypes)。特别是在低位宽领域,我们看到跳过(skip)对某些敏感层进行量化对 PTQ 和 QAT 都有效。我们是否需要完全跳过对这些层进行量化,还是可以仍然对其进行量化,只是使用更低的位宽?在 QAT 的背景下探索混合精度量化将很有趣。使用 MX4 等更新的数据类型进行训练是另一个有前景的方向,特别是考虑到即将推出的 Blackwell GPU 将 不再支持 int4 tensor cores。
- 与 LoRA 和 QLoRA 的可组合性。torchtune 中的 QAT 集成目前仅支持 full fine-tuning 工作流程。然而,许多用户希望使用低秩适配器(low-ranked adaptors)对模型进行微调,以大幅减少内存占用。将 QAT 与 LoRA / QLoRA 等技术结合,将使用户在获得这些方法的内存和性能优势的同时,生成一个最终量化后模型质量下降最小的模型。
- 与 torch.compile 的可组合性。这是另一种在减少内存占用的同时显着加快 QAT 中伪量化计算速度的潜在方法。torch.compile 目前与 torchtune 中 full distributed fine-tuning 配置中使用的分布式策略(无论是否使用 QAT)不兼容,但未来将很快添加支持。
- 量化其他层。在这项工作中,我们只探索了线性层的量化。然而,在长序列长度的背景下,KV 缓存常常成为吞吐瓶颈,并且可以达到数十 GB,因此 LLM-QAT 探索了与激活和权重一起量化 KV 缓存。 之前的工作 还在其他基于 Transformer 的模型中成功地将嵌入层量化到 2 位。
- 在高性能 CUDA 内核上进行端到端评估。这项工作的一个自然延伸是提供一个在高性能 CUDA 内核上进行评估的端到端 QAT 流程,类似于通过 executorch 下沉到 XNNPACK 内核的现有 8da4w QAT 流程。对于 int4 仅权重量化,我们可以利用高效的带 bitpacking 的 int4 weight mm 内核进行量化,并且正在进行为该内核添加 QAT 支持的工作:https://github.com/pytorch/ao/pull/383。对于 8da4w 量化,cutlass 中也正在添加 混合 4 位/8 位 GEMM。这将需要用于构建高效的 8da4w CUDA 内核。
QAT 代码可在此处找到:此处。请参考 此 torchtune 教程开始入门。如果您有任何进一步的问题,请随时在 torchao 的 github 上提出 issue 或通过 andrewor@meta.com 联系。我们欢迎您的反馈和贡献!