LinkedIn:Shivam Sahni, Byron Hsu, Yanning Chen
Meta:Ankith Gunapal, Evan Smothers
本博客探讨了如何将自定义 Triton 内核 Liger Kernel 与 torch.compile 相结合,以提升使用 torchtune 对大语言模型(LLM)进行微调的性能。torchtune 是一个 PyTorch 原生库,提供模块化的构建块和可定制的微调方案,其中包括对各类 LLM 的 torch.compile 支持;而 Liger Kernel 则提供经过优化的 Triton 内核,旨在提高训练效率并降低内存占用。这种集成涉及修改 torchtune 中的 TransformerDecoder 模块以绕过线性层计算,从而允许 Liger 融合线性交叉熵损失(Liger Fused Linear Cross Entropy Loss)来处理前向投影权重。在 NVIDIA A100 实例上进行的实验表明,torch.compile 在吞吐量和内存效率方面均优于 PyTorch Eager 模式,而 Liger Kernel 进一步降低了峰值内存分配,并支持更大的批量大小(batch size)。结果显示,在批量大小为 256 时,峰值内存降低了 47%,且在使用 meta-llama/Llama-3.2-1B 时吞吐量有小幅提升,证实了该集成在不影响损失曲线的情况下具有极佳的效果。
torchtune 简介
torchtune 是一个专为微调 LLM 而设计的 PyTorch 原生库。正如本博客后续将展示的那样,torchtune 提供了可组合且模块化的构建块,以及可根据用户需求轻松定制的微调方案。
torchtune 提供:
- 主流 LLM 模型架构(如 Llama、Gemma、Mistral、Phi 和 Qwen 模型系列)的 PyTorch 实现。
- 用于全量微调、LoRA、QLoRA、DPO、PPO、QAT、知识蒸馏等任务的可修改训练方案。
- 开箱即用的内存效率、性能提升,以及通过最新的 PyTorch API(包括
torch.compile)实现扩展的能力。 - 用于轻松配置训练、评估、量化或推理方案的 YAML 配置文件。
- 内置对多种主流数据集格式和提示词模板的支持。
Liger Kernel 简介
Liger Kernel 是一个开源库,包含了一系列经过优化的 Triton 内核,旨在增强大语言模型(LLM)训练的效率和可扩展性。它专注于算子级优化(如算子融合和输入分块),与 HuggingFace 等现有实现相比,在训练吞吐量和 GPU 内存使用方面取得了显著改善。只需一行代码,Liger Kernel 即可将训练吞吐量提高 20%,内存使用量降低 60%。

图 1:融合线性交叉熵(Fused Linear Cross Entropy)
Liger Kernel 的性能提升主要归功于融合线性交叉熵(FLCE)损失,其核心思想如下:
在 LLM 中,词表大小显著增加,导致交叉熵(CE)损失计算过程中产生巨大的 Logit 张量。该 Logit 张量消耗了过多的内存,成为训练过程中的瓶颈。例如,在使用批量大小为 8、序列长度为 4096、词表大小为 256k 进行训练时,会产生 16.8 GB 的 Logit 张量。FLCE 内核将计算分解为更小的块,从而降低了内存消耗。
其工作原理如下:
- 通过折叠批量大小和序列长度维度,将 3D 隐藏状态展平为 2D 矩阵。
- 在分块的隐藏状态上依次应用线性投影头(Linear Projection Head)。
- 使用 Liger CE 内核计算部分损失并返回分块的 Logit 梯度。
- 导出分块的隐藏状态梯度并累积投影头梯度。
Torchtune 的方案开箱即用地提供了 torch.compile 支持。研究表明,将 torch.compile 与 FLCE 结合使用可使 FLCE 速度提升 2 倍。
集成 Liger Kernel 与 torch.compile 和 torchtune
我们通过运行 meta-llama/Llama-3.2-1B 的全量微调方案,演示了 Liger Kernel 与 torch.compile 及 torchtune 的集成。为了实现这一集成,我们定义了一个自定义的全量微调方案,具体变更细节如下。
CUDA_VISIBLE_DEVICES=0,1,2,3 tune run --nproc_per_node 4 recipes/full_finetune_distributed.py --config llama3_2/1B_full optimizer=torch.optim.AdamW optimizer.fused=True optimizer_in_bwd=False gradient_accumulation_steps=1 dataset.packed=True compile=True enable_activation_checkpointing=True tokenizer.max_seq_len=512 batch_size=128
LCE 内核的输入之一是前向投影权重。torchtune 被设计为一个具有可组合块的模块化库。其中有一个 TransformerDecoder 块,在该块的末尾,我们将最终隐藏状态通过线性层以获得最终输出。由于线性层在 LCE 内核中与 CE 损失相结合,我们为 TransformerDecoder 编写了一个自定义的 forward 函数,并跳过了线性层的计算。
在全量微调方案中,我们使用此自定义方法覆盖了模型的 forward 方法:
import types
from liger_kernel.torchtune.modules.transformers import decoder_forward
self._model.forward = types.MethodType(decoder_forward, self._model)
随后,我们传入模型的前向投影权重,利用 LCE 内核计算损失。
from liger_kernel.transformers.fused_linear_cross_entropy import (
LigerFusedLinearCrossEntropyLoss,
)
# Use LCE loss instead of CE loss
self._loss_fn = LigerFusedLinearCrossEntropyLoss()
# call torch.compile on the loss function
if self._compile:
training.compile_loss(self._loss_fn, verbose=self._is_rank_zero)
# pass the model's forward projection weights for loss computation
current_loss = (
self._loss_fn(
self._model.output.tied_module.weight,
logits,
labels,
)
* current_num_tokens
)
完整代码和说明可在 GitHub 仓库 中找到。
实验与基准测试结果
我们进行了 3 类实验,以证明 Liger Kernel 与 torch.compile 的集成如何增强 torchtune 的性能。我们在运行 NVIDIA A100 的实例上设置了实验。我们使用不同的批量大小微调了一个小型 LLM meta-llama/Llama-3.2-1B。我们记录了以 tokens/秒为单位的吞吐量,并测量了微调期间分配的峰值内存。由于是小型模型,我们在基准测试中仅使用了 4 个 A100 GPU。以下是我们进行的实验:
- 使用 PyTorch Eager 以 2 的幂次增加 batch_size
- 使用 torch.compile 以 2 的幂次增加 batch_size
- 使用 torch.compile 和 Liger 集成以 2 的幂次增加 batch_size
我们注意到,在使用 PyTorch Eager 时,吞吐量随 batch_size 的增加而增加,直到 batch_size 256 时出现 OOM(内存溢出)。使用 torch.compile 时,在每个 batch_size 下的吞吐量都高于 PyTorch Eager。随着 batch_size 的增加,峰值内存分配大幅减少,在 batch_size 128 时峰值内存降低了 50% 以上。这使得 torch.compile 能够支持 batch_size 256,因此使用 torch.compile 的总吞吐量比 PyTorch Eager 高出 36%。将 Liger Kernel 与 torch.compile 集成后,在较低的 batch_size 下吞吐量不会下降;随着 batch_size 的增加,我们注意到 torchtune 比单独使用 torch.compile 消耗的内存更少。在 batch_size 256 时,使用 Liger 内核后峰值内存分配减少了 47%。这使得我们能够在使用 torch.compile 和 Liger 的情况下支持 batch_size 512。我们注意到,与不使用自定义 Triton 内核的 torch.compile 相比,吞吐量有 1-2% 的微小提升。

图 2:每个 Rank 的 tokens/秒与 batch_size 的关系图

图 3:峰值内存分配与 batch_size 的关系图
为了排除我们将 Liger Kernel 集成到 torchtune 中可能存在的任何功能性问题,我们绘制了使用与不使用 Liger 时的损失曲线对比。我们发现损失曲线之间没有明显的差异。

图 4:batch_size=128 时损失与训练步数的关系图
下一步
- 在 torchtune 的方案中启用 Liger 内核,分别用于 DPO 方案中的 DPO 损失 和 知识蒸馏 方案中的 蒸馏损失。
- 在 torchtune 中通过 张量并行训练(tensor parallel training) 支持 Liger 集成。
鸣谢
感谢 Hamid Shojanazeri (Meta)、Less Wright (Meta)、Horace He (Meta) 和 Gregory Chanan (Meta) 对本博客文章提供的反馈与支持。