• 教程 >
  • PyTorch 中 non_blockingpin_memory() 的良好使用指南
快捷方式

PyTorch 中 non_blockingpin_memory() 的良好使用指南

作者: Vincent Moens

简介

在许多 PyTorch 应用程序中,将数据从 CPU 传输到 GPU 是基本操作。用户必须了解用于在设备之间移动数据的最有效工具和选项。本教程考察了 PyTorch 中用于设备到设备数据传输的两种关键方法:pin_memory()to() 以及 non_blocking=True 选项。

您将学到什么

通过异步传输和内存固定,可以优化张量从 CPU 到 GPU 的传输。但是,有一些重要的注意事项

  • 使用 tensor.pin_memory().to(device, non_blocking=True) 的速度可能比简单的 tensor.to(device) 慢两倍。

  • 通常,tensor.to(device, non_blocking=True) 是提高传输速度的有效选择。

  • 虽然 cpu_tensor.to("cuda", non_blocking=True).mean() 正确执行,但尝试 cuda_tensor.to("cpu", non_blocking=True).mean() 将导致错误的输出。

前言

本教程中报告的性能取决于用于构建教程的系统。虽然结论适用于不同的系统,但具体观察结果可能会略有不同,具体取决于可用的硬件,尤其是在较旧的硬件上。本教程的主要目的是提供一个理论框架来理解 CPU 到 GPU 的数据传输。但是,任何设计决策都应针对具体情况进行调整,并以基准测试的吞吐量测量以及手头任务的具体要求为指导。

import torch

assert torch.cuda.is_available(), "A cuda device is required to run this tutorial"

本教程需要安装 tensordict。如果您在环境中还没有 tensordict,请在单独的单元格中运行以下命令来安装它

# Install tensordict with the following command
!pip3 install tensordict

我们首先概述围绕这些概念的理论,然后转向这些功能的具体测试示例。

背景

内存管理基础

当在 PyTorch 中创建一个 CPU 张量时,该张量的内容需要放置在内存中。我们这里所说的内存是一个相当复杂的概念,值得仔细研究。我们区分由内存管理单元处理的两种类型的内存:RAM(为简单起见)和磁盘上的交换空间(可能是硬盘驱动器,也可能不是)。磁盘和 RAM(物理内存)中可用的空间共同构成了虚拟内存,它是可用总资源的抽象。简而言之,虚拟内存使得可用空间大于 RAM 本身所能找到的空间,并创造了主内存大于实际大小的错觉。

在正常情况下,普通的 CPU 张量是可分页的,这意味着它被划分为称为页面的块,这些页面可以位于虚拟内存中的任何地方(无论是在 RAM 中还是在磁盘上)。如前所述,这具有使内存看起来大于主内存实际大小的优点。

通常,当程序访问不在 RAM 中的页面时,就会发生“页面错误”,然后操作系统 (OS) 将该页面调回 RAM(“调入”或“调页”)。反过来,操作系统可能必须调出(或“调页”)另一个页面来为新页面腾出空间。

与可分页内存相比,固定(或页面锁定或不可分页)内存是一种不能交换到磁盘的内存类型。它允许更快的和更可预测的访问时间,但缺点是它比可分页内存(即主内存)更受限制。

CUDA 和(非)可分页内存

为了理解 CUDA 如何将张量从 CPU 复制到 CUDA,让我们考虑以上两种情况

  • 如果内存是页面锁定的,设备可以直接访问主内存中的内存。内存地址是明确定义的,需要读取这些数据的函数可以显着加速。

  • 如果内存是可分页的,则所有页面都必须被调入主内存,然后再被发送到 GPU。此操作可能需要时间,并且比在页面锁定张量上执行时不可预测。

更准确地说,当 CUDA 将可分页数据从 CPU 发送到 GPU 时,它必须首先创建该数据的页面锁定副本,然后再进行传输。

使用 non_blocking=True(CUDA cudaMemcpyAsync)的异步与同步操作

当从主机(例如 CPU)到设备(例如 GPU)执行复制操作时,CUDA 工具包提供以相对于主机同步或异步方式执行这些操作的方式。

在实践中,当调用 to() 时,PyTorch 始终调用 cudaMemcpyAsync。如果 non_blocking=False(默认值),则在每次 cudaMemcpyAsync 之后都会调用 cudaStreamSynchronize,使对 to() 的调用在主线程中阻塞。如果 non_blocking=True,则不会触发任何同步操作,并且主线程不会在主机上阻塞。因此,从主机的角度来看,可以同时将多个张量发送到设备,因为线程不需要等待一个传输完成才能启动另一个传输。

注意

通常,传输在设备端是阻塞的(即使在主机端不是):设备上的复制操作不能在执行另一个操作时发生。但是,在某些高级场景中,可以同时在 GPU 端执行复制操作和内核执行。正如以下示例所示,必须满足三个要求才能启用此功能

  1. 设备必须至少有一个空闲的 DMA(直接内存访问)引擎。现代 GPU 架构(如 Volterra、Tesla 或 H100 设备)具有多个 DMA 引擎。

  2. 传输必须在单独的非默认 cuda 流上完成。在 PyTorch 中,cuda 流可以使用 Stream 处理。

  3. 源数据必须位于固定内存中。

我们通过对以下脚本运行概要分析来演示这一点。

import contextlib

from torch.cuda import Stream


s = Stream()

torch.manual_seed(42)
t1_cpu_pinned = torch.randn(1024**2 * 5, pin_memory=True)
t2_cpu_paged = torch.randn(1024**2 * 5, pin_memory=False)
t3_cuda = torch.randn(1024**2 * 5, device="cuda:0")

assert torch.cuda.is_available()
device = torch.device("cuda", torch.cuda.current_device())


# The function we want to profile
def inner(pinned: bool, streamed: bool):
    with torch.cuda.stream(s) if streamed else contextlib.nullcontext():
        if pinned:
            t1_cuda = t1_cpu_pinned.to(device, non_blocking=True)
        else:
            t2_cuda = t2_cpu_paged.to(device, non_blocking=True)
        t_star_cuda_h2d_event = s.record_event()
    # This operation can be executed during the CPU to GPU copy if and only if the tensor is pinned and the copy is
    #  done in the other stream
    t3_cuda_mul = t3_cuda * t3_cuda * t3_cuda
    t3_cuda_h2d_event = torch.cuda.current_stream().record_event()
    t_star_cuda_h2d_event.synchronize()
    t3_cuda_h2d_event.synchronize()


# Our profiler: profiles the `inner` function and stores the results in a .json file
def benchmark_with_profiler(
    pinned,
    streamed,
) -> None:
    torch._C._profiler._set_cuda_sync_enabled_val(True)
    wait, warmup, active = 1, 1, 2
    num_steps = wait + warmup + active
    rank = 0
    with torch.profiler.profile(
        activities=[
            torch.profiler.ProfilerActivity.CPU,
            torch.profiler.ProfilerActivity.CUDA,
        ],
        schedule=torch.profiler.schedule(
            wait=wait, warmup=warmup, active=active, repeat=1, skip_first=1
        ),
    ) as prof:
        for step_idx in range(1, num_steps + 1):
            inner(streamed=streamed, pinned=pinned)
            if rank is None or rank == 0:
                prof.step()
    prof.export_chrome_trace(f"trace_streamed{int(streamed)}_pinned{int(pinned)}.json")

在 chrome (chrome://tracing) 中加载这些概要分析跟踪显示以下结果:首先,让我们看看如果在将可分页张量发送到主流中的 GPU 之后执行 t3_cuda 上的算术运算会发生什么

benchmark_with_profiler(streamed=False, pinned=False)

使用固定张量不会改变跟踪太多,两个操作仍然是连续执行的

benchmark_with_profiler(streamed=False, pinned=True)

将可分页张量发送到单独流上的 GPU 也是一个阻塞操作

benchmark_with_profiler(streamed=True, pinned=False)

只有将固定张量复制到单独流上的 GPU 才能与在主流上执行的另一个 cuda 内核重叠

benchmark_with_profiler(streamed=True, pinned=True)

PyTorch 视角

pin_memory()

PyTorch 提供了通过 pin_memory() 方法和构造函数参数创建并将张量发送到页面锁定内存的可能性。在初始化了 CUDA 的机器上,CPU 张量可以通过 pin_memory() 方法转换为固定内存。重要的是,pin_memory 在主机的主线程上是阻塞的:它将等待张量被复制到页面锁定内存中,然后再执行下一个操作。可以使用 zeros()ones() 等构造函数直接在固定内存中创建新的张量。

让我们检查一下固定内存并将张量发送到 CUDA 的速度

import torch
import gc
from torch.utils.benchmark import Timer
import matplotlib.pyplot as plt


def timer(cmd):
    median = (
        Timer(cmd, globals=globals())
        .adaptive_autorange(min_run_time=1.0, max_run_time=20.0)
        .median
        * 1000
    )
    print(f"{cmd}: {median: 4.4f} ms")
    return median


# A tensor in pageable memory
pageable_tensor = torch.randn(1_000_000)

# A tensor in page-locked (pinned) memory
pinned_tensor = torch.randn(1_000_000, pin_memory=True)

# Runtimes:
pageable_to_device = timer("pageable_tensor.to('cuda:0')")
pinned_to_device = timer("pinned_tensor.to('cuda:0')")
pin_mem = timer("pageable_tensor.pin_memory()")
pin_mem_to_device = timer("pageable_tensor.pin_memory().to('cuda:0')")

# Ratios:
r1 = pinned_to_device / pageable_to_device
r2 = pin_mem_to_device / pageable_to_device

# Create a figure with the results
fig, ax = plt.subplots()

xlabels = [0, 1, 2]
bar_labels = [
    "pageable_tensor.to(device) (1x)",
    f"pinned_tensor.to(device) ({r1:4.2f}x)",
    f"pageable_tensor.pin_memory().to(device) ({r2:4.2f}x)"
    f"\npin_memory()={100*pin_mem/pin_mem_to_device:.2f}% of runtime.",
]
values = [pageable_to_device, pinned_to_device, pin_mem_to_device]
colors = ["tab:blue", "tab:red", "tab:orange"]
ax.bar(xlabels, values, label=bar_labels, color=colors)

ax.set_ylabel("Runtime (ms)")
ax.set_title("Device casting runtime (pin-memory)")
ax.set_xticks([])
ax.legend()

plt.show()

# Clear tensors
del pageable_tensor, pinned_tensor
_ = gc.collect()
Device casting runtime (pin-memory)
pageable_tensor.to('cuda:0'):  0.4660 ms
pinned_tensor.to('cuda:0'):  0.3703 ms
pageable_tensor.pin_memory():  0.3580 ms
pageable_tensor.pin_memory().to('cuda:0'):  0.7199 ms

我们可以观察到,将固定内存张量转换为 GPU 确实比可分页张量快得多,因为在幕后,可分页张量必须在被发送到 GPU 之前复制到固定内存中。

但是,与一种普遍的看法相反,在将可分页张量转换为 GPU 之前在它上面调用 pin_memory() 不会带来任何显著的加速,相反,此调用通常比仅执行传输要慢。这是有道理的,因为我们实际上是在要求 Python 执行 CUDA ohnehin 在将数据从主机复制到设备之前执行的操作。

注意

PyTorch 实现的 pin_memory 依赖于通过 cudaHostAlloc 在固定内存中创建一个全新的存储,在极少数情况下可能比 cudaMemcpy 所做的分块数据转换要快。同样,观察结果可能取决于可用的硬件、正在发送的张量的大小或可用的 RAM 量。

non_blocking=True

如前所述,许多 PyTorch 操作可以通过 non_blocking 参数以相对于主机异步的方式执行。

在这里,为了准确地说明使用 non_blocking 的好处,我们将设计一个稍微复杂的实验,因为我们要评估使用和不使用 non_blocking 将多个张量发送到 GPU 的速度。

# A simple loop that copies all tensors to cuda
def copy_to_device(*tensors):
    result = []
    for tensor in tensors:
        result.append(tensor.to("cuda:0"))
    return result


# A loop that copies all tensors to cuda asynchronously
def copy_to_device_nonblocking(*tensors):
    result = []
    for tensor in tensors:
        result.append(tensor.to("cuda:0", non_blocking=True))
    # We need to synchronize
    torch.cuda.synchronize()
    return result


# Create a list of tensors
tensors = [torch.randn(1000) for _ in range(1000)]
to_device = timer("copy_to_device(*tensors)")
to_device_nonblocking = timer("copy_to_device_nonblocking(*tensors)")

# Ratio
r1 = to_device_nonblocking / to_device

# Plot the results
fig, ax = plt.subplots()

xlabels = [0, 1]
bar_labels = [f"to(device) (1x)", f"to(device, non_blocking=True) ({r1:4.2f}x)"]
colors = ["tab:blue", "tab:red"]
values = [to_device, to_device_nonblocking]

ax.bar(xlabels, values, label=bar_labels, color=colors)

ax.set_ylabel("Runtime (ms)")
ax.set_title("Device casting runtime (non-blocking)")
ax.set_xticks([])
ax.legend()

plt.show()
Device casting runtime (non-blocking)
copy_to_device(*tensors):  25.5406 ms
copy_to_device_nonblocking(*tensors):  18.7735 ms

为了更好地了解这里发生了什么,让我们对这两个函数进行概要分析

from torch.profiler import profile, ProfilerActivity


def profile_mem(cmd):
    with profile(activities=[ProfilerActivity.CPU]) as prof:
        exec(cmd)
    print(cmd)
    print(prof.key_averages().table(row_limit=10))

让我们看看使用常规的 to(device) 的调用堆栈

print("Call to `to(device)`", profile_mem("copy_to_device(*tensors)"))
copy_to_device(*tensors)
-------------------------  ------------  ------------  ------------  ------------  ------------  ------------
                     Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg    # of Calls
-------------------------  ------------  ------------  ------------  ------------  ------------  ------------
                 aten::to         4.16%       1.340ms       100.00%      32.237ms      32.237us          1000
           aten::_to_copy        12.86%       4.146ms        95.84%      30.897ms      30.897us          1000
      aten::empty_strided        24.00%       7.738ms        24.00%       7.738ms       7.738us          1000
              aten::copy_        19.01%       6.130ms        58.98%      19.013ms      19.013us          1000
          cudaMemcpyAsync        18.59%       5.992ms        18.59%       5.992ms       5.992us          1000
    cudaStreamSynchronize        21.38%       6.892ms        21.38%       6.892ms       6.892us          1000
-------------------------  ------------  ------------  ------------  ------------  ------------  ------------
Self CPU time total: 32.237ms

Call to `to(device)` None

现在是 non_blocking 版本

print(
    "Call to `to(device, non_blocking=True)`",
    profile_mem("copy_to_device_nonblocking(*tensors)"),
)
copy_to_device_nonblocking(*tensors)
-------------------------  ------------  ------------  ------------  ------------  ------------  ------------
                     Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg    # of Calls
-------------------------  ------------  ------------  ------------  ------------  ------------  ------------
                 aten::to         4.83%       1.137ms        99.87%      23.519ms      23.519us          1000
           aten::_to_copy        16.65%       3.920ms        95.04%      22.382ms      22.382us          1000
      aten::empty_strided        31.88%       7.508ms        31.88%       7.508ms       7.508us          1000
              aten::copy_        23.04%       5.425ms        46.52%      10.954ms      10.954us          1000
          cudaMemcpyAsync        23.48%       5.528ms        23.48%       5.528ms       5.528us          1000
    cudaDeviceSynchronize         0.13%      29.825us         0.13%      29.825us      29.825us             1
-------------------------  ------------  ------------  ------------  ------------  ------------  ------------
Self CPU time total: 23.549ms

Call to `to(device, non_blocking=True)` None

毫无疑问,使用 non_blocking=True 的结果更好,因为所有传输都在主机端同时启动,并且只执行一次同步。

好处将根据张量的数量和大小以及使用的硬件而有所不同。

注意

有趣的是,阻塞 to("cuda") 实际上执行与 non_blocking=True 相同的异步设备转换操作 (cudaMemcpyAsync),并在每次复制后都有一个同步点。

协同作用

现在我们已经证明将已位于固定内存中的张量传输到 GPU 比从可分页内存传输更快,并且我们知道异步传输也比同步传输更快,我们可以对这些方法的组合进行基准测试。首先,让我们编写几个新函数,这些函数将在每个张量上调用 pin_memoryto(device)

def pin_copy_to_device(*tensors):
    result = []
    for tensor in tensors:
        result.append(tensor.pin_memory().to("cuda:0"))
    return result


def pin_copy_to_device_nonblocking(*tensors):
    result = []
    for tensor in tensors:
        result.append(tensor.pin_memory().to("cuda:0", non_blocking=True))
    # We need to synchronize
    torch.cuda.synchronize()
    return result

使用 pin_memory() 的优势对于相当大的大型张量批次更为明显。

tensors = [torch.randn(1_000_000) for _ in range(1000)]
page_copy = timer("copy_to_device(*tensors)")
page_copy_nb = timer("copy_to_device_nonblocking(*tensors)")

tensors_pinned = [torch.randn(1_000_000, pin_memory=True) for _ in range(1000)]
pinned_copy = timer("copy_to_device(*tensors_pinned)")
pinned_copy_nb = timer("copy_to_device_nonblocking(*tensors_pinned)")

pin_and_copy = timer("pin_copy_to_device(*tensors)")
pin_and_copy_nb = timer("pin_copy_to_device_nonblocking(*tensors)")

# Plot
strategies = ("pageable copy", "pinned copy", "pin and copy")
blocking = {
    "blocking": [page_copy, pinned_copy, pin_and_copy],
    "non-blocking": [page_copy_nb, pinned_copy_nb, pin_and_copy_nb],
}

x = torch.arange(3)
width = 0.25
multiplier = 0


fig, ax = plt.subplots(layout="constrained")

for attribute, runtimes in blocking.items():
    offset = width * multiplier
    rects = ax.bar(x + offset, runtimes, width, label=attribute)
    ax.bar_label(rects, padding=3, fmt="%.2f")
    multiplier += 1

# Add some text for labels, title and custom x-axis tick labels, etc.
ax.set_ylabel("Runtime (ms)")
ax.set_title("Runtime (pin-mem and non-blocking)")
ax.set_xticks([0, 1, 2])
ax.set_xticklabels(strategies)
plt.setp(ax.get_xticklabels(), rotation=45, ha="right", rotation_mode="anchor")
ax.legend(loc="upper left", ncols=3)

plt.show()

del tensors, tensors_pinned
_ = gc.collect()
Runtime (pin-mem and non-blocking)
copy_to_device(*tensors):  607.5763 ms
copy_to_device_nonblocking(*tensors):  533.4217 ms
copy_to_device(*tensors_pinned):  371.5275 ms
copy_to_device_nonblocking(*tensors_pinned):  345.8600 ms
pin_copy_to_device(*tensors):  970.9513 ms
pin_copy_to_device_nonblocking(*tensors):  653.2347 ms

其他复制方向 (GPU -> CPU, CPU -> MPS)

到目前为止,我们一直假设从 CPU 到 GPU 的异步复制是安全的。这通常是正确的,因为 CUDA 会自动处理同步以确保访问的数据在读取时有效。但是,此保证不适用于相反方向的传输,即从 GPU 到 CPU。在没有显式同步的情况下,这些传输无法保证复制在数据访问时将完成。因此,主机上的数据可能不完整或不正确,实际上使其成为垃圾。

tensor = (
    torch.arange(1, 1_000_000, dtype=torch.double, device="cuda")
    .expand(100, 999999)
    .clone()
)
torch.testing.assert_close(
    tensor.mean(), torch.tensor(500_000, dtype=torch.double, device="cuda")
), tensor.mean()
try:
    i = -1
    for i in range(100):
        cpu_tensor = tensor.to("cpu", non_blocking=True)
        torch.testing.assert_close(
            cpu_tensor.mean(), torch.tensor(500_000, dtype=torch.double)
        )
    print("No test failed with non_blocking")
except AssertionError:
    print(f"{i}th test failed with non_blocking. Skipping remaining tests")
try:
    i = -1
    for i in range(100):
        cpu_tensor = tensor.to("cpu", non_blocking=True)
        torch.cuda.synchronize()
        torch.testing.assert_close(
            cpu_tensor.mean(), torch.tensor(500_000, dtype=torch.double)
        )
    print("No test failed with synchronize")
except AssertionError:
    print(f"One test failed with synchronize: {i}th assertion!")
0th test failed with non_blocking. Skipping remaining tests
No test failed with synchronize

相同的考虑适用于从 CPU 到非 CUDA 设备(例如 MPS)的复制。通常,仅当目标是支持 CUDA 的设备时,异步复制到设备才在没有显式同步的情况下安全。

总之,使用 non_blocking=True 从 CPU 到 GPU 复制数据是安全的,但对于任何其他方向,non_blocking=True 仍然可以使用,但用户必须确保在访问数据之前执行设备同步。

实用建议

现在我们可以根据我们的观察结果总结一些早期建议。

通常,non_blocking=True 将提供良好的吞吐量,无论原始张量是否位于固定内存中。如果张量已经位于固定内存中,则可以加速传输,但从 Python 主线程手动将其发送到固定内存是一个阻塞操作,因此将消除使用 non_blocking=True 的大部分好处(因为 CUDA 也会执行 pin_memory 传输)。

现在有人可能会问 pin_memory() 方法有什么用。在下一节中,我们将进一步探讨如何使用它来进一步加速数据传输。

其他注意事项

众所周知,PyTorch 提供了一个 DataLoader 类,其构造函数接受一个 pin_memory 参数。考虑到我们之前关于 pin_memory 的讨论,您可能想知道如果内存固定本质上是阻塞的,DataLoader 如何设法加速数据传输。

关键在于 DataLoader 使用单独的线程来处理从可分页内存到固定内存的数据传输,从而防止主线程中的任何阻塞。

为了说明这一点,我们将使用同名库中的 TensorDict 原语。当调用 to() 时,默认行为是异步地将张量发送到设备,然后之后调用一次 torch.device.synchronize()

此外,TensorDict.to() 包含一个 non_blocking_pin 选项,该选项会启动多个线程来执行 pin_memory(),然后再继续执行 to(device)。这种方法可以进一步加速数据传输,如以下示例所示。

from tensordict import TensorDict
import torch
from torch.utils.benchmark import Timer
import matplotlib.pyplot as plt

# Create the dataset
td = TensorDict({str(i): torch.randn(1_000_000) for i in range(1000)})

# Runtimes
copy_blocking = timer("td.to('cuda:0', non_blocking=False)")
copy_non_blocking = timer("td.to('cuda:0')")
copy_pin_nb = timer("td.to('cuda:0', non_blocking_pin=True, num_threads=0)")
copy_pin_multithread_nb = timer("td.to('cuda:0', non_blocking_pin=True, num_threads=4)")

# Rations
r1 = copy_non_blocking / copy_blocking
r2 = copy_pin_nb / copy_blocking
r3 = copy_pin_multithread_nb / copy_blocking

# Figure
fig, ax = plt.subplots()

xlabels = [0, 1, 2, 3]
bar_labels = [
    "Blocking copy (1x)",
    f"Non-blocking copy ({r1:4.2f}x)",
    f"Blocking pin, non-blocking copy ({r2:4.2f}x)",
    f"Non-blocking pin, non-blocking copy ({r3:4.2f}x)",
]
values = [copy_blocking, copy_non_blocking, copy_pin_nb, copy_pin_multithread_nb]
colors = ["tab:blue", "tab:red", "tab:orange", "tab:green"]

ax.bar(xlabels, values, label=bar_labels, color=colors)

ax.set_ylabel("Runtime (ms)")
ax.set_title("Device casting runtime")
ax.set_xticks([])
ax.legend()

plt.show()
Device casting runtime
td.to('cuda:0', non_blocking=False):  616.1530 ms
td.to('cuda:0'):  540.1283 ms
td.to('cuda:0', non_blocking_pin=True, num_threads=0):  662.1348 ms
td.to('cuda:0', non_blocking_pin=True, num_threads=4):  363.3530 ms

在这个示例中,我们正在将许多大型张量从 CPU 传输到 GPU。这种情况非常适合利用多线程 pin_memory(),这可以显着提高性能。但是,如果张量很小,与多线程相关的开销可能超过其好处。同样,如果只有几个张量,则在单独线程上固定张量的优势就会变得有限。

另外,虽然在固定内存中创建永久缓冲区以在将张量传输到 GPU 之前从可分页内存中中转张量似乎很有优势,但这种策略并不一定能加快计算速度。复制数据到固定内存的固有瓶颈仍然是一个限制因素。

此外,将位于磁盘上的数据(无论是在共享内存中还是在文件中)传输到 GPU 通常需要将数据复制到固定内存(位于 RAM 中)的中间步骤。在这种情况下,对大型数据传输使用非阻塞可能会显着增加 RAM 使用量,从而可能导致不利影响。

在实践中,没有万能的解决方案。使用多线程 pin_memory 结合 non_blocking 传输的有效性取决于各种因素,包括特定系统、操作系统、硬件以及正在执行的任务的性质。以下列出了在尝试加快 CPU 和 GPU 之间的数据传输速度或比较跨场景的吞吐量时要检查的因素列表。

  • 可用核心数量

    有多少个 CPU 核心可用?系统是否与可能竞争资源的其他用户或进程共享?

  • 核心利用率

    其他进程是否大量使用 CPU 核心?应用程序是否与数据传输同时执行其他 CPU 密集型任务?

  • 内存使用情况

    目前正在使用多少可分页内存和锁定页内存?是否有足够的可用内存来分配额外的固定内存,而不会影响系统性能?请记住,没有免费的午餐,例如,pin_memory 会消耗 RAM,并可能影响其他任务。

  • CUDA 设备功能

    GPU 是否支持用于并发数据传输的多个 DMA 引擎?正在使用的 CUDA 设备的具体功能和限制是什么?

  • 要发送的张量数量

    在典型的操作中传输了多少个张量?

  • 要发送的张量的尺寸

    正在传输的张量的尺寸是多少?少量大型张量或大量小型张量可能不会从相同的传输程序中受益。

  • 系统架构

    系统架构如何影响数据传输速度(例如,总线速度、网络延迟)?

此外,在固定内存中分配大量张量或大型张量可能会占用很大一部分 RAM。这会减少其他关键操作(例如分页)的可用内存,从而可能对算法的整体性能产生负面影响。

结论

在本教程中,我们探讨了影响将张量从主机发送到设备时的传输速度和内存管理的几个关键因素。我们了解到使用 non_blocking=True 通常会加速数据传输,并且 pin_memory() 如果正确实现,也可以提高性能。但是,这些技术需要仔细设计和校准才能有效。

请记住,对代码进行性能分析并密切关注内存使用情况对于优化资源使用和获得最佳性能至关重要。

其他资源

如果您在使用 CUDA 设备时遇到内存复制问题,或者想了解有关本教程中讨论的内容的更多信息,请查看以下参考资料。

脚本的总运行时间:( 1 分 16.874 秒)

由 Sphinx-Gallery 生成的画廊

文档

访问 PyTorch 的综合开发者文档

查看文档

教程

获取针对初学者和高级开发人员的深入教程

查看教程

资源

查找开发资源并获得问题的解答

查看资源