快捷方式

内存优化概述

作者: Salman Mohammadi

torchtune 提供了一系列即插即用的内存优化组件,为您提供大量灵活地 tune 我们的食谱以适应您的硬件。本页简要介绍了这些组件以及如何使用它们。为了方便起见,我们已在以下表格中总结了这些组件

内存优化组件

组件

何时使用?

模型精度

您通常希望将其保留为默认的 bfloat16。如果您由于精度问题而遇到训练稳定性或准确性问题,fp32 可能有所帮助,但会显着增加内存使用量并降低训练速度。

激活检查点

当您内存受限且需要处理更大的批次大小或更长的上下文长度时使用。请注意,它可能会降低训练速度。

梯度累积

内存受限时很有用,可以模拟更大的批次大小。通常优于激活检查点,可提高训练速度。

低精度优化器

当您需要通过降低优化器状态的精度来进一步减少内存使用量(超出了使用 bf16 的范围)时。请注意,低精度优化器可能会降低训练稳定性/准确性。

将优化器步骤融合到反向传递中

使用有状态优化器时有助于减少内存使用量,尤其是在使用高梯度内存使用量的完整微调大型模型时。这与 gradient_accumulation_steps 不兼容,因此由于模型吞吐量降低,训练可能会变慢。

低秩自适应 (LoRA)

当您希望显着减少可训练参数的数量,在训练期间节省梯度和优化器内存,并显着加快训练速度时。

量化低秩自适应 (QLoRA)

当您需要比 LoRA 更多的内存节省时,可能会以一些训练速度为代价。对于非常大的模型或有限的硬件很有用。

注意

在当前状态下,本教程侧重于单设备优化。请尽快查看更新后的页面,以获取有关分布式微调的最新内存优化功能。

模型精度

这里发生了什么?

我们使用“精度”一词来指代用于表示模型和优化器参数的底层数据类型。我们在 torchtune 中支持两种数据类型

注意

我们建议深入了解 Sebastian Raschka 的 关于混合精度技术的博客文章,以更深入地了解精度和数据格式的相关概念。

  • fp32,通常称为“全精度”,每个模型和优化器参数使用 4 个字节。

  • bfloat16,称为“半精度”,每个模型和优化器参数使用 2 个字节 - 实际上是 fp32 内存的一半,并且还提高了训练速度。通常,如果您的硬件支持使用 bfloat16 进行训练,我们建议您使用它 - 这是我们食谱的默认设置。

注意

另一个常见的范例是“混合精度”训练:模型权重为 bfloat16(或 fp16),优化器状态为 fp32。目前,我们在 torchtune 中不支持混合精度训练。

听起来很棒!我该如何使用它?

只需在我们所有的食谱中使用 dtype 标志或配置条目!例如,要在 bf16 中使用半精度训练,请设置 dtype=bf16

激活检查点

这里发生了什么?

PyTorch 文档 中的相关部分很好地解释了这个概念。引用

激活检查点是一种用计算换取内存的技术。它不会将反向传播所需的张量一直保存到它们在反向传播期间用于梯度计算之前,而是会省略在检查点区域的正向计算中保存用于反向传播的张量,并在反向传播过程中重新计算它们。

此设置对于内存受限的情况很有帮助,尤其是在更大的批次大小或更长的上下文长度的情况下。但是,这些内存节省是以训练速度(即每秒代币数)为代价的,在大多数情况下,由于这种激活重新计算,训练速度可能会大幅下降。

听起来很棒!我该如何使用它?

要启用激活检查点,请在我们任何食谱中使用 enable_activation_checkpointing 配置条目或标志,例如 enable_activation_checkpointing=True

激活卸载

这里发生了什么?

您可能刚刚阅读了有关激活检查点的内容!与检查点类似,卸载是一种内存效率技术,它允许通过将激活临时移动到 CPU 并根据需要在反向传播期间将其带回以节省 GPU VRAM。

请参阅 PyTorch autograd 挂钩教程,详细了解如何通过 saved_tensors_hooks 实现这一点。

此设置对于更大的批次大小或内存受限时的更长的上下文长度特别有用。但是,这些内存节省可能以训练速度(即每秒代币数)为代价,因为将张量从 GPU 移动到 CPU 并返回需要运行时间和资源。torchtune 中的实现具有 offload_with_streams 选项,以便使用多个 CUDA 流以重叠额外的通信和计算以隐藏额外的运行时间。由于通信工作负载因被卸载的张量的数量和大小而异,因此通常不会卸载每个激活。事实上,可以使用卸载与激活检查点结合使用,其中所有激活要么将在后面的反向传播中重新计算,要么从 CPU 返回。

听起来很棒!我该如何使用它?

要启用激活卸载,请在我们 lora 微调单设备食谱中使用 enable_activation_offloading 配置条目或标志,例如 enable_activation_offloading=True。要允许使用流,请确保您使用的是 PyTorch 2.5.0.dev20240907 之后的 torch 版本,并指定 offload_with_streams=True

梯度累积

这里发生了什么?

梯度累积允许您通过在使用优化器更新模型参数之前累积多个批次的梯度来模拟更大的批次大小。具体来说,使用梯度累积时,用于梯度更新的样本总数为

total_batch_size = batch_size * gradient_accumulation_steps

例如:使用 batch_size=1gradient_accumulation_steps=32,我们得到总批量大小为 32。

注意

对于 torchtune 中使用“步数”的其他组件,例如 指标记录,或 学习 调度器,一个“步数”被计为对模型参数的一次更新,而不是对数据的一次模型前向传递。假设 gradient_accumulation_steps = 4 并且 log_every_n_steps = 10。指标将在每 10 个全局步数记录一次,这意味着每 40 个模型前向传递记录一次。因此,在使用梯度累积训练时,指标记录的频率会降低,进度条的更新速度也会变慢。

如果您使用我们的分布式配方之一,只需乘以设备数量

total_batch_size = batch_size * gradient_accumulation_steps * num_devices

梯度累积在内存受限的情况下尤其有用。在这种情况下,累积梯度可能会比启用 激活检查点 提供更快的训练速度,因为激活检查点以重复计算为代价减少了内存消耗。

听起来很棒!我该如何使用它?

我们所有的微调配方都支持通过累积梯度来模拟更大的批量大小。只需设置 gradient_accumulation_steps 标志或配置条目。

注意

将优化器步骤融合到反向传播中 时,梯度累积始终应设置为 1。

低精度优化器

这里发生了什么?

除了在训练期间 降低模型和优化器的精度 之外,我们还可以进一步降低优化器状态的精度。我们所有的单设备微调配方都支持来自 bitsandbytes 库的低精度优化器 - 一个好的起点可能是 AdamW8bitPagedAdamW8bit 优化器,我们已经使用这些优化器测试了我们的配方。

听起来很棒!我该如何使用它?

要在您的配方中使用它,请确保您已安装 bitsandbytes (pip install bitsandbytes)。然后,使用 torchtune CLI 启用低精度优化器

tune run <RECIPE> --config <CONFIG> \
optimizer=bitsandbytes.optim.PagedAdamW

或通过直接 修改配置文件

optimizer:
  _component_: bitsandbytes.optim.PagedAdamW
  lr: 2e-5

将优化器步骤融合到反向传播中

这里发生了什么?

有状态优化器(例如,使用动量的优化器)是现代深度学习中的默认选择,因为它们具有稳定的收敛特性。但是,维护梯度统计信息的状态会带来额外的内存使用成本。一个直接的替代方案可能是转向无状态优化器,例如没有动量的 随机梯度下降,它不需要任何额外的内存使用,但可能会导致训练期间收敛性变差。

我们是否可以找到一个折衷方案?让我们考虑一种技术,它允许使用“有状态”优化器,例如 AdamW,而不会产生梯度统计信息的内存开销,也不会牺牲它们理想的收敛特性。您可能会问,这是如何实现的?通过完全删除优化器在执行 step() 期间存储的梯度缓冲区

为了理解它是如何工作的,我们建议您阅读 PyTorch 关于此概念的相关教程:如何通过将优化器步骤融合到反向传播中来节省内存

听起来很棒!我该如何使用它?

在 torchtune 中,您可以使用 optimizer_in_bwd 标志启用此功能,该功能目前仅在我们的单设备完整微调配方中支持。当梯度内存特别大时,此功能效果最佳;例如,当使用具有大量参数的模型的有状态优化器时,并且您不需要使用 梯度累积

参数高效微调 (PEFT)

低秩自适应 (LoRA)

这里发生了什么?

您可以阅读我们关于 使用 LoRA 微调 Llama2 的教程,以了解 LoRA 的工作原理以及如何使用它。简单地说,LoRA 大大减少了可训练参数的数量,从而在训练期间节省了大量的梯度和优化器内存。

听起来很棒!我该如何使用它?

您可以使用我们的任何配方进行微调,配方名称以 lora_ 为前缀,例如 lora_finetune_single_device。这些配方利用支持所有模型的 LoRA 启用的模型构建器,也使用 lora_ 为前缀,例如,torchtune.models.llama3.llama3() 模型具有相应的 torchtune.models.llama3.lora_llama3()。我们旨在提供一组全面的配置,让您能够快速开始使用 LoRA 进行训练,只需指定任何名称中包含 _lora 的配置,例如

tune run lora_finetune_single_device --config llama3/8B_lora_single_device

有两组参数可以自定义 LoRA 以满足您的需求。首先,控制 LoRA 应该应用于模型中哪些线性层的参数

  • lora_attn_modules: List[str] 接受一个字符串列表,指定要将 LoRA 应用于模型的哪些层

    • q_proj 将 LoRA 应用于查询投影层。

    • k_proj 将 LoRA 应用于键投影层。

    • v_proj 将 LoRA 应用于值投影层。

    • output_proj 将 LoRA 应用于注意力输出投影层。

    虽然添加更多要微调的层可能会提高模型精度,但这会以增加内存使用量和降低训练速度为代价。

  • apply_lora_to_mlp: Bool 将 LoRA 应用于每个 Transformer 层中的 MLP。

  • apply_lora_to_output: Bool 将 LoRA 应用于模型的最终输出投影。这通常是投影到词汇空间(例如在语言模型中),但其他建模任务可能具有不同的投影 - 例如,分类模型将投影到类别数量

注意

在最终输出投影中使用绑定嵌入的模型(例如 Gemma 和 Qwen2 1.5B 和 0.5B)不支持 apply_lora_to_output

这些参数都指定在 model 标志或配置条目下,即

tune run lora_finetune_single_device --config llama3/8B_lora_single_device  \
model.apply_lora_to_mlp=True \
model.lora_attn_modules=["q_proj","k_proj","v_proj"]
model:
  apply_lora_to_mlp: True
  model.lora_attn_modules: ["q_proj", "k_proj", "v_proj"]

其次,控制 LoRA 对模型的影响规模的参数

  • lora_rank: int 影响 LoRA 分解的规模,其中 lora_rank << in_dim 并且 lora_rank << out_dim - 模型中任意线性层的维度。具体而言,lora_rank 将存储在线性层中的梯度数量从 in_dim * out_dim 线性地减少到 lora_rank * (in_dim + out_dim)。通常,我们有 lora_rank in [8, 128]

  • lora_alpha: float 影响 LoRA 更新的幅度。更大的 alpha 会导致对基础模型权重的更大更新,这可能会以训练稳定性为代价,相反,较小的 alpha 可以稳定训练,但会以学习速度变慢为代价。我们为这些参数提供了我们已经过所有模型测试的默认设置,但我们鼓励您根据您的具体用例调整它们。通常,人们会同时改变 lora_ranklora_alpha,其中 lora_alpha ~= 2*lora_rank

  • lora_dropout 在 LoRA 层中引入丢弃,以帮助正则化训练。我们为所有模型默认设置为 0.0。

如上所述,这些参数也指定在 model 标志或配置条目下。

注意

要更深入地了解 LoRA 参数如何影响训练期间的内存使用情况,请参阅 我们 Llama2 LoRA 教程中的相关部分

量化低秩自适应 (QLoRA)

这里发生了什么?

QLoRA 是在 LoRA 之上的增强,它将 LoRA 中的冻结模型参数以 4 位量化精度维护,从而减少了内存使用量。这是通过作者提出的新颖的 4 位 NormalFloat (NF4) 数据类型实现的,该数据类型允许参数内存使用量减少 4-8 倍,同时保持模型精度。您可以阅读我们关于 使用 QLoRA 微调 Llama2 的教程,以更深入地了解它是如何工作的。

在考虑使用 QLoRA 来减少内存使用时,值得注意的是,QLoRA 通过在模型前向传递期间将量化参数向上转换为原始更高精度的数

听起来很棒!我该如何使用它?

您可以使用任何我们的 LoRA 配方来使用 QLoRA 进行微调,即带有 lora_ 前缀的配方,

tune run lora_finetune_single_device --config llama3/8B_qlora_single_device

所有其他 LoRA 参数对于 QLoRA 来说都保持不变 - 请查看上面关于

文档

访问 PyTorch 的全面开发人员文档

查看文档

教程

获取针对初学者和高级开发人员的深入教程

查看教程

资源

查找开发资源并获得问题的解答

查看资源