作者:Aaron Shi, Zachary DeVito

在使用 GPU 上的 PyTorch 时,您可能熟悉这个常见的错误消息

torch.cuda.OutOfMemoryError: CUDA out of memory. Tried to allocate 512.00 MiB. GPU 0 has a total capacity of 79.32 GiB of which 401.56 MiB is free.

在本系列中,我们将展示如何使用显存工具,包括 Memory Snapshot、Memory Profiler 和 Reference Cycle Detector,以调试显存不足错误并改进显存使用。

Memory Timeline

Memory Snapshot 工具提供细粒度的 GPU 显存可视化,用于调试 GPU OOM (显存不足) 错误。捕获的显存快照将显示显存事件,包括分配、释放和 OOM,以及它们的堆栈跟踪。

在快照中,每个张量的显存分配都以不同颜色编码。x 轴表示时间,y 轴表示以 MB 为单位的 GPU 显存量。快照是交互式的,因此我们可以通过鼠标悬停来查看任何分配的堆栈跟踪。您可以在 https://github.com/pytorch/pytorch.github.io/blob/site/assets/images/understanding-gpu-memory-1/snapshot.html 亲自尝试。

在此快照中,有 3 个峰值,显示了 3 次训练迭代期间的显存分配(这是可配置的)。观察这些峰值,可以轻松看到正向传播过程中显存的上升以及计算梯度时反向传播过程中的下降。还可以看到程序在迭代之间具有相同的显存使用模式。值得注意的是,显存中存在许多微小的峰值,通过鼠标悬停在上面,我们可以看到它们是卷积操作符临时使用的缓冲区。

捕获显存快照

捕获显存快照的 API 相当简单,可在 torch.cuda.memory 中使用

  • 开始: torch.cuda.memory._record_memory_history(max_entries=100000)
  • 保存: torch.cuda.memory._dump_snapshot(file_name)
  • 停止: torch.cuda.memory._record_memory_history(enabled=None)

代码片段(完整代码示例请参见附录 A

   # Start recording memory snapshot history, initialized with a buffer
   # capacity of 100,000 memory events, via the `max_entries` field.
   torch.cuda.memory._record_memory_history(
       max_entries=MAX_NUM_OF_MEM_EVENTS_PER_SNAPSHOT
   )

   # Run your PyTorch Model.
   # At any point in time, save a snapshot to file for later.
   for _ in range(5):
       pred = model(inputs)
       loss_fn(pred, labels).backward()
       optimizer.step()
       optimizer.zero_grad(set_to_none=True)

   # In this sample, we save the snapshot after running 5 iterations.
   #   - Save as many snapshots as you'd like.
   #   - Snapshots will save last `max_entries` number of memory events
   #     (100,000 in this example).
   try:
       torch.cuda.memory._dump_snapshot(f"{file_prefix}.pickle")
   except Exception as e:
       logger.error(f"Failed to capture memory snapshot {e}")

   # Stop recording memory snapshot history.
   torch.cuda.memory._record_memory_history(enabled=None)

为了可视化快照文件,我们提供了一个工具,托管在 https://pytorch.ac.cn/memory_viz。您可以在那里拖放您保存的快照文件,它将随时间绘制每个分配。隐私注意事项:该工具不会保存您的快照。

Memory Timeline

或者,您可以使用 pytorch/torch/cuda/_memory_viz.py 中的脚本从 .pickle 文件生成 HTML,示例如下

python torch/cuda/_memory_viz.py trace_plot snapshot.pickle -o snapshot.html

调试 CUDA OOM 错误

让我们看看如何使用显存快照工具来回答

  1. 为什么发生了 CUDA OOM
  2. GPU 显存在哪里被使用?

带有 bug 的 ResNet50

我们在第一个快照中查看了一个正常工作的模型。现在,让我们看看一个带有 bug 的训练示例,请看快照

Memory Timeline

注意第二次迭代使用的显存比第一次迭代多得多。如果模型更大,它可能在第二次迭代时就发生了 CUDA OOM,而难以找出原因。

Memory Timeline

进一步检查此快照时,我们可以清楚地看到,有几个张量从第一次迭代到第二次以及后续迭代都一直存在。如果我们将鼠标悬停在其中一个张量上,它会显示一个堆栈跟踪,表明这些是梯度张量

如果我们查看代码,确实可以看到它没有清除梯度张量,而它本可以在正向传播之前清除它们

修改前

        for _ in range(num_iters):
          pred = model(inputs)
          loss_fn(pred, labels).backward()
          optimizer.step()

修改后

        for _ in range(num_iters):
          pred = model(inputs)
          loss_fn(pred, labels).backward()
          optimizer.step()
          # Add this line to clear grad tensors
          optimizer.zero_grad(set_to_none=True)

我们可以简单地添加一个 optimizer.zero_grad(set_to_none=True) 指令来清除迭代之间的梯度张量(关于为何需要将梯度归零的更多详细信息请参见此处:https://pytorch.ac.cn/tutorials/recipes/recipes/zeroing_out_gradients.html)。

这是我们使用此工具在更复杂的程序中发现的一个 bug 的简化示例。我们鼓励您在遇到 GPU 显存问题时尝试使用 Memory Snapshot,并告知我们效果如何。

修复 bug 后的 ResNet50

应用修复后,快照显示梯度现在正在被清除。

Memory Timeline

现在我们有了正常工作的 ResNet50 模型的快照。您可以自己尝试运行代码(参见附录 A 中的代码示例)。

但您可能想知道,为什么在第一次迭代后显存仍然会增加?为了回答这个问题,让我们在下一节看看 Memory Profiler

分类显存使用

Memory Profiler 是 PyTorch Profiler 的一个新增功能,可以按时间分类显存使用情况。我们仍然依赖 Memory Snapshot 的堆栈跟踪来深入分析显存分配。

要生成显存时间线,以下是一个代码片段(完整代码示例请参见附录 B

   # Initialize the profiler context with record_shapes, profile_memory,
   # and with_stack set to True.
   with torch.profiler.profile(
       activities=[
           torch.profiler.ProfilerActivity.CPU,
           torch.profiler.ProfilerActivity.CUDA,
       ],
       schedule=torch.profiler.schedule(wait=0, warmup=0, active=6, repeat=1),
       record_shapes=True,
       profile_memory=True,
       with_stack=True,
       on_trace_ready=trace_handler,
   ) as prof:
       # Run the PyTorch Model inside the profile context.
       for _ in range(5):
           prof.step()
           with record_function("## forward ##"):
               pred = model(inputs)

           with record_function("## backward ##"):
               loss_fn(pred, labels).backward()

           with record_function("## optimizer ##"):
               optimizer.step()
               optimizer.zero_grad(set_to_none=True)

   # Construct the memory timeline HTML plot.
   prof.export_memory_timeline(f"{file_prefix}.html", device="cuda:0")

更多参考信息,请参见 https://pytorch.ac.cn/docs/main/profiler.html

Memory Profiler 根据分析期间记录的张量操作图自动生成类别。

Memory Timeline

使用 Memory Profiler 收集的此显存时间线显示了与之前相同的训练示例。我们可以观察到蓝色的梯度现在在迭代之间被清除了。我们还可以注意到,黄色的优化器状态在第一次迭代后被分配,并在作业的其余部分保持不变。

这个优化器状态是 GPU 显存从第一次迭代到第二次增加的原因。您可以自己尝试运行代码(参见附录 B 中的代码示例)。Memory Profiler 有助于提高训练时的显存理解,以便模型作者能够找出哪些类别使用了最多的 GPU 显存。

在哪里可以找到这些工具?

我们希望这些工具能够极大地提高您调试 CUDA OOM 并按类别理解显存使用情况的能力。

Memory Snapshot 和 Memory Profiler 作为实验性功能已在 PyTorch v2.1 版本中提供。

反馈

我们期待听到关于我们的工具帮助解决的任何增强功能、bug 或显存故事!一如既往,欢迎随时在 PyTorch 的 Github 页面上提交新问题。

我们也欢迎 OSS (开源软件) 社区的贡献,欢迎在任何 Github PR (Pull Request) 中标记 Aaron ShiZachary DeVito 进行审阅。

致谢

非常感谢内容审阅者,Mark SaroufimGregory Chanan 审阅此文章并提高了其可读性。

非常感谢 Adnan AzizLei Tian 的代码审阅和反馈。

附录

附录 A - ResNet50 显存快照代码示例

# (c) Meta Platforms, Inc. and affiliates. 
import logging
import socket
from datetime import datetime, timedelta

import torch

from torchvision import models

logging.basicConfig(
   format="%(levelname)s:%(asctime)s %(message)s",
   level=logging.INFO,
   datefmt="%Y-%m-%d %H:%M:%S",
)
logger: logging.Logger = logging.getLogger(__name__)
logger.setLevel(level=logging.INFO)

TIME_FORMAT_STR: str = "%b_%d_%H_%M_%S"

# Keep a max of 100,000 alloc/free events in the recorded history
# leading up to the snapshot.
MAX_NUM_OF_MEM_EVENTS_PER_SNAPSHOT: int = 100000

def start_record_memory_history() -> None:
   if not torch.cuda.is_available():
       logger.info("CUDA unavailable. Not recording memory history")
       return

   logger.info("Starting snapshot record_memory_history")
   torch.cuda.memory._record_memory_history(
       max_entries=MAX_NUM_OF_MEM_EVENTS_PER_SNAPSHOT
   )

def stop_record_memory_history() -> None:
   if not torch.cuda.is_available():
       logger.info("CUDA unavailable. Not recording memory history")
       return

   logger.info("Stopping snapshot record_memory_history")
   torch.cuda.memory._record_memory_history(enabled=None)

def export_memory_snapshot() -> None:
   if not torch.cuda.is_available():
       logger.info("CUDA unavailable. Not exporting memory snapshot")
       return

   # Prefix for file names.
   host_name = socket.gethostname()
   timestamp = datetime.now().strftime(TIME_FORMAT_STR)
   file_prefix = f"{host_name}_{timestamp}"

   try:
       logger.info(f"Saving snapshot to local file: {file_prefix}.pickle")
       torch.cuda.memory._dump_snapshot(f"{file_prefix}.pickle")
   except Exception as e:
       logger.error(f"Failed to capture memory snapshot {e}")
       return

# Simple Resnet50 example to demonstrate how to capture memory visuals.
def run_resnet50(num_iters=5, device="cuda:0"):
   model = models.resnet50().to(device=device)
   inputs = torch.randn(1, 3, 224, 224, device=device)
   labels = torch.rand_like(model(inputs))
   optimizer = torch.optim.SGD(model.parameters(), lr=1e-2, momentum=0.9)
   loss_fn = torch.nn.CrossEntropyLoss()

   # Start recording memory snapshot history
   start_record_memory_history()

   for _ in range(num_iters):
       pred = model(inputs)
       loss_fn(pred, labels).backward()
       optimizer.step()
       optimizer.zero_grad(set_to_none=True)

   # Create the memory snapshot file
   export_memory_snapshot()

   # Stop recording memory snapshot history
   stop_record_memory_history()

if __name__ == "__main__":
    # Run the resnet50 model
    run_resnet50()

附录 B - ResNet50 显存分析器代码示例

# (c) Meta Platforms, Inc. and affiliates. 
import logging
import socket
from datetime import datetime, timedelta

import torch

from torch.autograd.profiler import record_function
from torchvision import models

logging.basicConfig(
   format="%(levelname)s:%(asctime)s %(message)s",
   level=logging.INFO,
   datefmt="%Y-%m-%d %H:%M:%S",
)
logger: logging.Logger = logging.getLogger(__name__)
logger.setLevel(level=logging.INFO)

TIME_FORMAT_STR: str = "%b_%d_%H_%M_%S"

def trace_handler(prof: torch.profiler.profile):
   # Prefix for file names.
   host_name = socket.gethostname()
   timestamp = datetime.now().strftime(TIME_FORMAT_STR)
   file_prefix = f"{host_name}_{timestamp}"

   # Construct the trace file.
   prof.export_chrome_trace(f"{file_prefix}.json.gz")

   # Construct the memory timeline file.
   prof.export_memory_timeline(f"{file_prefix}.html", device="cuda:0")

def run_resnet50(num_iters=5, device="cuda:0"):
   model = models.resnet50().to(device=device)
   inputs = torch.randn(1, 3, 224, 224, device=device)
   labels = torch.rand_like(model(inputs))
   optimizer = torch.optim.SGD(model.parameters(), lr=1e-2, momentum=0.9)
   loss_fn = torch.nn.CrossEntropyLoss()

   with torch.profiler.profile(
       activities=[
           torch.profiler.ProfilerActivity.CPU,
           torch.profiler.ProfilerActivity.CUDA,
       ],
       schedule=torch.profiler.schedule(wait=0, warmup=0, active=6, repeat=1),
       record_shapes=True,
       profile_memory=True,
       with_stack=True,
       on_trace_ready=trace_handler,
   ) as prof:
       for _ in range(num_iters):
           prof.step()
           with record_function("## forward ##"):
               pred = model(inputs)

           with record_function("## backward ##"):
               loss_fn(pred, labels).backward()

           with record_function("## optimizer ##"):
               optimizer.step()
               optimizer.zero_grad(set_to_none=True)

if __name__ == "__main__":
    # Warm up
    run_resnet50()
    # Run the resnet50 model
    run_resnet50()