• 教程 >
  • 如何通过将优化器步骤融合到反向传播中来节省内存
快捷方式

如何通过将优化器步骤融合到反向传播中来节省内存

您好!本教程旨在展示一种通过减少梯度占用的内存来减少训练循环内存占用量的方法。假设您有一个模型,并且您有兴趣通过优化内存来避免 Out of Memory (OOM) 错误,或者只是让您的 GPU 更加高效。好吧,您可能会走运 (如果梯度占用您部分内存并且您不需要进行梯度累积)。我们将探讨以下内容

  1. 在训练或微调循环中占用内存的内容,

  2. 如何捕获和可视化内存快照以确定瓶颈,

  3. 新的 Tensor.register_post_accumulate_grad_hook(hook) API,最后,

  4. 10 行代码如何实现内存节省。

要运行本教程,您需要

  • 需要使用 PyTorch 2.1.0 或更高版本,并安装 torchvision 库。

  • 如果想要在本地运行内存可视化,则需要 1 个 CUDA GPU。 否则,这项技术在任何设备上都能带来类似的益处。

首先,让我们导入所需的模块和模型。我们将使用来自 torchvision 的视觉 Transformer 模型,但您可以随意替换成您自己的模型。 我们还将使用 torch.optim.Adam 作为我们的优化器,同样,您可以随意替换成您自己的优化器。

import torch
from torchvision import models
from pickle import dump

model = models.vit_l_16(weights='DEFAULT').cuda()
optimizer = torch.optim.Adam(model.parameters())
Downloading: "https://download.pytorch.org/models/vit_l_16-852ce7e3.pth" to /var/lib/ci-user/.cache/torch/hub/checkpoints/vit_l_16-852ce7e3.pth

  0%|          | 0.00/1.13G [00:00<?, ?B/s]
  1%|          | 10.5M/1.13G [00:00<00:11, 110MB/s]
  2%|1         | 21.0M/1.13G [00:00<00:25, 46.0MB/s]
  3%|2         | 31.2M/1.13G [00:00<00:22, 52.9MB/s]
  3%|3         | 37.5M/1.13G [00:00<00:24, 47.5MB/s]
  4%|4         | 49.2M/1.13G [00:01<00:23, 48.9MB/s]
  6%|5         | 65.6M/1.13G [00:01<00:19, 58.4MB/s]
  7%|6         | 81.2M/1.13G [00:01<00:14, 77.7MB/s]
  8%|7         | 90.2M/1.13G [00:01<00:18, 60.4MB/s]
  8%|8         | 97.5M/1.13G [00:01<00:19, 57.8MB/s]
  9%|8         | 104M/1.13G [00:01<00:21, 50.8MB/s]
 10%|9         | 113M/1.13G [00:02<00:21, 52.1MB/s]
 10%|#         | 119M/1.13G [00:02<00:22, 48.1MB/s]
 11%|#1        | 130M/1.13G [00:02<00:18, 59.3MB/s]
 12%|#1        | 137M/1.13G [00:02<00:29, 36.1MB/s]
 13%|#2        | 146M/1.13G [00:03<00:28, 36.8MB/s]
 13%|#2        | 150M/1.13G [00:03<00:31, 33.4MB/s]
 14%|#4        | 164M/1.13G [00:03<00:23, 44.1MB/s]
 15%|#5        | 179M/1.13G [00:03<00:17, 58.5MB/s]
 16%|#5        | 185M/1.13G [00:03<00:20, 49.5MB/s]
 17%|#6        | 196M/1.13G [00:03<00:16, 59.8MB/s]
 17%|#7        | 203M/1.13G [00:04<00:17, 57.2MB/s]
 18%|#8        | 213M/1.13G [00:04<00:18, 54.9MB/s]
 19%|#9        | 222M/1.13G [00:04<00:17, 56.6MB/s]
 20%|#9        | 229M/1.13G [00:04<00:17, 57.3MB/s]
 21%|##1       | 246M/1.13G [00:04<00:18, 52.6MB/s]
 22%|##2       | 261M/1.13G [00:05<00:20, 45.2MB/s]
 23%|##2       | 266M/1.13G [00:05<00:23, 39.7MB/s]
 24%|##3       | 277M/1.13G [00:05<00:23, 39.0MB/s]
 24%|##4       | 281M/1.13G [00:06<00:26, 34.7MB/s]
 25%|##5       | 294M/1.13G [00:06<00:20, 43.3MB/s]
 26%|##5       | 299M/1.13G [00:06<00:24, 37.5MB/s]
 27%|##6       | 311M/1.13G [00:06<00:18, 47.4MB/s]
 28%|##8       | 326M/1.13G [00:06<00:14, 62.4MB/s]
 29%|##8       | 333M/1.13G [00:06<00:14, 59.9MB/s]
 29%|##9       | 339M/1.13G [00:07<00:17, 48.3MB/s]
 30%|##9       | 345M/1.13G [00:07<00:18, 46.9MB/s]
 31%|###1      | 360M/1.13G [00:07<00:14, 58.2MB/s]
 32%|###2      | 377M/1.13G [00:07<00:12, 67.8MB/s]
 33%|###3      | 384M/1.13G [00:07<00:12, 66.3MB/s]
 34%|###3      | 393M/1.13G [00:07<00:11, 71.2MB/s]
 35%|###5      | 407M/1.13G [00:08<00:09, 87.0MB/s]
 36%|###5      | 416M/1.13G [00:08<00:13, 55.8MB/s]
 37%|###6      | 426M/1.13G [00:08<00:11, 64.3MB/s]
 37%|###7      | 434M/1.13G [00:08<00:16, 45.7MB/s]
 38%|###7      | 440M/1.13G [00:08<00:15, 49.3MB/s]
 38%|###8      | 446M/1.13G [00:09<00:16, 46.0MB/s]
 40%|###9      | 459M/1.13G [00:09<00:14, 50.6MB/s]
 41%|####      | 474M/1.13G [00:09<00:11, 65.1MB/s]
 41%|####1     | 481M/1.13G [00:09<00:12, 57.5MB/s]
 42%|####2     | 490M/1.13G [00:09<00:12, 57.2MB/s]
 43%|####2     | 496M/1.13G [00:09<00:13, 53.5MB/s]
 44%|####3     | 508M/1.13G [00:10<00:13, 51.0MB/s]
 45%|####5     | 523M/1.13G [00:10<00:11, 59.6MB/s]
 46%|####5     | 529M/1.13G [00:10<00:11, 55.8MB/s]
 47%|####6     | 541M/1.13G [00:10<00:11, 57.7MB/s]
 48%|####7     | 556M/1.13G [00:10<00:09, 64.7MB/s]
 48%|####8     | 563M/1.13G [00:11<00:09, 63.0MB/s]
 49%|####9     | 574M/1.13G [00:11<00:09, 68.0MB/s]
 51%|#####     | 587M/1.13G [00:11<00:07, 83.4MB/s]
 51%|#####1    | 596M/1.13G [00:11<00:08, 66.8MB/s]
 52%|#####2    | 606M/1.13G [00:11<00:08, 68.8MB/s]
 54%|#####3    | 623M/1.13G [00:11<00:08, 70.1MB/s]
 55%|#####4    | 638M/1.13G [00:12<00:08, 63.9MB/s]
 55%|#####5    | 644M/1.13G [00:12<00:09, 57.5MB/s]
 56%|#####6    | 655M/1.13G [00:12<00:08, 59.6MB/s]
 58%|#####7    | 672M/1.13G [00:12<00:07, 66.7MB/s]
 59%|#####9    | 688M/1.13G [00:12<00:06, 78.8MB/s]
 60%|#####9    | 696M/1.13G [00:13<00:06, 76.6MB/s]
 61%|######    | 704M/1.13G [00:13<00:07, 65.2MB/s]
 61%|######1   | 710M/1.13G [00:13<00:07, 60.4MB/s]
 62%|######2   | 720M/1.13G [00:13<00:08, 55.8MB/s]
 63%|######2   | 726M/1.13G [00:13<00:08, 50.9MB/s]
 63%|######3   | 737M/1.13G [00:13<00:07, 61.1MB/s]
 64%|######3   | 743M/1.13G [00:14<00:08, 50.3MB/s]
 65%|######4   | 752M/1.13G [00:14<00:07, 59.5MB/s]
 65%|######5   | 759M/1.13G [00:14<00:08, 46.9MB/s]
 66%|######6   | 769M/1.13G [00:14<00:07, 54.0MB/s]
 67%|######6   | 775M/1.13G [00:14<00:08, 47.2MB/s]
 68%|######7   | 786M/1.13G [00:15<00:08, 46.4MB/s]
 69%|######9   | 801M/1.13G [00:15<00:05, 64.1MB/s]
 70%|######9   | 809M/1.13G [00:15<00:07, 50.7MB/s]
 71%|#######   | 819M/1.13G [00:15<00:06, 57.4MB/s]
 72%|#######1  | 836M/1.13G [00:15<00:05, 67.0MB/s]
 73%|#######3  | 851M/1.13G [00:15<00:03, 84.7MB/s]
 74%|#######4  | 861M/1.13G [00:16<00:05, 61.6MB/s]
 75%|#######4  | 868M/1.13G [00:16<00:05, 56.0MB/s]
 76%|#######6  | 885M/1.13G [00:16<00:06, 44.3MB/s]
 77%|#######7  | 900M/1.13G [00:16<00:04, 57.3MB/s]
 78%|#######8  | 907M/1.13G [00:17<00:05, 50.2MB/s]
 79%|#######9  | 918M/1.13G [00:17<00:05, 49.1MB/s]
 80%|########  | 930M/1.13G [00:17<00:03, 62.5MB/s]
 81%|########  | 938M/1.13G [00:17<00:04, 54.1MB/s]
 82%|########1 | 950M/1.13G [00:17<00:03, 57.8MB/s]
 83%|########3 | 965M/1.13G [00:18<00:02, 74.3MB/s]
 84%|########3 | 974M/1.13G [00:18<00:03, 62.7MB/s]
 85%|########4 | 984M/1.13G [00:18<00:02, 70.0MB/s]
 86%|########5 | 998M/1.13G [00:18<00:02, 81.4MB/s]
 87%|########6 | 0.98G/1.13G [00:18<00:02, 62.1MB/s]
 87%|########7 | 0.99G/1.13G [00:19<00:02, 53.4MB/s]
 88%|########8 | 1.00G/1.13G [00:19<00:03, 48.4MB/s]
 89%|########8 | 1.01G/1.13G [00:19<00:02, 53.1MB/s]
 89%|########9 | 1.01G/1.13G [00:19<00:03, 42.5MB/s]
 90%|######### | 1.02G/1.13G [00:19<00:02, 51.2MB/s]
 92%|#########1| 1.04G/1.13G [00:20<00:02, 50.7MB/s]
 92%|#########2| 1.04G/1.13G [00:20<00:02, 46.1MB/s]
 93%|#########2| 1.05G/1.13G [00:20<00:01, 47.3MB/s]
 93%|#########3| 1.06G/1.13G [00:20<00:02, 40.3MB/s]
 94%|#########3| 1.06G/1.13G [00:20<00:02, 36.3MB/s]
 94%|#########4| 1.07G/1.13G [00:20<00:01, 46.2MB/s]
 95%|#########4| 1.08G/1.13G [00:21<00:01, 35.5MB/s]
 96%|#########5| 1.09G/1.13G [00:21<00:01, 46.4MB/s]
 97%|#########7| 1.10G/1.13G [00:21<00:00, 53.7MB/s]
 98%|#########8| 1.12G/1.13G [00:21<00:00, 67.2MB/s]
 99%|#########9| 1.12G/1.13G [00:22<00:00, 42.1MB/s]
100%|##########| 1.13G/1.13G [00:22<00:00, 54.7MB/s]

现在,让我们定义一个典型的训练循环。在训练时,您应该使用真实图像,但为了本教程的目的,我们将在其中传入假输入,无需担心加载任何实际数据。

IMAGE_SIZE = 224

def train(model, optimizer):
  # create our fake image input: tensor shape is batch_size, channels, height, width
  fake_image = torch.rand(1, 3, IMAGE_SIZE, IMAGE_SIZE).cuda()

  # call our forward and backward
  loss = model.forward(fake_image)
  loss.sum().backward()

  # optimizer update
  optimizer.step()
  optimizer.zero_grad()

训练期间的内存使用情况

我们将要查看一些内存快照,因此我们应该准备好正确地分析它们。 通常,训练内存包含以下部分:

  • 模型参数(大小 P)

  • 为反向传播保存的激活值(大小 A)

  • 梯度,其大小与模型参数相同,因此大小 G = P。

  • 优化器状态,它与参数大小成正比。在本例中,Adam 的状态需要 2 倍的模型参数,因此大小 O = 2P。

  • 中间张量,它们在整个计算过程中被分配。 我们现在不用担心它们,因为它们通常很小且是临时的。

捕获和可视化内存快照

让我们获得一个内存快照! 随着代码的运行,请思考您对 CUDA 内存时间线的预期。

# tell CUDA to start recording memory allocations
torch.cuda.memory._record_memory_history(enabled='all')

# train 3 steps
for _ in range(3):
  train(model, optimizer)

# save a snapshot of the memory allocations
s = torch.cuda.memory._snapshot()
with open(f"snapshot.pickle", "wb") as f:
    dump(s, f)

# tell CUDA to stop recording memory allocations now
torch.cuda.memory._record_memory_history(enabled=None)

现在,通过拖放 snapshot.pickle 文件,在 https://pytorch.ac.cn/memory_viz 中打开 CUDA 内存可视化器。内存时间线是否符合您的预期?

snapshot.png loaded into CUDA Memory Visualizer

模型参数已在训练步骤之前加载到内存中,因此我们看到一开始有一块内存用于权重。 当我们开始正向传播时,会为激活值(即我们为能够在反向传播中计算梯度而保存的张量)逐渐分配内存。 一旦我们开始反向传播,激活值就会逐渐释放,而梯度的内存开始累积。

最后,随着优化器的启动,它的状态将被延迟初始化,因此我们应该看到优化器状态内存仅在第一个训练循环的优化器步骤期间逐渐增加。 在未来的循环中,优化器内存将保持不变并在内存中更新。 梯度的内存将在每次训练循环结束时,当调用 zero_grad 时相应地释放。

在这个训练循环中,内存瓶颈在哪里? 或者换句话说,峰值内存在哪里?

峰值内存使用发生在优化器步骤期间! 请注意,此时内存包含约 1.2GB 的参数、约 1.2GB 的梯度以及约 2.4GB=2*1.2GB 的优化器状态,正如预期的那样。 最后的约 1.2GB 来自 Adam 优化器需要为中间值分配内存,总共约 6GB 的峰值内存。 技术上来说,如果您设置 Adam(model.parameters(), foreach=False),可以消除对最后 1.2GB 优化器中间值的需要,这将以运行时间为代价来节省内存。 如果关闭 foreach 运行时优化足以节省内存,那就太好了,但如果您想知道本教程如何帮助您做得更好,请继续阅读! 通过我们很快将要介绍的技术,我们将通过消除对约 1.2GB 的 **梯度内存** 以及 **优化器中间值内存** 的需求来减少峰值内存。 现在,您预期新的峰值内存是多少? 答案将在下一张快照中揭晓。

免责声明:这项技术并非适用于所有情况

在我们过于兴奋之前,我们必须考虑这项技术是否适用于您的用例。 这不是万能药! 将优化器步骤融合到反向传播中的技术仅针对减少 *梯度* 内存(并且作为副作用,也减少优化器中间值内存)。 因此,梯度占用的内存越大,内存减少就越显著。 在我们上面的例子中,梯度占用了 20% 的内存,这是一个相当大的比例!

对于您来说可能并非如此,例如,如果您的权重已经很小(例如,由于应用了 LoRa),那么梯度在您的训练循环中不会占用太多空间,收益就会小得多。 在这种情况下,您应该首先尝试其他技术,例如激活值检查点、分布式训练、量化或减少批次大小。 然后,当梯度再次成为瓶颈的一部分时,再回到本教程!

还在吗? 太棒了,让我们介绍一下新的 register_post_accumulate_grad_hook(hook) API 在 Tensor 上的使用。

Tensor.register_post_accumulate_grad_hook(hook) API 及其技术

我们的技术依赖于无需在 backward() 期间保存梯度。 相反,一旦梯度累积,我们将立即将优化器应用于相应的参数,并完全丢弃该梯度! 这消除了将大量梯度缓冲区保留到优化器步骤的需要。

那么,我们如何才能更积极地应用优化器的行为呢? 在我们的 2.1 版本中,我们添加了一个新的 API torch.Tensor.register_post_accumulate_grad_hook(),它允许我们在张量的 .grad 字段累积后向其添加一个钩子。 我们将把优化器步骤封装到这个钩子中。 如何实现?

10 行代码如何实现

还记得我们一开始的模型和优化器设置吗? 我将在下面将其注释掉,这样我们就不必浪费资源重新运行代码。

model = models.vit_l_16(weights='DEFAULT').cuda()
optimizer = torch.optim.Adam(model.parameters())
# Instead of having just *one* optimizer, we will have a ``dict`` of optimizers
# for every parameter so we could reference them in our hook.
optimizer_dict = {p: torch.optim.Adam([p], foreach=False) for p in model.parameters()}

# Define our hook, which will call the optimizer ``step()`` and ``zero_grad()``
def optimizer_hook(parameter) -> None:
  optimizer_dict[parameter].step()
  optimizer_dict[parameter].zero_grad()

# Register the hook onto every parameter
for p in model.parameters():
   p.register_post_accumulate_grad_hook(optimizer_hook)

# Now remember our previous ``train()`` function? Since the optimizer has been
# fused into the backward, we can remove the optimizer step and zero_grad calls.
def train(model):
  # create our fake image input: tensor shape is batch_size, channels, height, width
  fake_image = torch.rand(1, 3, IMAGE_SIZE, IMAGE_SIZE).cuda()

  # call our forward and backward
  loss = model.forward(fake_image)
  loss.sum().backward()

  # optimizer update --> no longer needed!
  # optimizer.step()
  # optimizer.zero_grad()

我们的示例模型只更改了约 10 行代码,这很方便。 但是,对于真实的模型来说,将优化器替换为优化器字典可能是一项相当侵入性的更改,尤其是对于那些使用 ``LRScheduler`` 或在整个训练 epochs 中操作优化器配置的人来说。 使用这些更改来处理这个 API 将更加复杂,并且可能需要将更多配置转移到全局状态,但这并非不可能。 也就是说,PyTorch 的下一步是使这个 API 更容易与 LRSchedulers 和您已经习惯的其他功能一起使用。

但让我回到说服您这项技术是值得的。 我们将咨询我们的朋友,内存快照。

# delete optimizer memory from before to get a clean slate for the next
# memory snapshot
del optimizer

# tell CUDA to start recording memory allocations
torch.cuda.memory._record_memory_history(enabled='all')

# train 3 steps. note that we no longer pass the optimizer into train()
for _ in range(3):
  train(model)

# save a snapshot of the memory allocations
s = torch.cuda.memory._snapshot()
with open(f"snapshot-opt-in-bwd.pickle", "wb") as f:
    dump(s, f)

# tell CUDA to stop recording memory allocations now
torch.cuda.memory._record_memory_history(enabled=None)

是的,花点时间将您的快照拖放到 CUDA 内存可视化器中。

snapshot.png loaded into CUDA Memory Visualizer
一些主要观察结果
  1. 不再有优化器步骤! 没错...我们将其融合到了反向传播中。

  2. 同样,反向传播持续时间更长,并且为中间值分配了更多随机内存。 这是预期的,因为优化器步骤需要中间值。

  3. 最重要的是! 峰值内存更低! 现在大约是 4GB(我希望与您之前的预期非常接近)。

请注意,与之前相比,不再为梯度分配任何大的内存块,这节省了约 1.2GB 的内存。 相反,我们通过尽可能提前移动优化器步骤,在计算后立即释放了每个梯度。 太棒了! 顺便说一下,其他约 1.2GB 的内存节省来自将优化器分解为每个参数优化器,因此中间值的大小也按比例缩小了。 这个细节 *不如* 梯度内存节省 *重要*,因为您只需关闭 foreach=False 就可以获得优化器中间值的节省,而无需使用这项技术。

您可能会想知道:如果我们节省了 2.4GB 的内存,为什么峰值内存不是 6GB - 2.4GB = 3.6GB? 嗯,峰值已经移动了! 峰值现在位于反向传播步骤的开始附近,此时我们仍然有激活值在内存中,而之前,峰值是在优化器步骤期间,当时激活值已经被释放。 因此,约 0.4GB 的差异(约 4.0GB - 约 3.6GB)是由于激活值内存造成的。 然后可以想象,这项技术可以与激活值检查点相结合,以实现更多内存节省。

结论

在本教程中,我们学习了将优化器融合到反向传播步骤中以节省内存的技术,该技术通过新的 Tensor.register_post_accumulate_grad_hook() API 实现,以及 *何时* 应用这项技术(当梯度内存很大时)。 在此过程中,我们还学习了有关内存快照的信息,这在内存优化中通常很有用。

脚本的总运行时间:(0 分钟 31.817 秒)

由 Sphinx-Gallery 生成的图库

文档

访问 PyTorch 的全面的开发者文档

查看文档

教程

获取面向初学者和高级开发者的深入教程

查看教程

资源

查找开发资源并获得问题的解答

查看资源