快捷方式

自动混合精度

PyTorch/XLA 的 AMP 扩展了 PyTorch 的 AMP 包,支持在 XLA:GPUXLA:TPU 设备上进行自动混合精度计算。AMP 通过在 float32 中执行某些操作,在较低精度数据类型(根据硬件支持,可以是 float16bfloat16)中执行其他操作来加速训练和推理。本文档描述了如何在 XLA 设备上使用 AMP 以及最佳实践。

用于 XLA:TPU 的 AMP

由于 TPU 本身支持 bfloat16,TPU 上的 AMP 会自动将操作转换为在 float32bfloat16 中运行。一个简单的 TPU AMP 示例如下

# Creates model and optimizer in default precision
model = Net().to(xm.xla_device())
# Pytorch/XLA provides sync-free optimizers for improved performance
optimizer = syncfree.SGD(model.parameters(), ...)

for input, target in data:
    optimizer.zero_grad()

    # Enables autocasting for the forward pass
    with autocast(xm.xla_device()):
        output = model(input)
        loss = loss_fn(output, target)

    # Exits the context manager before backward()
    loss.backward()
    xm.optimizer_step.(optimizer)

当 XLA 设备是 TPU 时,autocast(xm.xla_device()) 别名为 torch.autocast('xla')。或者,如果脚本仅用于 TPU,则可以直接使用 torch.autocast('xla', dtype=torch.bfloat16)

如果存在某个操作应该被自动类型转换但未包含在内,请提交问题或拉取请求。

用于 XLA:TPU 的 AMP 最佳实践

  1. autocast 应仅包裹网络的正向传播和损失计算。反向传播操作与其对应的正向传播操作使用相同的类型运行。

  2. 由于 TPU 使用 bfloat16 混合精度,因此无需进行梯度缩放。

  3. PyTorch/XLA 提供了 优化器 的修改版本,避免了设备与主机之间的额外同步。

支持的操作符

TPU 上的 AMP 工作方式类似于 PyTorch 的 AMP。自动类型转换的应用规则总结如下

只有非就地 (out-of-place) 操作和 Tensor 方法符合自动类型转换的条件。就地 (in-place) 变体和明确提供 out=... Tensor 的调用在启用自动类型转换的区域中是被允许的,但不会经过自动类型转换。例如,在启用自动类型转换的区域中,a.addmm(b, c) 可以自动类型转换,但 a.addmm_(b, c) 和 a.addmm(b, c, out=d) 不会。为了获得最佳性能和稳定性,在启用自动类型转换的区域中优先使用非就地操作。

以 float64 或非浮点 dtype 运行的操作符不符合条件,无论是否启用自动类型转换,都将以这些类型运行。此外,使用显式 dtype=... 参数调用的操作符不符合条件,并将生成遵循 dtype 参数的输出。

未在下方列出的操作符不会经过自动类型转换。它们以其输入定义的类型运行。如果未列出的操作符是自动类型转换操作符的下游,则自动类型转换仍可能改变其运行类型。

自动类型转换为 ``bfloat16`` 的操作符

__matmul__, addbmm, addmm, addmv, addr, baddbmm,bmm, conv1d, conv2d, conv3d, conv_transpose1d, conv_transpose2d, conv_transpose3d, linear, matmul, mm, relu, prelu, max_pool2d

自动类型转换为 ``float32`` 的操作符

batch_norm, log_softmax, binary_cross_entropy, binary_cross_entropy_with_logits, prod, cdist, trace, chloesky ,inverse, reflection_pad, replication_pad, mse_loss, cosine_embbeding_loss, nll_loss, multilabel_margin_loss, qr, svd, triangular_solve, linalg_svd, linalg_inv_ex

自动类型转换为最宽输入类型的操作符

stack, cat, index_copy

用于 XLA:GPU 的 AMP

XLA:GPU 设备上的 AMP 复用 PyTorch 的 AMP 规则。关于 CUDA 特定行为,请参阅 PyTorch 的 AMP 文档。一个简单的 CUDA AMP 示例如下

# Creates model and optimizer in default precision
model = Net().to(xm.xla_device())
# Pytorch/XLA provides sync-free optimizers for improved performance
optimizer = syncfree.SGD(model.parameters(), ...)
scaler = GradScaler()

for input, target in data:
    optimizer.zero_grad()

    # Enables autocasting for the forward pass
    with autocast(xm.xla_device()):
        output = model(input)
        loss = loss_fn(output, target)

    # Exits the context manager before backward pass
    scaler.scale(loss).backward()
    gradients = xm._fetch_gradients(optimizer)
    xm.all_reduce('sum', gradients, scale=1.0 / xr.world_size())
    scaler.step(optimizer)
    scaler.update()

当 XLA 设备是 CUDA 设备 (XLA:GPU) 时,autocast(xm.xla_device()) 别名为 torch.cuda.amp.autocast()。或者,如果脚本仅用于 CUDA 设备,则可以直接使用 torch.cuda.amp.autocast,但这要求 torch 是用 cuda 支持 torch.bfloat16 数据类型编译的。我们建议在 XLA:GPU 上使用 autocast(xm.xla_device()),因为它不需要 torch.cuda 支持任何数据类型,包括 torch.bfloat16

用于 XLA:GPU 的 AMP 最佳实践

  1. autocast 应仅包裹网络的正向传播和损失计算。反向传播操作与其对应的正向传播操作使用相同的类型运行。

  2. 在 Cuda 设备上使用 AMP 时,不要设置 XLA_USE_F16 标志。这会覆盖 AMP 提供的每个操作符的精度设置,并导致所有操作符以 float16 执行。

  3. 使用梯度缩放以防止 float16 梯度下溢。

  4. PyTorch/XLA 提供了 优化器 的修改版本,避免了设备与主机之间的额外同步。

示例

我们的 mnist 训练脚本imagenet 训练脚本 展示了如何在 TPU 和 GPU 上使用 AMP。

文档

访问全面的 PyTorch 开发者文档

查看文档

教程

获取面向初学者和高级开发者的深入教程

查看教程

资源

查找开发资源并获得问题解答

查看资源