快捷方式

CUDA 语义

torch.cuda 用于设置和运行 CUDA 操作。它会跟踪当前选定的 GPU,并且你分配的所有 CUDA 张量默认都会在该设备上创建。可以使用 torch.cuda.device 上下文管理器来更改选定的设备。

但是,一旦张量被分配,你就可以在其上执行操作,无论当前选定的设备是哪个,结果将始终位于与该张量相同的设备上。

默认情况下,不允许跨 GPU 操作,但 copy_() 以及其他具有类似复制功能的 methods (例如 to()cuda()) 除外。除非启用点对点内存访问,否则任何尝试在分布于不同设备上的张量上启动操作都会引发错误。

下面是一个展示此情况的小例子

cuda = torch.device('cuda')     # Default CUDA device
cuda0 = torch.device('cuda:0')
cuda2 = torch.device('cuda:2')  # GPU 2 (these are 0-indexed)

x = torch.tensor([1., 2.], device=cuda0)
# x.device is device(type='cuda', index=0)
y = torch.tensor([1., 2.]).cuda()
# y.device is device(type='cuda', index=0)

with torch.cuda.device(1):
    # allocates a tensor on GPU 1
    a = torch.tensor([1., 2.], device=cuda)

    # transfers a tensor from CPU to GPU 1
    b = torch.tensor([1., 2.]).cuda()
    # a.device and b.device are device(type='cuda', index=1)

    # You can also use ``Tensor.to`` to transfer a tensor:
    b2 = torch.tensor([1., 2.]).to(device=cuda)
    # b.device and b2.device are device(type='cuda', index=1)

    c = a + b
    # c.device is device(type='cuda', index=1)

    z = x + y
    # z.device is device(type='cuda', index=0)

    # even within a context, you can specify the device
    # (or give a GPU index to the .cuda call)
    d = torch.randn(2, device=cuda2)
    e = torch.randn(2).to(cuda2)
    f = torch.randn(2).cuda(cuda2)
    # d.device, e.device, and f.device are all device(type='cuda', index=2)

Ampere(及更高版本)设备上的 TensorFloat-32 (TF32)

从 PyTorch 1.7 版本开始,引入了一个名为 allow_tf32 的新标志。在 PyTorch 1.7 到 1.11 版本中,此标志默认为 True;在 PyTorch 1.12 及更高版本中,则默认为 False。此标志控制 PyTorch 是否允许在内部使用 TensorFloat32 (TF32) 张量核心(自 Ampere 架构以来在 NVIDIA GPU 上可用)来计算 matmul(矩阵乘法和批量矩阵乘法)和卷积。

TF32 张量核心旨在通过将输入数据舍入到具有 10 位尾数,并以 FP32 精度累积结果,从而保持 FP32 动态范围,以实现在 torch.float32 张量上的 matmul 和卷积运算获得更好的性能。

matmul 和卷积分别控制,其对应的标志可通过以下方式访问:

# The flag below controls whether to allow TF32 on matmul. This flag defaults to False
# in PyTorch 1.12 and later.
torch.backends.cuda.matmul.allow_tf32 = True

# The flag below controls whether to allow TF32 on cuDNN. This flag defaults to True.
torch.backends.cudnn.allow_tf32 = True

matmul 的精度也可以通过 set_float_32_matmul_precision() 进行更广泛的设置(不限于 CUDA)。请注意,除了 matmul 和卷积本身,内部使用 matmul 或卷积的函数和 nn 模块也会受到影响。其中包括 nn.Linearnn.Conv*、cdist、tensordot、affine grid 和 grid sample、adaptive log softmax、GRU 和 LSTM。

要了解精度和速度,请参阅下面的示例代码和(在 A100 上的)基准测试数据

a_full = torch.randn(10240, 10240, dtype=torch.double, device='cuda')
b_full = torch.randn(10240, 10240, dtype=torch.double, device='cuda')
ab_full = a_full @ b_full
mean = ab_full.abs().mean()  # 80.7277

a = a_full.float()
b = b_full.float()

# Do matmul at TF32 mode.
torch.backends.cuda.matmul.allow_tf32 = True
ab_tf32 = a @ b  # takes 0.016s on GA100
error = (ab_tf32 - ab_full).abs().max()  # 0.1747
relative_error = error / mean  # 0.0022

# Do matmul with TF32 disabled.
torch.backends.cuda.matmul.allow_tf32 = False
ab_fp32 = a @ b  # takes 0.11s on GA100
error = (ab_fp32 - ab_full).abs().max()  # 0.0031
relative_error = error / mean  # 0.000039

从上面的示例可以看出,启用 TF32 后,在 A100 上的速度提高了约 7 倍,并且相对于双精度浮点的相对误差大约大两个数量级。请注意,TF32 与单精度速度的准确比率取决于硬件代系,因为内存带宽与计算的比率以及 TF32 与 FP32 matmul 吞吐量的比率可能因代系或模型而异。如果需要完整的 FP32 精度,用户可以通过以下方式禁用 TF32:

torch.backends.cuda.matmul.allow_tf32 = False
torch.backends.cudnn.allow_tf32 = False

要在 C++ 中关闭 TF32 标志,可以使用:

at::globalContext().setAllowTF32CuBLAS(false);
at::globalContext().setAllowTF32CuDNN(false);

有关 TF32 的更多信息,请参阅:

FP16 GEMM 中的降低精度约简

(与旨在用于 FP16 累积吞吐量高于 FP32 累积吞吐量的硬件上的完整 FP16 累积不同,请参阅 Full FP16 accumulation)

fp16 GEMM 运算可能伴随一些中间降低精度约简(例如,使用 fp16 而非 fp32)。这些选择性的精度约简可以在某些工作负载(尤其是那些具有较大 k 维度的)和 GPU 架构上实现更高的性能,但代价是数值精度和潜在的溢出风险。

V100 上的一些基准测试数据示例:

[--------------------------- bench_gemm_transformer --------------------------]
      [  m ,  k  ,  n  ]    |  allow_fp16_reduc=True  |  allow_fp16_reduc=False
1 threads: --------------------------------------------------------------------
      [4096, 4048, 4096]    |           1634.6        |           1639.8
      [4096, 4056, 4096]    |           1670.8        |           1661.9
      [4096, 4080, 4096]    |           1664.2        |           1658.3
      [4096, 4096, 4096]    |           1639.4        |           1651.0
      [4096, 4104, 4096]    |           1677.4        |           1674.9
      [4096, 4128, 4096]    |           1655.7        |           1646.0
      [4096, 4144, 4096]    |           1796.8        |           2519.6
      [4096, 5096, 4096]    |           2094.6        |           3190.0
      [4096, 5104, 4096]    |           2144.0        |           2663.5
      [4096, 5112, 4096]    |           2149.1        |           2766.9
      [4096, 5120, 4096]    |           2142.8        |           2631.0
      [4096, 9728, 4096]    |           3875.1        |           5779.8
      [4096, 16384, 4096]   |           6182.9        |           9656.5
(times in microseconds).

如果需要完整精度约简,用户可以通过以下方式禁用 fp16 GEMM 中的降低精度约简:

torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = False

要在 C++ 中切换降低精度约简标志,可以使用:

at::globalContext().setAllowFP16ReductionCuBLAS(false);

BF16 GEMM 中的降低精度约简

对于 BFloat16 GEMM 也存在一个类似的标志(如上所示)。请注意,对于 BF16,此开关默认设置为 True,如果你在工作负载中观察到数值不稳定,可能需要将其设置为 False

如果不需要降低精度约简,用户可以通过以下方式禁用 bf16 GEMM 中的降低精度约简:

torch.backends.cuda.matmul.allow_bf16_reduced_precision_reduction = False

要在 C++ 中切换降低精度约简标志,可以使用:

at::globalContext().setAllowBF16ReductionCuBLAS(true);

FP16 GEMM 中的完整 FP16 累积

某些 GPU 在进行 _所有_ FP16 GEMM 累积都在 FP16 中完成时,性能会提高,但代价是数值精度降低和溢出可能性增加。请注意,此设置仅对计算能力为 7.0 (Volta) 或更高版本的 GPU 有效。

可通过以下方式启用此行为:

torch.backends.cuda.matmul.allow_fp16_accumulation = True

要在 C++ 中切换降低精度约简标志,可以使用:

at::globalContext().setAllowFP16AccumulationCuBLAS(true);

异步执行

默认情况下,GPU 操作是异步的。当你调用一个使用 GPU 的函数时,这些操作会被 *排队* 到特定的设备上,但不必立即执行,而是在稍后执行。这使得我们可以并行执行更多计算,包括在 CPU 或其他 GPU 上的操作。

通常,异步计算的效果对调用者是不可见的,因为 (1) 每个设备按照它们被排队的顺序执行操作,并且 (2) PyTorch 在 CPU 和 GPU 之间或两个 GPU 之间复制数据时会自动执行必要的同步。因此,计算将按照仿佛每个操作都是同步执行的方式进行。

可以通过设置环境变量 CUDA_LAUNCH_BLOCKING=1 来强制进行同步计算。这在 GPU 上发生错误时非常方便。(在异步执行中,此类错误直到操作实际执行后才会报告,因此堆栈跟踪不会显示请求发生的位置。)

异步计算的一个结果是,没有同步的时间测量不准确。为了获得精确的测量结果,应该在测量之前调用 torch.cuda.synchronize(),或者使用 torch.cuda.Event 记录时间,如下所示:

start_event = torch.cuda.Event(enable_timing=True)
end_event = torch.cuda.Event(enable_timing=True)
start_event.record()

# Run some things here

end_event.record()
torch.cuda.synchronize()  # Wait for the events to be recorded!
elapsed_time_ms = start_event.elapsed_time(end_event)

作为例外,一些函数(例如 to()copy_())接受一个显式的 non_blocking 参数,该参数允许调用者在不需要时绕过同步。另一个例外是 CUDA 流,将在下面解释。

CUDA 流

CUDA 流是属于特定设备的线性执行序列。通常无需显式创建:默认情况下,每个设备都有自己的“默认”流。

每个流内部的操作按创建顺序串行执行,但来自不同流的操作可以以任何相对顺序并发执行,除非使用了显式同步函数(例如 synchronize()wait_stream())。例如,以下代码是错误的:

cuda = torch.device('cuda')
s = torch.cuda.Stream()  # Create a new stream.
A = torch.empty((100, 100), device=cuda).normal_(0.0, 1.0)
with torch.cuda.stream(s):
    # sum() may start execution before normal_() finishes!
    B = torch.sum(A)

当“当前流”是默认流时,PyTorch 会在数据移动时自动执行必要的同步,如上所述。但是,在使用非默认流时,用户有责任确保正确的同步。此示例的修正版本是:

cuda = torch.device('cuda')
s = torch.cuda.Stream()  # Create a new stream.
A = torch.empty((100, 100), device=cuda).normal_(0.0, 1.0)
s.wait_stream(torch.cuda.default_stream(cuda))  # NEW!
with torch.cuda.stream(s):
    B = torch.sum(A)
A.record_stream(s)  # NEW!

新增了两个部分。torch.cuda.Stream.wait_stream() 调用确保 normal_() 执行完成后,我们才在侧流上开始运行 sum(A)torch.Tensor.record_stream()(详见)确保在 sum(A) 完成之前,我们不会释放 A。你也可以稍后在某个时间点手动等待流完成,使用 torch.cuda.default_stream(cuda).wait_stream(s)(请注意,立即等待是没意义的,因为那会阻止流执行与默认流上的其他工作并行运行)。关于何时使用其中一种或另一种方法的更多细节,请参阅 torch.Tensor.record_stream() 的文档。

请注意,即使没有读取依赖,这种同步也是必要的,例如在这个例子中所示:

cuda = torch.device('cuda')
s = torch.cuda.Stream()  # Create a new stream.
A = torch.empty((100, 100), device=cuda)
s.wait_stream(torch.cuda.default_stream(cuda))  # STILL REQUIRED!
with torch.cuda.stream(s):
    A.normal_(0.0, 1.0)
    A.record_stream(s)

尽管在 s 上的计算没有读取 A 的内容,并且没有其他地方使用 A,但仍然需要同步,因为 A 可能对应于 CUDA 缓存分配器重新分配的内存,而旧的(已释放的)内存上可能还有待处理的操作。

反向传播的流语义

每个 CUDA 反向操作都在与其对应的正向操作相同的流上运行。如果你的正向传播在不同的流上并行运行独立操作,这有助于反向传播利用相同的并行性。

反向调用相对于周围操作的流语义与任何其他调用相同。即使如前一段所述,反向操作在多个流上运行,反向传播也会插入内部同步以确保这一点。更具体地说,当调用 autograd.backwardautograd.gradtensor.backward,并选择性地提供 CUDA 张量作为初始梯度时(例如 autograd.backward(..., grad_tensors=initial_grads)autograd.grad(..., grad_outputs=initial_grads)tensor.backward(..., gradient=initial_grad)),以下行为:

  1. 可选地填充初始梯度,

  2. 调用反向传播,以及

  3. 使用梯度

与任何一组操作具有相同的流语义关系

s = torch.cuda.Stream()

# Safe, grads are used in the same stream context as backward()
with torch.cuda.stream(s):
    loss.backward()
    use grads

# Unsafe
with torch.cuda.stream(s):
    loss.backward()
use grads

# Safe, with synchronization
with torch.cuda.stream(s):
    loss.backward()
torch.cuda.current_stream().wait_stream(s)
use grads

# Safe, populating initial grad and invoking backward are in the same stream context
with torch.cuda.stream(s):
    loss.backward(gradient=torch.ones_like(loss))

# Unsafe, populating initial_grad and invoking backward are in different stream contexts,
# without synchronization
initial_grad = torch.ones_like(loss)
with torch.cuda.stream(s):
    loss.backward(gradient=initial_grad)

# Safe, with synchronization
initial_grad = torch.ones_like(loss)
s.wait_stream(torch.cuda.current_stream())
with torch.cuda.stream(s):
    initial_grad.record_stream(s)
    loss.backward(gradient=initial_grad)

BC 注意:在默认流上使用梯度

在早期版本的 PyTorch(1.9 及更早版本)中,autograd 引擎始终将默认流与所有反向操作同步,因此以下模式:

with torch.cuda.stream(s):
    loss.backward()
use grads

是安全的,只要 use grads 发生在默认流上。在当前版本的 PyTorch 中,该模式不再安全。如果 backward()use grads 在不同的流上下文中,你必须同步流:

with torch.cuda.stream(s):
    loss.backward()
torch.cuda.current_stream().wait_stream(s)
use grads

即使 use grads 位于默认流上。

内存管理

PyTorch 使用缓存内存分配器来加速内存分配。这允许在没有设备同步的情况下快速释放内存。但是,由分配器管理的未使用内存仍然会在 nvidia-smi 中显示为已使用。你可以使用 memory_allocated()max_memory_allocated() 监控张量占用的内存,并使用 memory_reserved()max_memory_reserved() 监控由缓存分配器管理的总内存量。调用 empty_cache() 会释放 PyTorch 中所有 未使用 的缓存内存,以便其他 GPU 应用程序可以使用它们。但是,张量占用的 GPU 内存不会被释放,因此它不能增加 PyTorch 可用的 GPU 内存量。

为了更好地了解 CUDA 内存随时间的使用情况,理解 CUDA 内存使用 介绍了用于捕获和可视化内存使用轨迹的工具。

对于更高级的用户,我们通过 memory_stats() 提供更全面的内存基准测试。我们还提供通过 memory_snapshot() 捕获内存分配器完整快照的功能,这有助于你理解代码产生的底层分配模式。

使用 PYTORCH_CUDA_ALLOC_CONF 优化内存使用

使用缓存分配器可能会干扰 cuda-memcheck 等内存检查工具。要使用 cuda-memcheck 调试内存错误,请在你的环境中设置 PYTORCH_NO_CUDA_MEMORY_CACHING=1 以禁用缓存。

可以通过环境变量 PYTORCH_CUDA_ALLOC_CONF 控制缓存分配器的行为。格式为 PYTORCH_CUDA_ALLOC_CONF=<选项>:<值>,<选项2>:<值2>... 可用选项:

  • backend 允许选择底层分配器实现。当前有效选项是 native,它使用 PyTorch 的原生实现,以及 cudaMallocAsync,它使用 CUDA 内置的异步分配器(参见此处)。cudaMallocAsync 需要 CUDA 11.4 或更高版本。默认值为 nativebackend 适用于进程使用的所有设备,不能按设备指定。

  • max_split_size_mb 防止原生分配器分割大于此大小(以 MB 为单位)的块。这可以减少碎片化,并可能允许某些边界工作负载在不耗尽内存的情况下完成。性能开销从“零”到“相当大”不等,具体取决于分配模式。默认值是无限制,即所有块都可以分割。memory_stats()memory_summary() 方法对于调优很有用。此选项应作为因“内存不足”而中止且显示大量非活动分割块的工作负载的最后手段。max_split_size_mb 仅在 backend:native 时有意义。使用 backend:cudaMallocAsync 时,max_split_size_mb 将被忽略。

  • roundup_power2_divisions 有助于将请求的分配大小舍入到最接近的 2 的幂次划分,从而更好地利用块。在原生 CUDACachingAllocator 中,大小会按 512 的块大小进行向上舍入,因此对于较小的大小工作良好。然而,对于较大的临近分配,这可能效率低下,因为每个分配会分配到不同大小的块,导致这些块的重用最小化。这可能会产生大量未使用的块并浪费 GPU 内存容量。此选项启用将分配大小舍入到最接近的 2 的幂次划分。例如,如果我们需要舍入大小 1200,并且划分数为 4,则大小 1200 位于 1024 和 2048 之间,如果我们在它们之间进行 4 次划分,得到的值是 1024、1280、1536 和 1792。因此,分配大小 1200 将被舍入到最接近的 2 的幂次划分上限 1280。可以指定一个单一值应用于所有分配大小,或者指定一个键值对数组,以便为每个 2 的幂次区间单独设置划分数。例如,要为所有小于 256MB 的分配设置 1 次划分,为 256MB 到 512MB 之间的分配设置 2 次划分,为 512MB 到 1GB 之间的分配设置 4 次划分,以及为所有更大的分配设置 8 次划分,请将旋钮值设置为:[256:1,512:2,1024:4,>:8]。roundup_power2_divisions 仅在 backend:native 时有意义。使用 backend:cudaMallocAsync 时,roundup_power2_divisions 将被忽略。

  • max_non_split_rounding_mb 将允许非分割块以更好地重用,例如,

    一个 1024MB 的缓存块可以被重用于 512MB 的分配请求。在默认情况下,我们只允许非分割块最多舍入 20MB,因此 512MB 的块只能由大小在 512-532 MB 之间的块服务。如果我们将此选项的值设置为 1024,它将允许大小在 512-1536 MB 之间的块用于 512MB 的块,从而增加了对较大块的重用。这也有助于减少由于避免昂贵的 cudaMalloc 调用而导致的停顿。

  • garbage_collection_threshold 有助于主动回收未使用的 GPU 内存,以避免触发昂贵的同步并回收所有内存的操作 (release_cached_blocks),这对于对延迟敏感的 GPU 应用程序(例如服务器)可能不利。设置此阈值(例如 0.8)后,如果 GPU 内存容量使用量超过此阈值(即分配给 GPU 应用程序的总内存的 80%),分配器将开始回收 GPU 内存块。算法优先释放旧的且未使用的块,以避免释放正在积极重用的块。阈值应大于 0.0 且小于 1.0。garbage_collection_threshold 仅在 backend:native 时有意义。使用 backend:cudaMallocAsync 时,garbage_collection_threshold 将被忽略。

  • expandable_segments(实验性功能,默认值:False)如果设置为 True,此设置会指示分配器创建可后续扩展的 CUDA 分配,以便更好地处理频繁更改分配大小的情况,例如批量大小发生变化。通常对于大型(>2MB)分配,分配器会调用 cudaMalloc 获取与用户请求大小相同的分配。将来,这些分配的某些部分在空闲时可以重用于其他请求。当程序发出许多大小完全相同或为该大小的偶数倍的请求时,这种方法效果很好。许多深度学习模型都遵循这种行为。然而,一个常见的例外是当批量大小在一次迭代到下一次迭代之间略有变化时,例如在批量推理中。当程序最初以批量大小 N 运行时,它将进行适合该大小的分配。如果将来以大小 N - 1 运行,现有的分配仍然足够大。但是,如果以大小 N + 1 运行,它将不得不进行稍微大一些的新分配。并非所有张量都具有相同的大小。有些可能是 (N + 1)*A,而另一些可能是 (N + 1)*A*B,其中 AB 是模型中的非批量维度。由于分配器会在现有分配足够大时重用它们,所以一些 (N + 1)*A 的分配实际上会适合已有的 N*B*A 段,尽管不是完全匹配。随着模型的运行,它会部分填充所有这些段,并在这些段的末尾留下不可用的空闲内存片段。分配器在某个时候将需要 cudaMalloc 一个新的 (N + 1)*A*B 段。如果内存不足,现在就无法恢复现有段末尾的空闲内存片段。对于深度超过 50 层的模型,这种模式可能会重复 50 次以上,产生许多碎片。

    expandable_segments 允许分配器最初创建一个段,然后在需要更多内存时再扩展其大小。它不像每进行一次分配就创建一个段,而是尝试创建一个(每流)随着需要增长的段。现在,当 N + 1 的情况运行时,分配将很好地填充到这个大段中,直到它被填满。然后,会请求更多内存并附加到段的末尾。这个过程不会产生那么多不可用的内存碎片,因此更有可能成功找到所需的内存。

    pinned_use_cuda_host_register 选项是一个布尔标志,用于确定是否使用 CUDA API 的 cudaHostRegister 函数来分配 pinned memory(固定内存),而不是默认的 cudaHostAlloc。当设置为 True 时,内存使用常规的 malloc 分配,然后在调用 cudaHostRegister 之前将页面映射到内存。这种页面的预映射有助于减少执行 cudaHostRegister 期间的锁定时间。

    pinned_num_register_threads 选项仅在 pinned_use_cuda_host_register 设置为 True 时有效。默认情况下,使用一个线程来映射页面。此选项允许使用更多线程来并行执行页面映射操作,以减少 pinned memory 的总分配时间。根据基准测试结果,此选项的良好值为 8。

    pinned_use_background_threads 选项是一个布尔标志,用于启用后台线程来处理事件。这避免了在快速分配路径中查询/处理事件相关的任何慢速路径。此功能默认禁用。

注意

CUDA 内存管理 API 报告的一些统计信息是 backend:native 特有的,对于 backend:cudaMallocAsync 没有意义。有关详细信息,请参阅每个函数的 docstring。

使用 CUDA 的自定义内存分配器

可以将分配器定义为 C/C++ 中的简单函数,并将其编译为共享库,下面的代码展示了一个只跟踪所有内存操作的基本分配器。

#include <sys/types.h>
#include <cuda_runtime_api.h>
#include <iostream>
// Compile with g++ alloc.cc -o alloc.so -I/usr/local/cuda/include -shared -fPIC
extern "C" {
void* my_malloc(ssize_t size, int device, cudaStream_t stream) {
   void *ptr;
   cudaMalloc(&ptr, size);
   std::cout<<"alloc "<<ptr<<size<<std::endl;
   return ptr;
}

void my_free(void* ptr, ssize_t size, int device, cudaStream_t stream) {
   std::cout<<"free "<<ptr<< " "<<stream<<std::endl;
   cudaFree(ptr);
}
}

这可以通过 torch.cuda.memory.CUDAPluggableAllocator 在 python 中使用。用户负责提供 .so 文件的路径以及与上述签名匹配的 alloc/free 函数的名称。

import torch

# Load the allocator
new_alloc = torch.cuda.memory.CUDAPluggableAllocator(
    'alloc.so', 'my_malloc', 'my_free')
# Swap the current allocator
torch.cuda.memory.change_current_allocator(new_alloc)
# This will allocate memory in the device using the new allocator
b = torch.zeros(10, device='cuda')
import torch

# Do an initial memory allocator
b = torch.zeros(10, device='cuda')
# Load the allocator
new_alloc = torch.cuda.memory.CUDAPluggableAllocator(
    'alloc.so', 'my_malloc', 'my_free')
# This will error since the current allocator was already instantiated
torch.cuda.memory.change_current_allocator(new_alloc)

在同一个程序中混合不同的 CUDA 系统分配器

根据您的用例,change_current_allocator() 可能不是您想使用的,因为它会替换整个程序的 CUDA 分配器(类似于 PYTORCH_CUDA_ALLOC_CONF=backend:cudaMallocAsync)。例如,如果替换的分配器没有缓存机制,您将失去 PyTorch 的 CUDACachingAllocator 的所有好处。相反,您可以使用 torch.cuda.MemPool 选择性地标记一段 PyTorch 代码来使用自定义分配器。这将允许您在同一个 PyTorch 程序中使用多个 CUDA 系统分配器,并保留 CUDACachingAllocator 的大部分优点(例如缓存)。使用 torch.cuda.MemPool,您可以利用自定义分配器启用多种功能,例如

  • 使用 ncclMemAlloc 分配器为 all-reduce 分配输出缓冲区可以启用 NVLink Switch Reductions (NVLS)。这可以减少 GPU 资源(SM 和 Copy Engines)上重叠计算和通信内核之间的竞争,尤其是在 tensor-parallel 工作负载上。

  • 对于基于 Grace CPU 的系统,使用 cuMemCreate 并指定 CU_MEM_LOCATION_TYPE_HOST_NUMA 为 all-gather 分配主机输出缓冲区可以启用基于 Extended GPU Memory (EGM) 的内存传输,从源 GPU 到目标 CPU。这加速了 all-gather,因为传输通过 NVLinks 进行,否则将通过带宽受限的网络接口卡 (NIC) 链接进行。这种加速的 all-gather 反过来可以加快模型检查点保存。

  • 如果您正在构建模型,并且最初不想考虑内存密集型模块(例如嵌入表)的最佳内存放置,或者您有一个对性能不敏感且无法容纳在 GPU 中的模块,那么您可以直接使用 cudaMallocManaged 并指定首选 CPU 位置来分配该模块,从而首先让您的模型运行起来。

注意

虽然 cudaMallocManaged 使用 CUDA 统一虚拟内存 (UVM) 提供了便捷的自动内存管理,但它不推荐用于 DL 工作负载。对于适合 GPU 内存的 DL 工作负载,显式放置始终优于 UVM,因为没有页面错误且访问模式保持可预测。当 GPU 内存饱和时,UVM 必须执行代价高昂的双重传输,在引入新页面之前将旧页面逐出到 CPU。

下面的代码展示了将 ncclMemAlloc 包装在 torch.cuda.memory.CUDAPluggableAllocator 中的示例。

import os

import torch
import torch.distributed as dist
from torch.cuda.memory import CUDAPluggableAllocator
from torch.distributed.distributed_c10d import _get_default_group
from torch.utils import cpp_extension


# create allocator
nccl_allocator_source = """
#include <nccl.h>
#include <iostream>
extern "C" {

void* nccl_alloc_plug(size_t size, int device, void* stream) {
  std::cout << "Using ncclMemAlloc" << std::endl;
  void* ptr;
  ncclResult_t err = ncclMemAlloc(&ptr, size);
  return ptr;

}

void nccl_free_plug(void* ptr, size_t size, int device, void* stream) {
  std::cout << "Using ncclMemFree" << std::endl;
  ncclResult_t err = ncclMemFree(ptr);
}

}
"""
nccl_allocator_libname = "nccl_allocator"
nccl_allocator = torch.utils.cpp_extension.load_inline(
    name=nccl_allocator_libname,
    cpp_sources=nccl_allocator_source,
    with_cuda=True,
    extra_ldflags=["-lnccl"],
    verbose=True,
    is_python_module=False,
    build_directory="./",
)

allocator = CUDAPluggableAllocator(
    f"./{nccl_allocator_libname}.so", "nccl_alloc_plug", "nccl_free_plug"
).allocator()

# setup distributed
rank = int(os.getenv("RANK"))
local_rank = int(os.getenv("LOCAL_RANK"))
world_size = int(os.getenv("WORLD_SIZE"))
torch.cuda.set_device(local_rank)
dist.init_process_group(backend="nccl")
device = torch.device(f"cuda:{local_rank}")
default_pg = _get_default_group()
backend = default_pg._get_backend(device)

# Note: for convenience, ProcessGroupNCCL backend provides
# the ncclMemAlloc allocator as backend.mem_allocator
allocator = backend.mem_allocator

现在,您可以通过将此分配器传递给 torch.cuda.MemPool 来定义新的内存池

pool = torch.cuda.MemPool(allocator)

然后可以将该内存池与 torch.cuda.use_mem_pool 上下文管理器一起使用,将张量分配到该内存池中

with torch.cuda.use_mem_pool(pool):
    # tensor gets allocated with ncclMemAlloc passed in the pool
    tensor = torch.arange(1024 * 1024 * 2, device=device)
    print(f"tensor ptr on rank {rank} is {hex(tensor.data_ptr())}")

# register user buffers using ncclCommRegister (called under the hood)
backend.register_mem_pool(pool)

# Collective uses Zero Copy NVLS
dist.all_reduce(tensor[0:4])
torch.cuda.synchronize()
print(tensor[0:4])

请注意上面示例中 register_mem_pool 的用法。这是 NVLS reductions 的额外步骤,其中用户缓冲区需要向 NCCL 注册。用户可以使用类似的 deregister_mem_pool 调用来注销缓冲区。

要回收内存,用户首先需要确保没有任何内容正在使用该内存池。当没有张量持有对该内存池的引用时,在删除内存池时,内部将调用 empty_cache(),从而将所有内存返回给系统。

del tensor, del pool

以下 torch.cuda.MemPool.use_count()torch.cuda.MemPool.snapshot() API 可用于调试目的

pool = torch.cuda.MemPool(allocator)

# pool's use count should be 1 at this point as MemPool object
# holds a reference
assert pool.use_count() == 1

nelem_1mb = 1024 * 1024 // 4

with torch.cuda.use_mem_pool(pool):
    out_0 = torch.randn(nelem_1mb, device="cuda")

    # pool's use count should be 2 at this point as use_mem_pool
    # holds a reference
    assert pool.use_count() == 2

# pool's use count should be back to 1 at this point as use_mem_pool
# released its reference
assert pool.use_count() == 1

with torch.cuda.use_mem_pool(pool):
    # pool should have 1 segment since we made a small allocation (1 MB)
    # above and so the CUDACachingAllocator packed it into a 2 MB buffer
    assert len(pool.snapshot()) == 1

    out_1 = torch.randn(nelem_1mb, device="cuda")

    # pool should still have 1 segment since we made another small allocation
    # (1 MB) that got packed into the existing 2 MB buffer
    assert len(pool.snapshot()) == 1

    out_2 = torch.randn(nelem_1mb, device="cuda")

    # pool now should have 2 segments since the CUDACachingAllocator had
    # to make a new 2 MB buffer to accomodate out_2
    assert len(pool.snapshot()) == 2

注意

  • torch.cuda.MemPool 持有对内存池的引用。当您使用 torch.cuda.use_mem_pool 上下文管理器时,它也会获取对内存池的另一个引用。退出上下文管理器时,它会释放其引用。之后,理想情况下,应该只有张量持有对内存池的引用。一旦张量释放了它们的引用,内存池的使用计数将变为 1,这反映出只有 torch.cuda.MemPool 对象持有引用。只有在那时,当使用 del 调用内存池的析构函数时,内存池持有的内存才能返回给系统。

  • torch.cuda.MemPool 当前不支持 CUDACachingAllocator 的 expandable_segments 模式。

  • NCCL 对缓冲区与 NVLS reductions 的兼容性有特定要求。在动态工作负载中,这些要求可能会被违反,例如,CUDACachingAllocator 发送给 NCCL 的缓冲区可能被分割,因此未正确对齐。在这些情况下,NCCL 可以使用回退算法而不是 NVLS。

  • ncclMemAlloc 这样的分配器由于对齐要求(CU_MULTICAST_GRANULARITY_RECOMMENDED, CU_MULTICAST_GRANULARITY_MINIMUM)可能会使用比请求更多的内存,这可能导致您的工作负载内存不足。

cuBLAS 工作区

对于 cuBLAS 句柄和 CUDA 流的每种组合,如果该句柄和流组合执行需要工作区的 cuBLAS 内核,就会分配一个 cuBLAS 工作区。为了避免重复分配工作区,除非调用 torch._C._cuda_clearCublasWorkspaces(),否则这些工作区不会被释放。每次分配的工作区大小可以通过环境变量 CUBLAS_WORKSPACE_CONFIG 指定,格式为 :[SIZE]:[COUNT]。例如,每次分配的默认工作区大小是 CUBLAS_WORKSPACE_CONFIG=:4096:2:16:8,它指定了总大小为 2 * 4096 + 8 * 16 KiB。要强制 cuBLAS 避免使用工作区,请设置 CUBLAS_WORKSPACE_CONFIG=:0:0

cuFFT plan 缓存

对于每个 CUDA 设备,使用 cuFFT plans 的 LRU 缓存来加速在相同几何结构和相同配置的 CUDA 张量上重复运行 FFT 方法(例如 torch.fft.fft())。由于某些 cuFFT plans 可能会分配 GPU 内存,这些缓存具有最大容量。

您可以使用以下 API 控制和查询当前设备的缓存属性

  • torch.backends.cuda.cufft_plan_cache.max_size 提供缓存的容量(默认在 CUDA 10 及更新版本上为 4096,在较旧的 CUDA 版本上为 1023)。直接设置此值会修改容量。

  • torch.backends.cuda.cufft_plan_cache.size 提供当前驻留在缓存中的 plans 数量。

  • torch.backends.cuda.cufft_plan_cache.clear() 清空缓存。

要控制和查询非默认设备的 plan 缓存,您可以使用 torch.device 对象或设备索引来索引 torch.backends.cuda.cufft_plan_cache 对象,并访问上述属性之一。例如,要设置设备 1 的缓存容量,可以写 torch.backends.cuda.cufft_plan_cache[1].max_size = 10

即时编译

PyTorch 在对 CUDA 张量执行某些操作(例如 torch.special.zeta)时会进行即时编译。此编译可能非常耗时(根据您的硬件和软件,可能需要几秒钟),并且对于单个 operator 可能会发生多次,因为许多 PyTorch operator 实际上会从各种 kernels 中进行选择,每个 kernel 都必须根据其输入编译一次。这种编译每个进程发生一次,或者如果使用 kernel 缓存,则只发生一次。

默认情况下,如果定义了 XDG_CACHE_HOME,PyTorch 会在 $XDG_CACHE_HOME/torch/kernels 中创建 kernel 缓存;如果未定义,则在 $HOME/.cache/torch/kernels 中创建(Windows 除外,Windows 尚不支持 kernel 缓存)。缓存行为可以通过两个环境变量直接控制。如果将 USE_PYTORCH_KERNEL_CACHE 设置为 0,则不使用缓存;如果设置了 PYTORCH_KERNEL_CACHE_PATH,则将该路径用作 kernel 缓存,而不是默认位置。

最佳实践

设备无关代码

由于 PyTorch 的结构,您可能需要显式编写与设备无关(CPU 或 GPU)的代码;一个例子可能是创建新张量作为循环神经网络的初始隐藏状态。

第一步是确定是否应该使用 GPU。一种常见的模式是使用 Python 的 argparse 模块读取用户参数,并设置一个标志,可以结合 is_available() 来禁用 CUDA。在下文中,args.device 会产生一个 torch.device 对象,该对象可用于将张量移动到 CPU 或 CUDA。

import argparse
import torch

parser = argparse.ArgumentParser(description='PyTorch Example')
parser.add_argument('--disable-cuda', action='store_true',
                    help='Disable CUDA')
args = parser.parse_args()
args.device = None
if not args.disable_cuda and torch.cuda.is_available():
    args.device = torch.device('cuda')
else:
    args.device = torch.device('cpu')

注意

在给定环境中评估 CUDA 可用性(is_available())时,PyTorch 的默认行为是调用 CUDA Runtime API 方法 cudaGetDeviceCount。由于此调用反过来会初始化 CUDA Driver API(通过 cuInit),如果它尚未初始化,则运行过 is_available() 的进程的后续分支(fork)将因 CUDA 初始化错误而失败。

您可以在导入执行 is_available() 的 PyTorch 模块之前(或在直接执行它之前)在您的环境中设置 PYTORCH_NVML_BASED_CUDA_CHECK=1,以便指示 is_available() 尝试基于 NVML 的评估(nvmlDeviceGetCount_v2)。如果基于 NVML 的评估成功(即 NVML 发现/初始化未失败),则 is_available() 调用将不会影响后续的分支。

如果 NVML 发现/初始化失败,is_available() 将回退到标准的 CUDA Runtime API 评估,并且上述分支约束将适用。

请注意,上述基于 NVML 的 CUDA 可用性评估提供的保证比默认的 CUDA Runtime API 方法(需要 CUDA 初始化成功)要弱。在某些情况下,基于 NVML 的检查可能成功,而后续的 CUDA 初始化却失败。

现在我们有了 args.device,我们可以使用它在所需的设备上创建 Tensor。

x = torch.empty((8, 42), device=args.device)
net = Network().to(device=args.device)

这可以在多种情况下用于生成设备无关代码。下面是使用 dataloader 时的示例

cuda0 = torch.device('cuda:0')  # CUDA GPU 0
for i, x in enumerate(train_loader):
    x = x.to(cuda0)

在使用系统上的多个 GPU 时,您可以使用 CUDA_VISIBLE_DEVICES 环境变量标志来管理哪些 GPU 对 PyTorch 可用。如上所述,要手动控制张量在哪里创建,最佳实践是使用 torch.cuda.device 上下文管理器。

print("Outside device is 0")  # On device 0 (default in most scenarios)
with torch.cuda.device(1):
    print("Inside device is 1")  # On device 1
print("Outside device is still 0")  # On device 0

如果您有一个张量,并希望在同一个设备上创建相同类型的新张量,那么您可以使用 torch.Tensor.new_* 方法(参见 torch.Tensor)。虽然前面提到的 torch.* 工厂函数(创建操作)取决于当前的 GPU 上下文和您传入的属性参数,但 torch.Tensor.new_* 方法会保留张量的设备和其他属性。

这是在创建需要在前向传播期间在内部创建新张量的模块时推荐的做法。

cuda = torch.device('cuda')
x_cpu = torch.empty(2)
x_gpu = torch.empty(2, device=cuda)
x_cpu_long = torch.empty(2, dtype=torch.int64)

y_cpu = x_cpu.new_full([3, 2], fill_value=0.3)
print(y_cpu)

    tensor([[ 0.3000,  0.3000],
            [ 0.3000,  0.3000],
            [ 0.3000,  0.3000]])

y_gpu = x_gpu.new_full([3, 2], fill_value=-5)
print(y_gpu)

    tensor([[-5.0000, -5.0000],
            [-5.0000, -5.0000],
            [-5.0000, -5.0000]], device='cuda:0')

y_cpu_long = x_cpu_long.new_tensor([[1, 2, 3]])
print(y_cpu_long)

    tensor([[ 1,  2,  3]])

如果您想创建一个与另一个张量类型和大小相同,并用全一或全零填充的张量,提供了 ones_like()zeros_like() 作为方便的辅助函数(它们也会保留 Tensor 的 torch.devicetorch.dtype)。

x_cpu = torch.empty(2, 3)
x_gpu = torch.empty(2, 3)

y_cpu = torch.ones_like(x_cpu)
y_gpu = torch.zeros_like(x_gpu)

使用 pinned memory buffer

警告

这是一条高级提示。如果过度使用 pinned memory,在 RAM 不足时可能会导致严重问题,并且您应该知道 pinning 通常是一项昂贵的操作。

从 pinned(页面锁定)内存复制到 GPU 要快得多。CPU 张量和存储器暴露了一个 pin_memory() 方法,该方法返回对象的副本,并将数据放入 pinned 区域。

此外,一旦您 pin 了一个张量或存储器,您就可以使用异步 GPU 复制。只需向 to()cuda() 调用传递一个额外的 non_blocking=True 参数。这可以用于将数据传输与计算重叠。

您可以通过向 DataLoader 的构造函数传递 pin_memory=True,使其返回放置在 pinned memory 中的批次。

使用 `nn.parallel.DistributedDataParallel` 而非 `multiprocessing` 或 `nn.DataParallel`

大多数涉及批处理输入和多个 GPU 的用例应默认使用 DistributedDataParallel 来利用多个 GPU。

将 CUDA 模型与 multiprocessing 一起使用存在显著的注意事项;除非严格满足数据处理要求,否则您的程序很可能会出现不正确或未定义的行为。

建议使用 DistributedDataParallel 而不是 DataParallel 进行多 GPU 训练,即使只有一个节点。

DistributedDataParallelDataParallel 的区别在于:DistributedDataParallel 使用多进程,为每个 GPU 创建一个进程,而 DataParallel 使用多线程。通过使用多进程,每个 GPU 都有其专用进程,这避免了 Python 解释器 GIL 引起的性能开销。

如果您使用 DistributedDataParallel,您可以使用 torch.distributed.launch 工具来启动您的程序,参见 第三方后端

CUDA Graphs

CUDA Graph 是 CUDA 流及其依赖流执行的工作(主要是 kernels 及其参数)的记录。有关基本原理和底层 CUDA API 的详细信息,请参见 CUDA Graphs 入门 和 CUDA C 编程指南的 Graphs 部分

PyTorch 支持使用 stream capture 构建 CUDA graphs,这将 CUDA 流置于捕获模式(capture mode)。发送到捕获流的 CUDA 工作实际上不在 GPU 上运行。相反,工作会被记录在一个 graph 中。

捕获后,graph 可以被启动(launched)以根据需要多次运行 GPU 工作。每次重放(replay)都使用相同的参数运行相同的 kernels。对于指针参数,这意味着使用相同的内存地址。通过在每次重放之前用新数据(例如来自新批次)填充输入内存,您可以在新数据上重新运行相同的工作。

为何使用 CUDA Graphs?

重放 graph 以牺牲典型即时执行的动态灵活性为代价,换取大幅降低的 CPU 开销。graph 的参数和 kernels 是固定的,因此 graph 重放跳过了参数设置和 kernel 调度的所有层,包括 Python、C++ 和 CUDA driver 的开销。在底层,重放通过一次调用 cudaGraphLaunch 将整个 graph 的工作提交到 GPU。重放中的 kernels 在 GPU 上执行速度也稍快,但消除 CPU 开销是主要好处。

如果您的网络全部或部分是 graph-safe 的(通常这意味着静态形状和静态控制流,但请参阅其他约束),并且您怀疑其运行时至少在某种程度上受到 CPU 限制,则应尝试使用 CUDA graphs。

PyTorch API

警告

此 API 处于 beta 阶段,未来版本可能会发生变化。

PyTorch 通过原始的 torch.cuda.CUDAGraph 类和两个方便的包装器 torch.cuda.graphtorch.cuda.make_graphed_callables 来暴露 graphs。

torch.cuda.graph 是一个简单、通用的上下文管理器,它在其上下文中捕获 CUDA 工作。在捕获之前,通过运行几次 eager 迭代来热身要捕获的工作负载。热身必须在 side stream 上进行。由于 graph 在每次重放中都读取和写入相同的内存地址,因此在捕获期间必须保持对保存输入和输出数据的张量的长期引用。要在新输入数据上运行 graph,请将新数据复制到捕获的输入张量中,重放 graph,然后从捕获的输出张量中读取新输出。示例

g = torch.cuda.CUDAGraph()

# Placeholder input used for capture
static_input = torch.empty((5,), device="cuda")

# Warmup before capture
s = torch.cuda.Stream()
s.wait_stream(torch.cuda.current_stream())
with torch.cuda.stream(s):
    for _ in range(3):
        static_output = static_input * 2
torch.cuda.current_stream().wait_stream(s)

# Captures the graph
# To allow capture, automatically sets a side stream as the current stream in the context
with torch.cuda.graph(g):
    static_output = static_input * 2

# Fills the graph's input memory with new data to compute on
static_input.copy_(torch.full((5,), 3, device="cuda"))
g.replay()
# static_output holds the results
print(static_output)  # full of 3 * 2 = 6

# Fills the graph's input memory with more data to compute on
static_input.copy_(torch.full((5,), 4, device="cuda"))
g.replay()
print(static_output)  # full of 4 * 2 = 8

有关实际和高级模式,请参见全网络捕获与 torch.cuda.amp 的用法与多个 stream 的用法

make_graphed_callables 更复杂。make_graphed_callables 接受 Python 函数和 torch.nn.Module。对于每个传递的函数或 Module,它会创建单独的前向传播和后向传播工作的 graphs。参见部分网络捕获

约束

如果一组 ops 不违反以下任何约束,则它们是可捕获的(capturable)

约束适用于 torch.cuda.graph 上下文中的所有工作以及您传递给 torch.cuda.make_graphed_callables() 的任何可调用对象的前向和后向传播中的所有工作。

违反其中任何一项都很可能会导致运行时错误

违反其中任何一条都可能导致静默的数值错误或未定义行为。

  • 在一个进程内,一次只能进行一次捕获。

  • 捕获进行期间,此进程中(在任何线程上)不允许运行非捕获的 CUDA 工作。

  • CPU 工作不会被捕获。如果捕获的操作包含 CPU 工作,该工作在重放时将被省略。

  • 每次重放都读写相同的(虚拟)内存地址。

  • 禁止使用(基于 CPU 或 GPU 数据)的动态控制流。

  • 禁止使用动态形状。图假定捕获的操作序列中的每个张量在每次重放时都具有相同的大小和布局。

  • 在捕获中使用多个流是允许的,但存在限制

非限制

  • 捕获后,图可以在任何流上重放。

完整网络捕获

如果您的整个网络都可以捕获,您可以捕获并重放整个迭代。

N, D_in, H, D_out = 640, 4096, 2048, 1024
model = torch.nn.Sequential(torch.nn.Linear(D_in, H),
                            torch.nn.Dropout(p=0.2),
                            torch.nn.Linear(H, D_out),
                            torch.nn.Dropout(p=0.1)).cuda()
loss_fn = torch.nn.MSELoss()
optimizer = torch.optim.SGD(model.parameters(), lr=0.1)

# Placeholders used for capture
static_input = torch.randn(N, D_in, device='cuda')
static_target = torch.randn(N, D_out, device='cuda')

# warmup
# Uses static_input and static_target here for convenience,
# but in a real setting, because the warmup includes optimizer.step()
# you must use a few batches of real data.
s = torch.cuda.Stream()
s.wait_stream(torch.cuda.current_stream())
with torch.cuda.stream(s):
    for i in range(3):
        optimizer.zero_grad(set_to_none=True)
        y_pred = model(static_input)
        loss = loss_fn(y_pred, static_target)
        loss.backward()
        optimizer.step()
torch.cuda.current_stream().wait_stream(s)

# capture
g = torch.cuda.CUDAGraph()
# Sets grads to None before capture, so backward() will create
# .grad attributes with allocations from the graph's private pool
optimizer.zero_grad(set_to_none=True)
with torch.cuda.graph(g):
    static_y_pred = model(static_input)
    static_loss = loss_fn(static_y_pred, static_target)
    static_loss.backward()
    optimizer.step()

real_inputs = [torch.rand_like(static_input) for _ in range(10)]
real_targets = [torch.rand_like(static_target) for _ in range(10)]

for data, target in zip(real_inputs, real_targets):
    # Fills the graph's input memory with new data to compute on
    static_input.copy_(data)
    static_target.copy_(target)
    # replay() includes forward, backward, and step.
    # You don't even need to call optimizer.zero_grad() between iterations
    # because the captured backward refills static .grad tensors in place.
    g.replay()
    # Params have been updated. static_y_pred, static_loss, and .grad
    # attributes hold values from computing on this iteration's data.

部分网络捕获

如果您的部分网络不适合捕获(例如,由于动态控制流、动态形状、CPU 同步或必要的 CPU 端逻辑),您可以立即运行不安全的部分,并使用 torch.cuda.make_graphed_callables() 仅对适合捕获的部分进行图化。

默认情况下,make_graphed_callables() 返回的可调用对象是 autograd 感知的,可以在训练循环中直接替代您传入的函数或 nn.Module

make_graphed_callables() 内部创建 CUDAGraph 对象,运行预热迭代,并根据需要维护静态输入和输出。因此(与使用 torch.cuda.graph 不同),您无需手动处理这些。

在下面的示例中,数据依赖的动态控制流意味着网络无法进行端到端捕获,但 make_graphed_callables() 无论如何都允许我们将适合图化的部分捕获并作为图运行。

N, D_in, H, D_out = 640, 4096, 2048, 1024

module1 = torch.nn.Linear(D_in, H).cuda()
module2 = torch.nn.Linear(H, D_out).cuda()
module3 = torch.nn.Linear(H, D_out).cuda()

loss_fn = torch.nn.MSELoss()
optimizer = torch.optim.SGD(chain(module1.parameters(),
                                  module2.parameters(),
                                  module3.parameters()),
                            lr=0.1)

# Sample inputs used for capture
# requires_grad state of sample inputs must match
# requires_grad state of real inputs each callable will see.
x = torch.randn(N, D_in, device='cuda')
h = torch.randn(N, H, device='cuda', requires_grad=True)

module1 = torch.cuda.make_graphed_callables(module1, (x,))
module2 = torch.cuda.make_graphed_callables(module2, (h,))
module3 = torch.cuda.make_graphed_callables(module3, (h,))

real_inputs = [torch.rand_like(x) for _ in range(10)]
real_targets = [torch.randn(N, D_out, device="cuda") for _ in range(10)]

for data, target in zip(real_inputs, real_targets):
    optimizer.zero_grad(set_to_none=True)

    tmp = module1(data)  # forward ops run as a graph

    if tmp.sum().item() > 0:
        tmp = module2(tmp)  # forward ops run as a graph
    else:
        tmp = module3(tmp)  # forward ops run as a graph

    loss = loss_fn(tmp, target)
    # module2's or module3's (whichever was chosen) backward ops,
    # as well as module1's backward ops, run as graphs
    loss.backward()
    optimizer.step()

与 torch.cuda.amp 一起使用

对于典型的优化器,GradScaler.step 会使 CPU 与 GPU 同步,这在捕获期间是被禁止的。为避免错误,要么使用部分网络捕获,要么(如果前向、损失和后向是适合捕获的)捕获前向、损失和后向,但不捕获优化器步骤。

# warmup
# In a real setting, use a few batches of real data.
s = torch.cuda.Stream()
s.wait_stream(torch.cuda.current_stream())
with torch.cuda.stream(s):
    for i in range(3):
        optimizer.zero_grad(set_to_none=True)
        with torch.cuda.amp.autocast():
            y_pred = model(static_input)
            loss = loss_fn(y_pred, static_target)
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()
torch.cuda.current_stream().wait_stream(s)

# capture
g = torch.cuda.CUDAGraph()
optimizer.zero_grad(set_to_none=True)
with torch.cuda.graph(g):
    with torch.cuda.amp.autocast():
        static_y_pred = model(static_input)
        static_loss = loss_fn(static_y_pred, static_target)
    scaler.scale(static_loss).backward()
    # don't capture scaler.step(optimizer) or scaler.update()

real_inputs = [torch.rand_like(static_input) for _ in range(10)]
real_targets = [torch.rand_like(static_target) for _ in range(10)]

for data, target in zip(real_inputs, real_targets):
    static_input.copy_(data)
    static_target.copy_(target)
    g.replay()
    # Runs scaler.step and scaler.update eagerly
    scaler.step(optimizer)
    scaler.update()

与多个流一起使用

捕获模式会自动传播到与捕获流同步的任何流。在捕获期间,您可以通过向不同流发出调用来暴露并行性,但总体的流依赖关系 DAG 必须在捕获开始后从初始捕获流分叉,并在捕获结束前重新汇合到初始流。

with torch.cuda.graph(g):
    # at context manager entrance, torch.cuda.current_stream()
    # is the initial capturing stream

    # INCORRECT (does not branch out from or rejoin initial stream)
    with torch.cuda.stream(s):
        cuda_work()

    # CORRECT:
    # branches out from initial stream
    s.wait_stream(torch.cuda.current_stream())
    with torch.cuda.stream(s):
        cuda_work()
    # rejoins initial stream before capture ends
    torch.cuda.current_stream().wait_stream(s)

注意

为了避免资深用户在 nsight systems 或 nvprof 中查看重放时感到困惑:与即时执行不同,图将捕获中的非平凡流 DAG 解释为提示,而不是命令。在重放期间,图可以将独立的操作重新组织到不同的流上或以不同的顺序排队(同时尊重您原始 DAG 的整体依赖关系)。

与 DistributedDataParallel 一起使用

NCCL < 2.9.6

NCCL 2.9.6 之前的版本不允许捕获集合操作 (collectives)。您必须使用部分网络捕获,这将 allreduce 操作推迟到后向传播图化部分之外进行。

在使用 DDP 包装网络*之前*,对可图化的网络部分调用 make_graphed_callables()

NCCL >= 2.9.6

NCCL 2.9.6 或更高版本允许在图中使用集合操作 (collectives)。捕获整个后向传播的方法是可行的选择,但需要三个设置步骤。

  1. 禁用 DDP 内部的异步错误处理。

    os.environ["NCCL_ASYNC_ERROR_HANDLING"] = "0"
    torch.distributed.init_process_group(...)
    
  2. 在进行完整的后向传播捕获之前,必须在辅助流上下文中构建 DDP。

    with torch.cuda.stream(s):
        model = DistributedDataParallel(model)
    
  3. 在捕获之前,您的预热必须运行至少 11 次启用 DDP 的即时迭代。

图内存管理

捕获的图在每次重放时都操作相同的虚拟地址。如果 PyTorch 释放了内存,随后的重放可能会导致非法内存访问。如果 PyTorch 将内存重新分配给新的张量,重放可能会损坏这些张量看到的值。因此,图使用的虚拟地址必须在重放期间为图保留。PyTorch 缓存分配器通过检测何时正在进行捕获,并从一个图专用的内存池中满足捕获的分配来实现这一点。该私有池一直存在,直到其 CUDAGraph 对象以及在捕获期间创建的所有张量超出作用域为止。

私有池会自动维护。默认情况下,分配器会为每次捕获创建一个独立的私有池。如果您捕获多个图,这种保守的方法确保图重放不会相互损坏值,但有时会不必要地浪费内存。

跨捕获共享内存

为了节省私有池中存储的内存,torch.cuda.graphtorch.cuda.make_graphed_callables() 可选地允许不同的捕获共享同一个私有池。如果您知道一组图将始终按捕获时的顺序重放,并且不会并发重放,那么它们共享一个私有池是安全的。

torch.cuda.graphpool 参数是使用特定私有池的提示,可用于如图所示在不同图之间共享内存。

g1 = torch.cuda.CUDAGraph()
g2 = torch.cuda.CUDAGraph()

# (create static inputs for g1 and g2, run warmups of their workloads...)

# Captures g1
with torch.cuda.graph(g1):
    static_out_1 = g1_workload(static_in_1)

# Captures g2, hinting that g2 may share a memory pool with g1
with torch.cuda.graph(g2, pool=g1.pool()):
    static_out_2 = g2_workload(static_in_2)

static_in_1.copy_(real_data_1)
static_in_2.copy_(real_data_2)
g1.replay()
g2.replay()

使用 torch.cuda.make_graphed_callables() 时,如果您想对多个可调用对象进行图化,并且您知道它们将始终按相同顺序运行(且永不并发),请将它们按照在实际工作负载中运行的相同顺序作为元组传入,make_graphed_callables() 将使用共享私有池捕获它们的图。

如果在实际工作负载中,您的可调用对象将以偶尔变化的顺序运行,或者它们将并发运行,则不允许将它们作为元组传递给单个 make_graphed_callables() 的调用。相反,您必须为每个可调用对象单独调用 make_graphed_callables()

文档

访问 PyTorch 的全面开发者文档

查看文档

教程

获取面向初学者和高级开发者的深度教程

查看教程

资源

查找开发资源并解答您的问题

查看资源