快捷方式

自动混合精度

创建于: 2020 年 9 月 15 日 | 最后更新于: 2025 年 1 月 30 日 | 最后验证于: 2024 年 11 月 5 日

作者: Michael Carilli

torch.cuda.amp 提供了方便的混合精度方法,其中一些操作使用 torch.float32 (float) 数据类型,而其他操作使用 torch.float16 (half)。一些算子(如线性层和卷积)在 float16bfloat16 中要快得多。其他算子(如规约操作)通常需要 float32 的动态范围。混合精度试图将每个算子与其适当的数据类型匹配,这可以减少网络的运行时和内存占用。

通常,“自动混合精度训练”会同时使用 torch.autocasttorch.cuda.amp.GradScaler

本代码示例测量了一个简单网络在默认精度下的性能,然后逐步讲解如何添加 autocastGradScaler,以便在混合精度下以改进的性能运行同一网络。

你可以将此代码示例下载并作为独立的 Python 脚本运行。唯一的要求是 PyTorch 1.6 或更高版本以及一个支持 CUDA 的 GPU。

混合精度主要受益于启用 Tensor Core 的架构(Volta, Turing, Ampere)。此代码示例在这些架构上应显示显著(2-3 倍)的加速。在更早期的架构(Kepler, Maxwell, Pascal)上,你可能会观察到适度的加速。运行 nvidia-smi 以显示你的 GPU 架构。

import torch, time, gc

# Timing utilities
start_time = None

def start_timer():
    global start_time
    gc.collect()
    torch.cuda.empty_cache()
    torch.cuda.reset_max_memory_allocated()
    torch.cuda.synchronize()
    start_time = time.time()

def end_timer_and_print(local_msg):
    torch.cuda.synchronize()
    end_time = time.time()
    print("\n" + local_msg)
    print("Total execution time = {:.3f} sec".format(end_time - start_time))
    print("Max memory used by tensors = {} bytes".format(torch.cuda.max_memory_allocated()))

一个简单网络

以下由线性层和 ReLU 组成的序列网络在混合精度下应显示出加速效果。

def make_model(in_size, out_size, num_layers):
    layers = []
    for _ in range(num_layers - 1):
        layers.append(torch.nn.Linear(in_size, in_size))
        layers.append(torch.nn.ReLU())
    layers.append(torch.nn.Linear(in_size, out_size))
    return torch.nn.Sequential(*tuple(layers)).cuda()

batch_sizein_sizeout_sizenum_layers 的选择应足够大,以使 GPU 饱和。通常,当 GPU 饱和时,混合精度能提供最大的加速。小型网络可能是 CPU 瓶颈,在这种情况下,混合精度不会改善性能。尺寸也选择为线性层相关维度是 8 的倍数,以便在支持 Tensor Core 的 GPU 上使用 Tensor Core(参见下面的故障排除)。

练习:改变相关尺寸,观察混合精度加速如何变化。

batch_size = 512 # Try, for example, 128, 256, 513.
in_size = 4096
out_size = 4096
num_layers = 3
num_batches = 50
epochs = 3

device = 'cuda' if torch.cuda.is_available() else 'cpu'
torch.set_default_device(device)

# Creates data in default precision.
# The same data is used for both default and mixed precision trials below.
# You don't need to manually change inputs' ``dtype`` when enabling mixed precision.
data = [torch.randn(batch_size, in_size) for _ in range(num_batches)]
targets = [torch.randn(batch_size, out_size) for _ in range(num_batches)]

loss_fn = torch.nn.MSELoss().cuda()

默认精度

如果没有 torch.cuda.amp,以下简单网络将所有操作以默认精度 (torch.float32) 执行。

net = make_model(in_size, out_size, num_layers)
opt = torch.optim.SGD(net.parameters(), lr=0.001)

start_timer()
for epoch in range(epochs):
    for input, target in zip(data, targets):
        output = net(input)
        loss = loss_fn(output, target)
        loss.backward()
        opt.step()
        opt.zero_grad() # set_to_none=True here can modestly improve performance
end_timer_and_print("Default precision:")

添加 torch.autocast

torch.autocast 的实例充当上下文管理器,允许脚本的某些区域以混合精度运行。

在这些区域中,CUDA 算子以 autocast 选择的 dtype 运行,以提高性能同时保持精度。有关 autocast 为每个算子选择何种精度以及在何种情况下的详细信息,请参阅Autocast 算子参考

for epoch in range(0): # 0 epochs, this section is for illustration only
    for input, target in zip(data, targets):
        # Runs the forward pass under ``autocast``.
        with torch.autocast(device_type=device, dtype=torch.float16):
            output = net(input)
            # output is float16 because linear layers ``autocast`` to float16.
            assert output.dtype is torch.float16

            loss = loss_fn(output, target)
            # loss is float32 because ``mse_loss`` layers ``autocast`` to float32.
            assert loss.dtype is torch.float32

        # Exits ``autocast`` before backward().
        # Backward passes under ``autocast`` are not recommended.
        # Backward ops run in the same ``dtype`` ``autocast`` chose for corresponding forward ops.
        loss.backward()
        opt.step()
        opt.zero_grad() # set_to_none=True here can modestly improve performance

添加 GradScaler

梯度缩放有助于防止在使用混合精度训练时,梯度因幅度过小而变为零(“下溢”)。

torch.cuda.amp.GradScaler 方便地执行梯度缩放步骤。

# Constructs a ``scaler`` once, at the beginning of the convergence run, using default arguments.
# If your network fails to converge with default ``GradScaler`` arguments, please file an issue.
# The same ``GradScaler`` instance should be used for the entire convergence run.
# If you perform multiple convergence runs in the same script, each run should use
# a dedicated fresh ``GradScaler`` instance. ``GradScaler`` instances are lightweight.
scaler = torch.amp.GradScaler("cuda")

for epoch in range(0): # 0 epochs, this section is for illustration only
    for input, target in zip(data, targets):
        with torch.autocast(device_type=device, dtype=torch.float16):
            output = net(input)
            loss = loss_fn(output, target)

        # Scales loss. Calls ``backward()`` on scaled loss to create scaled gradients.
        scaler.scale(loss).backward()

        # ``scaler.step()`` first unscales the gradients of the optimizer's assigned parameters.
        # If these gradients do not contain ``inf``s or ``NaN``s, optimizer.step() is then called,
        # otherwise, optimizer.step() is skipped.
        scaler.step(opt)

        # Updates the scale for next iteration.
        scaler.update()

        opt.zero_grad() # set_to_none=True here can modestly improve performance

全部整合:“自动混合精度”

(以下还展示了 enabled,它是 autocastGradScaler 的一个可选便利参数。如果设置为 False,autocastGradScaler 的调用将变为无操作。这允许在默认精度和混合精度之间切换,而无需使用 if/else 语句。)

use_amp = True

net = make_model(in_size, out_size, num_layers)
opt = torch.optim.SGD(net.parameters(), lr=0.001)
scaler = torch.amp.GradScaler("cuda" ,enabled=use_amp)

start_timer()
for epoch in range(epochs):
    for input, target in zip(data, targets):
        with torch.autocast(device_type=device, dtype=torch.float16, enabled=use_amp):
            output = net(input)
            loss = loss_fn(output, target)
        scaler.scale(loss).backward()
        scaler.step(opt)
        scaler.update()
        opt.zero_grad() # set_to_none=True here can modestly improve performance
end_timer_and_print("Mixed precision:")

检查/修改梯度(例如,裁剪)

scaler.scale(loss).backward() 产生的梯度都是经过缩放的。如果你希望在 backward()scaler.step(optimizer) 之间修改或检查参数的 .grad 属性,应首先使用 scaler.unscale_(optimizer) 取消缩放。

for epoch in range(0): # 0 epochs, this section is for illustration only
    for input, target in zip(data, targets):
        with torch.autocast(device_type=device, dtype=torch.float16):
            output = net(input)
            loss = loss_fn(output, target)
        scaler.scale(loss).backward()

        # Unscales the gradients of optimizer's assigned parameters in-place
        scaler.unscale_(opt)

        # Since the gradients of optimizer's assigned parameters are now unscaled, clips as usual.
        # You may use the same value for max_norm here as you would without gradient scaling.
        torch.nn.utils.clip_grad_norm_(net.parameters(), max_norm=0.1)

        scaler.step(opt)
        scaler.update()
        opt.zero_grad() # set_to_none=True here can modestly improve performance

保存/恢复

要以比特级精度保存/恢复启用 Amp 的运行,请使用 scaler.state_dictscaler.load_state_dict

保存时,将 scaler 的 state dict 与通常的模型和优化器 state dicts 一起保存。可以在迭代开始、任何前向传播之前进行此操作,或者在迭代结束、调用 scaler.update() 之后进行。

checkpoint = {"model": net.state_dict(),
              "optimizer": opt.state_dict(),
              "scaler": scaler.state_dict()}
# Write checkpoint as desired, e.g.,
# torch.save(checkpoint, "filename")

恢复时,加载 scaler 的 state dict 以及模型和优化器 state dicts。按需读取检查点,例如

dev = torch.cuda.current_device()
checkpoint = torch.load("filename",
                        map_location = lambda storage, loc: storage.cuda(dev))
net.load_state_dict(checkpoint["model"])
opt.load_state_dict(checkpoint["optimizer"])
scaler.load_state_dict(checkpoint["scaler"])

如果检查点是来自启用 Amp 的运行,并且你想启用 Amp 恢复训练,则照常从检查点加载模型和优化器状态。检查点不会包含已保存的 scaler 状态,因此请使用一个新的 GradScaler 实例。

如果检查点是来自启用 Amp 的运行,并且你想禁用 Amp 恢复训练,则照常从检查点加载模型和优化器状态,并忽略已保存的 scaler 状态。

推理/评估

autocast 可单独用于包装推理或评估的前向传播。GradScaler 不是必需的。

高级主题

有关高级用例,包括以下内容,请参阅自动混合精度示例

  • 梯度累积

  • 梯度惩罚/二次反向传播

  • 具有多个模型、优化器或损失函数的网络

  • 多 GPU(torch.nn.DataParalleltorch.nn.parallel.DistributedDataParallel

  • 自定义自动微分函数(torch.autograd.Function 的子类)

如果在同一个脚本中执行多次收敛运行,则每次运行都应使用一个专用的全新 GradScaler 实例。GradScaler 实例是轻量级的。

如果使用 dispatcher 注册自定义 C++ 算子,请参阅 dispatcher 教程的autocast 部分

故障排除

使用 Amp 加速效果不明显

  1. 你的网络可能未能使 GPU 饱和工作,因此受限于 CPU。Amp 对 GPU 性能的影响将无关紧要。

    • 使 GPU 饱和的一个经验法则是,在不发生 OOM 的前提下,尽可能增加 batch 大小和/或网络大小。

    • 尽量避免过多的 CPU-GPU 同步(例如 .item() 调用或打印 CUDA 张量的值)。

    • 尽量避免连续执行许多小的 CUDA 算子(如果可能,将它们合并成几个大的 CUDA 算子)。

  2. 你的网络可能是 GPU 计算瓶颈(有大量 matmuls/卷积),但你的 GPU 不具备 Tensor Cores。在这种情况下,预期加速效果会降低。

  3. matmul 的维度对 Tensor Core 不友好。确保 matmuls 相关维度是 8 的倍数。(对于带 encoder/decoder 的 NLP 模型,这可能比较微妙。此外,卷积过去也有类似的 Tensor Core 使用尺寸限制,但对于 CuDNN 7.3 及更高版本,不再存在此类限制。请参阅此处获取指导。)

损失为 inf/NaN

首先,检查你的网络是否符合高级用例。另请参阅优先使用 binary_cross_entropy_with_logits 而非 binary_cross_entropy

如果你确信 Amp 的使用方法正确,可能需要提交一个问题,但在提交之前,收集以下信息会很有帮助

  1. 单独禁用 autocastGradScaler(通过在其构造函数中传入 enabled=False),查看 infs/NaNs 是否仍然存在。

  2. 如果你怀疑网络中的某个部分(例如,复杂的损失函数)溢出,请将该前向区域在 float32 中运行,查看 infs/NaNs 是否仍然存在。autocast docstring 的最后一个代码片段展示了如何强制某个子区域在 float32 中运行(通过局部禁用 autocast 并转换该子区域的输入)。

类型不匹配错误(可能表现为 CUDNN_STATUS_BAD_PARAM

Autocast 试图涵盖所有受益于或需要类型转换的算子。获得明确覆盖的算子是根据数值特性以及经验选择的。如果在启用了 autocast 的前向区域或其后的反向传播中看到类型不匹配错误,可能是 autocast 遗漏了某个算子。

请提交一个包含错误回溯的 issue。在运行脚本之前设置环境变量 export TORCH_SHOW_CPP_STACKTRACES=1,可以提供关于哪个后端算子失败的详细信息。

脚本总运行时间: ( 0 分钟 0.000 秒)

由 Sphinx-Gallery 生成的图库

文档

查阅全面的 PyTorch 开发者文档

查看文档

教程

获取面向初学者和高级开发者的深度教程

查看教程

资源

查找开发资源并获得问题解答

查看资源