在使用 GPU 上的 PyTorch 时,您可能熟悉这个常见的错误消息
torch.cuda.OutOfMemoryError: CUDA out of memory. Tried to allocate 512.00 MiB. GPU 0 has a total capacity of 79.32 GiB of which 401.56 MiB is free.
在本系列中,我们将展示如何使用内存工具,包括内存快照、内存分析器和引用循环检测器来调试内存不足错误并改进内存使用。
内存快照工具提供细粒度的 GPU 内存可视化,用于调试 GPU OOM。捕获的内存快照将显示内存事件,包括分配、释放和 OOM,以及它们的堆栈跟踪。
在快照中,每个张量的内存分配都以不同的颜色编码。x 轴是时间,y 轴是 GPU 内存量(MB)。快照是交互式的,因此我们可以通过鼠标悬停来观察任何分配的堆栈跟踪。您可以在 https://github.com/pytorch/pytorch.github.io/blob/site/assets/images/understanding-gpu-memory-1/snapshot.html 亲自试用。
在此快照中,有 3 个峰值显示了 3 个训练迭代中的内存分配(这是可配置的)。查看峰值时,很容易看到前向传播中内存的上升以及反向传播期间的下降,因为梯度被计算出来。还可以看到程序在迭代之间具有相同的内存使用模式。一个突出的地方是许多微小的内存尖峰,通过鼠标悬停,我们看到它们是卷积运算符临时使用的缓冲区。
捕获内存快照
捕获内存快照的 API 非常简单,可在 torch.cuda.memory 中找到
- 开始:
torch.cuda.memory._record_memory_history(max_entries=100000)
- 保存:
torch.cuda.memory._dump_snapshot(file_name)
- 停止:
torch.cuda.memory._record_memory_history(enabled=None)
代码片段(完整代码示例,请参阅附录 A)
# Start recording memory snapshot history, initialized with a buffer
# capacity of 100,000 memory events, via the `max_entries` field.
torch.cuda.memory._record_memory_history(
max_entries=MAX_NUM_OF_MEM_EVENTS_PER_SNAPSHOT
)
# Run your PyTorch Model.
# At any point in time, save a snapshot to file for later.
for _ in range(5):
pred = model(inputs)
loss_fn(pred, labels).backward()
optimizer.step()
optimizer.zero_grad(set_to_none=True)
# In this sample, we save the snapshot after running 5 iterations.
# - Save as many snapshots as you'd like.
# - Snapshots will save last `max_entries` number of memory events
# (100,000 in this example).
try:
torch.cuda.memory._dump_snapshot(f"{file_prefix}.pickle")
except Exception as e:
logger.error(f"Failed to capture memory snapshot {e}")
# Stop recording memory snapshot history.
torch.cuda.memory._record_memory_history(enabled=None)
要可视化快照文件,我们有一个托管在 https://pytorch.ac.cn/memory_viz 的工具。在那里,您可以拖放保存的快照文件,它将绘制每个分配随时间的变化。隐私声明:该工具不会保存您的快照。
或者,您可以使用 pytorch/torch/cuda/_memory_viz.py 中的脚本从 .pickle 生成 HTML,这是一个示例
python torch/cuda/_memory_viz.py trace_plot snapshot.pickle -o snapshot.html
调试 CUDA OOM
让我们看看如何使用内存快照工具来回答
- 为什么会发生 CUDA OOM?
- GPU 内存被用在哪里?
带有错误的 ResNet50
我们已经查看了第一个快照中正常工作的模型。现在,让我们看看一个带有错误的训练示例,请参阅快照
请注意,第二个迭代比第一个迭代使用更多的内存。如果这个模型更大,它可能在第二个迭代中 CUDA OOM'd,而没有更多关于原因的洞察。
在进一步检查此快照时,我们可以清楚地看到,一些张量从第一次迭代到第二次迭代以及后续迭代都保持活动状态。如果我们将鼠标悬停在其中一个张量上,它将显示一个堆栈跟踪,表明这些是梯度张量。
实际上,如果我们查看代码,我们可以看到它没有清除梯度张量,而它本可以在前向传播之前清除它们。
之前
for _ in range(num_iters):
pred = model(inputs)
loss_fn(pred, labels).backward()
optimizer.step()
之后
for _ in range(num_iters):
pred = model(inputs)
loss_fn(pred, labels).backward()
optimizer.step()
# Add this line to clear grad tensors
optimizer.zero_grad(set_to_none=True)
我们可以简单地添加一个 optimizer.zero_grad(set_to_none=True)
指令来清除迭代之间的梯度张量(有关为什么我们需要将梯度归零的更多详细信息,请参阅此处:https://pytorch.ac.cn/tutorials/recipes/recipes/zeroing_out_gradients.html)。
这是我们在更复杂的程序中使用此工具发现的错误的简化版本。我们鼓励您在 GPU 内存问题上试用内存快照,并告诉我们效果如何。
修复错误后的 ResNet50
应用修复程序后,快照似乎正在清除梯度。
我们现在有了正常工作的 ResNet50 模型的快照。亲自试用代码(请参阅附录 A 中的代码示例)。
但您可能想知道,为什么在第一次迭代后内存仍然增加? 为了回答这个问题,让我们在下一节中访问内存分析器。
分类的内存使用情况
内存分析器是 PyTorch 分析器的一个新增功能,它可以分类随时间推移的内存使用情况。我们仍然依赖内存快照来获取堆栈跟踪,以便深入研究内存分配。
要生成内存时间线,这是一个代码片段(完整代码示例在附录 B 中)
# Initialize the profiler context with record_shapes, profile_memory,
# and with_stack set to True.
with torch.profiler.profile(
activities=[
torch.profiler.ProfilerActivity.CPU,
torch.profiler.ProfilerActivity.CUDA,
],
schedule=torch.profiler.schedule(wait=0, warmup=0, active=6, repeat=1),
record_shapes=True,
profile_memory=True,
with_stack=True,
on_trace_ready=trace_handler,
) as prof:
# Run the PyTorch Model inside the profile context.
for _ in range(5):
prof.step()
with record_function("## forward ##"):
pred = model(inputs)
with record_function("## backward ##"):
loss_fn(pred, labels).backward()
with record_function("## optimizer ##"):
optimizer.step()
optimizer.zero_grad(set_to_none=True)
# Construct the memory timeline HTML plot.
prof.export_memory_timeline(f"{file_prefix}.html", device="cuda:0")
有关更多参考,请参阅 https://pytorch.ac.cn/docs/main/profiler.html。
内存分析器根据分析期间记录的张量操作图自动生成类别。
在使用内存分析器收集的此内存时间线中,我们有与之前相同的训练示例。我们可以观察到蓝色中的梯度现在在迭代之间被清除。我们还可以注意到黄色中的优化器状态是在第一次迭代后分配的,并在作业的其余部分保持不变。
此优化器状态是 GPU 内存从第一次迭代增加到第二次迭代的原因。亲自试用代码(请参阅附录 B 中的代码示例)。内存分析器有助于提高训练内存理解,以便模型作者可以找出哪些类别正在使用最多的 GPU 内存。
我在哪里可以找到这些工具?
我们希望这些工具将大大提高您调试 CUDA OOM 的能力,并帮助您按类别了解内存使用情况。
内存快照和内存分析器在 PyTorch v2.1 版本中作为实验性功能提供。
- 有关内存快照的更多信息,请参阅 此处的 PyTorch 内存文档。
- 有关内存分析器的更多详细信息,请参阅 此处的 PyTorch 分析器文档。
反馈
我们期待收到您的反馈,了解我们的工具在增强功能、错误或帮助解决的内存问题方面的信息!与往常一样,请随时在 PyTorch 的 Github 页面上提出新问题。
我们也欢迎来自 OSS 社区的贡献,请随时在任何 Github PR 中标记 Aaron Shi 和 Zachary DeVito 进行审核。
致谢
非常感谢内容审阅者 Mark Saroufim 和 Gregory Chanan 审阅本文并提高其可读性。
非常感谢 Adnan Aziz 和 Lei Tian 的代码审查和反馈。
附录
附录 A - ResNet50 内存快照代码示例
# (c) Meta Platforms, Inc. and affiliates.
import logging
import socket
from datetime import datetime, timedelta
import torch
from torchvision import models
logging.basicConfig(
format="%(levelname)s:%(asctime)s %(message)s",
level=logging.INFO,
datefmt="%Y-%m-%d %H:%M:%S",
)
logger: logging.Logger = logging.getLogger(__name__)
logger.setLevel(level=logging.INFO)
TIME_FORMAT_STR: str = "%b_%d_%H_%M_%S"
# Keep a max of 100,000 alloc/free events in the recorded history
# leading up to the snapshot.
MAX_NUM_OF_MEM_EVENTS_PER_SNAPSHOT: int = 100000
def start_record_memory_history() -> None:
if not torch.cuda.is_available():
logger.info("CUDA unavailable. Not recording memory history")
return
logger.info("Starting snapshot record_memory_history")
torch.cuda.memory._record_memory_history(
max_entries=MAX_NUM_OF_MEM_EVENTS_PER_SNAPSHOT
)
def stop_record_memory_history() -> None:
if not torch.cuda.is_available():
logger.info("CUDA unavailable. Not recording memory history")
return
logger.info("Stopping snapshot record_memory_history")
torch.cuda.memory._record_memory_history(enabled=None)
def export_memory_snapshot() -> None:
if not torch.cuda.is_available():
logger.info("CUDA unavailable. Not exporting memory snapshot")
return
# Prefix for file names.
host_name = socket.gethostname()
timestamp = datetime.now().strftime(TIME_FORMAT_STR)
file_prefix = f"{host_name}_{timestamp}"
try:
logger.info(f"Saving snapshot to local file: {file_prefix}.pickle")
torch.cuda.memory._dump_snapshot(f"{file_prefix}.pickle")
except Exception as e:
logger.error(f"Failed to capture memory snapshot {e}")
return
# Simple Resnet50 example to demonstrate how to capture memory visuals.
def run_resnet50(num_iters=5, device="cuda:0"):
model = models.resnet50().to(device=device)
inputs = torch.randn(1, 3, 224, 224, device=device)
labels = torch.rand_like(model(inputs))
optimizer = torch.optim.SGD(model.parameters(), lr=1e-2, momentum=0.9)
loss_fn = torch.nn.CrossEntropyLoss()
# Start recording memory snapshot history
start_record_memory_history()
for _ in range(num_iters):
pred = model(inputs)
loss_fn(pred, labels).backward()
optimizer.step()
optimizer.zero_grad(set_to_none=True)
# Create the memory snapshot file
export_memory_snapshot()
# Stop recording memory snapshot history
stop_record_memory_history()
if __name__ == "__main__":
# Run the resnet50 model
run_resnet50()
附录 B - ResNet50 内存分析器代码示例
# (c) Meta Platforms, Inc. and affiliates.
import logging
import socket
from datetime import datetime, timedelta
import torch
from torch.autograd.profiler import record_function
from torchvision import models
logging.basicConfig(
format="%(levelname)s:%(asctime)s %(message)s",
level=logging.INFO,
datefmt="%Y-%m-%d %H:%M:%S",
)
logger: logging.Logger = logging.getLogger(__name__)
logger.setLevel(level=logging.INFO)
TIME_FORMAT_STR: str = "%b_%d_%H_%M_%S"
def trace_handler(prof: torch.profiler.profile):
# Prefix for file names.
host_name = socket.gethostname()
timestamp = datetime.now().strftime(TIME_FORMAT_STR)
file_prefix = f"{host_name}_{timestamp}"
# Construct the trace file.
prof.export_chrome_trace(f"{file_prefix}.json.gz")
# Construct the memory timeline file.
prof.export_memory_timeline(f"{file_prefix}.html", device="cuda:0")
def run_resnet50(num_iters=5, device="cuda:0"):
model = models.resnet50().to(device=device)
inputs = torch.randn(1, 3, 224, 224, device=device)
labels = torch.rand_like(model(inputs))
optimizer = torch.optim.SGD(model.parameters(), lr=1e-2, momentum=0.9)
loss_fn = torch.nn.CrossEntropyLoss()
with torch.profiler.profile(
activities=[
torch.profiler.ProfilerActivity.CPU,
torch.profiler.ProfilerActivity.CUDA,
],
schedule=torch.profiler.schedule(wait=0, warmup=0, active=6, repeat=1),
record_shapes=True,
profile_memory=True,
with_stack=True,
on_trace_ready=trace_handler,
) as prof:
for _ in range(num_iters):
prof.step()
with record_function("## forward ##"):
pred = model(inputs)
with record_function("## backward ##"):
loss_fn(pred, labels).backward()
with record_function("## optimizer ##"):
optimizer.step()
optimizer.zero_grad(set_to_none=True)
if __name__ == "__main__":
# Warm up
run_resnet50()
# Run the resnet50 model
run_resnet50()