作者:LinkedIn 和 Meta

LinkedIn:Shivam Sahni, Byron Hsu, Yanning Chen
Meta:Ankith Gunapal, Evan Smothers

本文探讨了如何将自定义 Triton 内核(Liger Kernel)与 torch.compile 集成,以提升使用 torchtune 对大型语言模型 (LLM) 进行微调的性能。torchtune 是一个原生的 PyTorch 库,提供模块化的构建块和可定制的微调攻略,其中包括对各种 LLM 的 torch.compile 支持;而 Liger Kernel 提供优化的 Triton 内核,可提高训练效率并减少内存使用。集成涉及修改 torchtune 中的 TransformerDecoder 模块,绕过线性层计算,让 Liger Fused Linear Cross Entropy Loss 处理前向投影权重。在 NVIDIA A100 实例上进行的实验表明,torch.compile 在吞吐量和内存效率方面优于 PyTorch Eager,而 Liger Kernel 进一步降低了峰值内存分配,并支持更大的批量大小。结果显示,在批量大小为 256 时,峰值内存减少了 47%,并且使用 meta-llama/Llama-3.2-1B 时吞吐量略有增加,证实了集成的有效性,且不影响损失曲线。

torchtune 简介

torchtune 是一个原生的 PyTorch 库,专为微调 LLM 而设计。torchtune 提供可组合的模块化构建块以及可根据您的用例轻松定制的微调攻略,本文将对此进行展示。
torchtune 提供

  • PyTorch 实现的来自 Llama、Gemma、Mistral、Phi 和 Qwen 模型系列的流行 LLM 模型架构
  • 用于全量微调 (full finetuning)、LoRA、QLoRA、DPO、PPO、QAT、知识蒸馏 (knowledge distillation) 等的可定制训练攻略
  • 开箱即用的内存效率、性能改进以及利用最新的 PyTorch API(包括 torch.compile)进行扩展
  • 用于轻松配置训练、评估、量化或推理攻略的 YAML 配置文件
  • 内置支持多种流行的 QxgmS 数据集格式和提示模板

Liger Kernel 简介

Liger Kernel 是一个开源的优化 Triton 内核库,旨在提高大型语言模型 (LLM) 训练的效率和可扩展性。它专注于内核级优化,例如操作融合和输入分块,与 HuggingFace 等现有实现相比,显著提高了训练吞吐量和 GPU 内存使用率。通过使用一行代码,Liger Kernel 可以提高训练吞吐量 20%,并减少内存使用 60%

Fused Linear Cross Entropy

Liger Kernel 的大部分性能提升来自融合线性交叉熵损失 (Fused Linear Cross Entropy Loss, FLCE),其核心思想如下

在 LLM 中,词汇量显著增加,导致在计算交叉熵 (CE) 损失时产生大型 logit 张量。这个 logit 张量消耗过多内存,成为训练的瓶颈。例如,当批量大小为 8、序列长度为 4096 时,256k 的词汇量会产生 16.8 GB 的 logit 张量。FLCE 内核将计算分解为更小的块,从而减少内存消耗。

工作原理如下

  1. 通过折叠批量大小和序列长度维度,将 3D 隐藏状态展平为 2D 矩阵。
  2. 按顺序对分块的隐藏状态应用线性投影头。
  3. 计算部分损失,并使用 Liger CE 内核返回分块的 logits 梯度。
  4. 导出分块的隐藏状态梯度并累积投影头梯度。

Torchtune 的攻略开箱即用地提供了 torch.compile 支持。研究表明,将 torch.compile 与 FLCE 一起使用可使 FLCE 速度提高 2 倍

将 Liger Kernel 与 torch.compile & torchtune 集成

我们通过运行针对 meta-llama/Llama-3.2-1B 的全量微调攻略来演示 Liger Kernel 与 torch.compile 和 torchtune 的集成。为了实现这种集成,我们定义了一个自定义的全量微调攻略,更改的详细信息如下所述。

CUDA_VISIBLE_DEVICES=0,1,2,3 tune run --nproc_per_node 4 recipes/full_finetune_distributed.py --config llama3_2/1B_full optimizer=torch.optim.AdamW optimizer.fused=True optimizer_in_bwd=False gradient_accumulation_steps=1  dataset.packed=True compile=True enable_activation_checkpointing=True tokenizer.max_seq_len=512  batch_size=128

LCE 内核的输入之一是前向投影权重。torchtune 被设计为一个模块化库,具有可组合的块。其中有一个 TransformerDecoder ,在该块的末尾,我们将最终的隐藏状态通过一个线性层来获得最终输出。由于线性层与 CE 损失在 LCE 内核中合并,我们为 TransformerDecoder 编写了一个自定义的 forward 函数,其中我们跳过了通过线性层的计算。

在全量微调攻略中,我们使用这个自定义方法覆盖模型的 forward 方法

import types
from liger_kernel.torchtune.modules.transformers import decoder_forward
self._model.forward = types.MethodType(decoder_forward, self._model)

然后,我们将模型的前向投影权重传递给 LCE 内核计算损失

from liger_kernel.transformers.fused_linear_cross_entropy import (
    LigerFusedLinearCrossEntropyLoss,
)

# Use LCE loss instead of CE loss
self._loss_fn = LigerFusedLinearCrossEntropyLoss()

# call torch.compile on the loss function
if self._compile:
    training.compile_loss(self._loss_fn, verbose=self._is_rank_zero)

# pass the model's forward projection weights for loss computation
current_loss = (
     self._loss_fn(
         self._model.output.tied_module.weight,
         logits,
         labels,
     )
     * current_num_tokens
 )

完整的代码和说明可以在 GitHub 仓库 中找到。

实验 & 基准测试结果

我们进行了 3 种类型的实验,以证明 Liger Kernel 与 torch.compile 的集成如何增强 torchtune 的性能。我们在运行 NVIDIA A100 的实例上设置了实验。我们使用不同的批量大小微调了一个小型 LLM meta-llama/Llama-3.2-1B。我们记录了以 tokens/second 为单位的吞吐量,并测量了微调期间分配的峰值内存。由于这是一个小型模型,我们仅使用 4 个 A100 GPU 进行基准测试。以下是我们进行的实验:

  1. 使用 PyTorch eager 以 2 的幂次增加批量大小
  2. 使用 torch.compile 以 2 的幂次增加批量大小
  3. 使用 torch.compile & Liger 集成以 2 的幂次增加批量大小

我们注意到,使用 PyTorch Eager 时,吞吐量随着批量大小的增加而增加,直到在批量大小 256 时达到 OOM。使用 torch.compile 时,对于每个批量大小,吞吐量都高于 PyTorch Eager。我们看到,随着批量大小的增加,峰值内存分配急剧减少,在批量大小 128 时,峰值内存减少了 50% 以上。这使得 torch.compile 能够支持批量大小 256,因此,使用 torch.compile 的总吞吐量比 PyTorch Eager 高 36%。将 Liger Kernel 与 torch.compile 集成在较低批量大小时不会降低吞吐量,但随着批量大小的增加,我们注意到 torchtune 比 torch.compile 消耗的内存更少。在批量大小 256 时,使用 Liger 内核的峰值内存分配减少了 47%。这使得我们能够使用 torch.compile 和 Liger 的批量大小 512。我们注意到,与没有自定义 Triton 内核的 torch.compile 相比,吞吐量略微增加了 1-2%。

Plot of tokens/sec per rank vs batch_size

图 2:每个 rank 的 tokens/秒 vs 批量大小的图

Peak memory allocated vs batch_size

图 3:分配的峰值内存 vs 批量大小的图

为了排除 Liger Kernel 与 torchtune 集成可能存在的任何功能性问题,我们绘制了使用 & 不使用 Liger 时的损失曲线 vs 训练步骤图。我们看到损失曲线没有明显差异。

Plot of loss vs training steps for batch_size=128

图 4:批量大小=128 时的损失 vs 训练步骤图

后续步骤

致谢

感谢 Hamid Shojanazeri (Meta)、Less Wright (Meta)、Horace He (Meta) & Gregory Chanan (Meta) 为本文提供的反馈和支持。