在本博客中,我们介绍了 PyTorch 中大型语言模型的端到端量化感知训练 (QAT) 流程。我们演示了 PyTorch 中的 QAT 如何恢复高达 96% 的 hellaswag 准确率下降以及 Llama3 在 wikitext 上 68% 的困惑度下降(与训练后量化 (PTQ) 相比)。我们介绍了 torchao 中的 QAT API,并展示了用户如何利用它们在 torchtune 中进行微调。
图 1: 使用 int8 每 token 动态激活 + int4 分组每通道权重,在 A100 GPU 上评估的、在 C4 数据集(en 子集)上使用和不使用 QAT 微调的 Llama3-8B。请注意 wikitext 的对数刻度(越低越好)。
为了演示 QAT 在端到端流程中的有效性,我们进一步将量化模型降低到 XNNPACK(一个针对包括 iOS 和 Android 在内的后端的高度优化神经网络库),通过 executorch。降低到 XNNPACK 后,QAT 模型看到的困惑度比 PTQ 模型低 16.8%,同时保持相同的模型大小和设备端推理和生成速度。
降低的模型指标 | PTQ | QAT |
Wikitext 单词困惑度 (↓) | 23.316 | 19.403 |
Wikitext 字节困惑度 (↓) | 1.850 | 1.785 |
Wikitext 每字节比特数 (↓) | 0.887 | 0.836 |
模型大小 | 3.881 GB | 3.881 GB |
设备端推理速度 | 5.065 tok/s | 5.265 tok/s |
设备端生成速度 | 8.369 tok/s | 8.701 tok/s |
表 1: 在降低到 XNNPACK 的 Llama3-8B 模型上,QAT 实现了低 16.8% 的困惑度,并且模型大小以及设备端推理和生成速度保持不变。线性层使用 int8 每 token 动态激活 + int4 分组每通道权重进行量化,并且嵌入层额外使用组大小为 32 的 int4 进行量化(QAT 仅应用于线性层)。Wikitext 评估在服务器 CPU 上使用 5 个样本和最大序列长度 127 执行,因为设备上无法进行评估(所有 wikitext 结果越低越好)。设备端推理和生成在三星 Galaxy S22 智能手机上进行基准测试。
QAT API
我们很高兴用户尝试我们在 torchao 中的 QAT API,该 API 可用于训练和微调。此 API 涉及两个步骤:准备和转换:准备对模型中的线性层应用转换,以模拟训练期间的量化数值,而转换实际上在训练后将这些层量化为较低的位宽。然后,转换后的模型可以像 PTQ 模型一样使用
import torch
from torchtune.models.llama3 import llama3
from torchao.quantization.prototype.qat import Int8DynActInt4WeightQATQuantizer
# Smaller version of llama3 to fit in a single GPU
model = llama3(
vocab_size=4096,
num_layers=16,
num_heads=16,
num_kv_heads=4,
embed_dim=2048,
max_seq_len=2048,
).cuda()
# Quantizer for int8 dynamic per token activations +
# int4 grouped per channel weights, only for linear layers
qat_quantizer = Int8DynActInt4WeightQATQuantizer()
# Insert "fake quantize" operations into linear layers.
# These operations simulate quantization numerics during
# training without performing any dtype casting
model = qat_quantizer.prepare(model)
# Standard training loop
optimizer = torch.optim.SGD(model.parameters(), lr=0.001, momentum=0.9, weight_decay=1e-5)
loss_fn = torch.nn.CrossEntropyLoss()
for i in range(10):
example = torch.randint(0, 4096, (2, 16)).cuda()
target = torch.randn((2, 16, 4096)).cuda()
output = model(example)
loss = loss_fn(output, target)
loss.backward()
optimizer.step()
optimizer.zero_grad()
# Convert fake quantize to actual quantize operations
# The quantized model has the exact same structure as the
# quantized model produced in the corresponding PTQ flow
# through `Int8DynActInt4WeightQuantizer`
model = qat_quantizer.convert(model)
# inference or generate
使用 torchtune 进行微调
我们还将此 QAT 流程集成到 torchtune 中,并提供了 配方,以便在分布式环境中运行此流程,类似于现有的完整微调分布式配方。用户还可以通过运行以下命令在 LLM 微调期间应用 QAT。有关更多详细信息,请参阅 此 README。
tune run --nproc_per_node 8 qat_distributed --config llama3/8B_qat_full
什么是量化感知训练?
量化感知训练 (QAT) 是一种常见的量化技术,用于减轻量化引起的模型准确率/困惑度下降。这是通过在训练期间模拟量化数值,同时保持原始数据类型(通常为浮点型)的权重和/或激活来实现的,有效地“伪量化”值,而不是实际将它们转换为较低的位宽
# PTQ: x_q is quantized and cast to int8
# scale and zero point (zp) refer to parameters used to quantize x_float
# qmin and qmax refer to the range of quantized values
x_q = (x_float / scale + zp).round().clamp(qmin, qmax).cast(int8)
# QAT: x_fq is still in float
# Fake quantize simulates the numerics of quantize + dequantize
x_fq = (x_float / scale + zp).round().clamp(qmin, qmax)
x_fq = (x_fq - zp) * scale
由于量化涉及舍入等不可微分操作,因此 QAT 反向传播通常使用 直通估计器 (STE),这是一种估计流经非平滑函数的梯度的机制,以确保传递给原始权重的梯度仍然有意义。通过这种方式,梯度是在知道权重最终将在训练后量化的情况下计算的,从而有效地允许模型在训练过程中调整量化噪声。请注意,QAT 的替代方案是量化训练,它实际上在训练期间将值转换为较低的位宽数据类型,但 先前的努力仅在 8 位时取得成功,而 QAT 即使在较低的位宽下也有效。
PyTorch 中的 QAT
我们在 torchao 的 prototype 下 此处 添加了初始 QAT 流程。目前,我们支持线性层的 int8 动态每 token 激活 + int4 分组每通道权重(缩写为 8da4w)。这些设置的动机是 边缘后端上的内核可用性 和 先前关于 LLM 量化的研究 的结合,该研究发现,与其他量化方案相比,每 token 激活和每组权重量化实现了 LLM 的最佳模型质量。
图 2: torchao QAT 流程。此流程涉及两个步骤:(1) 准备,将伪量化操作插入到模型的线性层中;(2) 转换,在训练后使用实际的量化和反量化操作转换这些伪量化操作。
此流程生成的量化模型与使用相同量化设置的 PTQ 流程完全相同(通过 Int8DynActInt4WeightQuantizer),但量化权重实现了卓越的准确率和困惑度。因此,我们可以使用从 QAT 流程转换的模型作为 PTQ 模型的直接替代品,并重用所有后端委托逻辑和底层内核。
实验结果
本博客文章中的所有实验均使用上述 torchtune QAT 集成执行。我们使用 6-8 个 A100 GPU(每个 GPU 80 GB)在 C4 数据集(en 子集)上微调 Llama2-7B 和 Llama3-8B 5000 步。对于所有实验,我们使用批处理大小 = 2,学习率 = 2e-5,Llama2 的最大序列长度 = 4096,Llama3 的最大序列长度 = 8192,完全分片数据并行 (FSDP) 作为我们的分布式策略,以及激活检查点以减少内存占用。对于 8da4w 实验,我们使用权重组大小为 256。
由于预训练数据集不易访问,我们在微调过程中执行 QAT。凭经验,我们发现,在前 N 步禁用伪量化会产生更好的结果,据推测,这是因为这样做允许权重在我们开始向微调过程引入量化噪声之前稳定下来。对于我们所有的实验,我们在前 1000 步禁用了伪量化。
我们使用 torchtune 中的 lm-evaluation-harness 集成评估我们的量化模型。我们报告了来自各种常用于评估 LLM 的任务的评估结果,包括 hellaswag(常识句子完成任务)、wikitext(下一个 token/字节预测任务)以及一些问答任务,例如 arc、openbookqa 和 piqa。对于 wikitext,困惑度是指模型预测下一个单词或字节的能力的倒数(越低越好),bits_per_byte
是指预测下一个字节所需的位数(此处也是越低越好)。对于所有其他任务,acc_norm
是指按目标字符串的字节长度标准化的准确率。
Int8 动态激活 + Int4 权重化 (8da4w)
从 Llama2 8da4w 量化开始,我们看到 QAT 能够恢复 62% 的 hellaswag 标准化准确率下降(与 PTQ 相比),以及 wikitext 上 58% 和 57% 的单词和字节困惑度下降(分别)。我们在大多数其他任务中都看到了类似的改进。
图 3a: 使用和不使用 QAT 的 Llama2-7B 8da4w 量化
图 3b: 在 wikitext 上评估的、使用和不使用 QAT 的 Llama2-7B 8da4w 量化(越低越好)
Llama3 8da4w 量化在 QAT 中看到了更明显的改进。在 hellaswag 评估任务中,我们能够恢复 96% 的 hellaswag 标准化准确率下降(与 PTQ 相比),与非量化准确率相比,总体下降幅度极小(<1%)。在 wikitext 评估任务中,QAT 分别恢复了 68% 和 65% 的单词和字节困惑度下降。即使对于 Llama2 QAT 难以处理的 arc_challenge,我们也能够恢复 51% 的标准化准确率下降。
图 4a: 使用和不使用 QAT 的 Llama3-8B 8da4w 量化
图 4b: 在 wikitext 上评估的、使用和不使用 QAT 的 Llama3-8B 8da4w 量化(越低越好)
更低位宽的仅权重量化
我们进一步扩展了 torchao QAT 流程,以支持 2 位和 3 位仅权重量化,并对 Llama3-8B 重复了相同的实验。在较低的位宽下,量化下降更为严重,因此我们对所有实验使用组大小为 32,以实现更精细的量化。
但是,这对于 2 位 PTQ 仍然不够,它看到的 wikitext 困惑度激增。为了缓解这个问题,我们利用先前敏感性分析的知识,即 Llama3 模型的前 3 层和后 2 层最敏感,并跳过量化这些层,以换取量化模型大小的适度增加(2 位为 1.78 GB,3 位为 1.65 GB)。这使 wikitext 单词困惑度从 603336 降至 6766,这很显着,但仍然远未达到可接受的水平。为了进一步改进量化模型,我们转向 QAT。
图 5a: 在 wikitext 上评估的、使用和不使用 QAT 的 Llama3-8B 2 位仅权重量化(越低越好)。带有“skip”的条形图表示跳过模型的前 3 层和后 2 层的量化,这些层对量化更敏感。请注意对数刻度。
我们观察到,在跳过前 3 层和后 2 层的量化的同时应用 QAT,进一步将单词困惑度降至更合理的 30(从 6766)。更一般而言,QAT 能够恢复 53% 的 hellaswag 标准化准确率下降(与 PTQ 相比),以及 wikitext 上 99% 和 89% 的单词和字节困惑度下降(分别)。然而,如果不跳过敏感层,QAT 在减轻量化模型质量下降方面的效果远不如前者。
图 5b: 使用和不使用 QAT 的 Llama3-8B 2 位仅权重量化。带有“skip”的条形图表示跳过模型的前 3 层和后 2 层的量化,这些层对量化更敏感。
对于 3 位仅权重量化,即使不跳过前 3 层和后 2 层,QAT 也有效,尽管跳过这些层仍然为 PTQ 和 QAT 带来了更好的结果。在跳过的情况下,QAT 能够恢复 63% 的 hellaswag 标准化准确率下降(与 PTQ 相比),以及 wikitext 上 72% 和 65% 的单词和字节困惑度下降(分别)。
图 6a: 使用和不使用 QAT 的 Llama3-8B 3 位仅权重量化。带有“skip”的条形图表示跳过模型的前 3 层和后 2 层的量化,这些层对量化更敏感。
图 6b: 在 wikitext 上评估的、使用和不使用 QAT 的 Llama3-8B 3 位仅权重量化(越低越好)。带有“skip”的条形图表示跳过模型的前 3 层和后 2 层的量化,这些层对量化更敏感。请注意对数刻度。
QAT 开销
QAT 在整个模型中插入了许多伪量化操作,从而显着增加了微调速度和内存使用量方面的开销。例如,对于像 Llama3-8B 这样的模型,我们有 (32 * 7) + 1 = 225 个线性层,每个线性层至少有 1 个权重的伪量化,并且可能还有 1 个输入激活的伪量化。内存占用增加也很明显,因为我们无法就地改变权重,因此我们需要在应用伪量化之前克隆它们,尽管这种开销可以通过启用激活检查点在很大程度上缓解。
在我们的微基准测试中,我们发现 8da4w QAT 微调比常规完整微调慢约 34%。通过激活检查点,每个 GPU 的内存增加约为 2.35 GB。大多数这些开销都是 QAT 工作原理的基本要素,尽管我们将来或许能够使用 torch.compile 加速计算。
每个 GPU 统计信息 | 完整微调 | QAT 微调 |
每秒中位数 token 数 | 546.314 tok/s | 359.637 tok/s |
中位数峰值内存 | 67.501 GB | 69.850 GB |
表 2: 在 6 个 A100 GPU(每个 GPU 80GB 内存)上,用于 int8 每 token 动态激活 + int4 分组每通道权重的 Llama3 QAT 微调开销。
展望未来
在本博客中,我们介绍了通过 torchao 的 LLM QAT 流程,将此流程与 torchtune 中的微调 API 集成,并演示了其潜力,与 PTQ 相比,它可以恢复大部分量化下降,并在某些任务上与非量化性能相匹配。未来还有许多探索方向
- 超参数调整。 大量的超参数调整可能会进一步改善微调和 QAT 的结果。除了学习率、批处理大小、数据集大小和微调步数等通用超参数外,我们还应调整 QAT 特定的超参数,例如何时开始/停止伪量化、伪量化多少步以及伪量化值的正则化参数。
- 异常值减少技术。 在我们的实验中,我们发现 PTQ 和 QAT 都容易受到异常值的影响。除了微调期间的简单裁剪和正则化外,我们还可以探索允许网络学习如何控制这些异常值的技术(例如 学习的量化范围、裁剪后的 softmax 和 门控注意力),或者甚至可能借鉴训练后设置中的异常值抑制技术(例如 SpinQuant、SmoothQuant),并在整个微调过程中谨慎应用它们。
- 混合精度和更复杂的数据类型。 特别是在较低位宽的情况下,我们看到跳过某些敏感层的量化对于 PTQ 和 QAT 都是有效的。我们需要完全跳过量化这些层吗?或者我们仍然可以量化它们,只是量化到更低的位宽?探索 QAT 环境中的混合精度量化将很有趣。使用更新的数据类型(如 MX4)进行训练是另一个有希望的方向,尤其是在即将推出的 Blackwell GPU 将 不再支持 int4 tensor core 的情况下。
- 与 LoRA 和 QLoRA 的可组合性。 我们的 torchtune 中的 QAT 集成目前仅支持完整微调工作流程。但是,许多用户希望使用低秩适配器微调其模型,以大幅减少内存占用。将 QAT 与 LoRA/QLoRA 等技术结合使用将使用户能够获得这些方法的内存和性能优势,同时生成最终将以最小模型质量下降进行量化的模型。
- 与 torch.compile 的可组合性。 这是另一种可能显着加速 QAT 中伪量化计算,同时减少内存占用的方法。torch.compile 目前与 torchtune 中完整分布式微调配方(无论是否使用 QAT)中使用的分布式策略不兼容,但将在不久的将来添加支持。
- 量化其他层。 在这项工作中,我们仅探索了量化线性层。但是,在长序列长度的上下文中,KV 缓存通常会成为吞吐量瓶颈,并且可能达到数十 GB,因此 LLM-QAT 探索了在量化激活和权重的同时量化 KV 缓存。先前的研究还在其他基于 Transformer 的模型中成功地将嵌入层量化到 2 位。
- 在高性能 cuda 内核上进行端到端评估。 这项工作的自然延伸是在高性能 cuda 内核上提供端到端 QAT 流程评估,类似于通过 executorch 降低到 XNNPACK 内核的现有 8da4w QAT 流程。对于 int4 仅权重量化,我们可以利用高效的 带有位压缩的 int4 权重 mm 内核 进行量化,并且正在进行工作以添加对此内核的 QAT 支持:https://github.com/pytorch/ao/pull/383。对于 8da4w 量化,混合 4 位/8 位 GEMM 也正在 cutlass 中添加。这将是构建高效 8da4w cuda 内核所必需的。
QAT 代码可以在此处找到。请参阅此 torchtune 教程以开始使用。如果您有任何其他问题,请随时在 torchao github 上提出 issue,或联系 andrewor@meta.com。我们欢迎您的反馈和贡献!