自动混合精度¶
Pytorch/XLA 的 AMP 扩展了 Pytorch 的 AMP 包,增加了对 XLA:GPU
和 XLA:TPU
设备上自动混合精度的支持。AMP 用于通过以 float32
执行某些操作,并以较低精度数据类型(float16
或 bfloat16
,取决于硬件支持)执行其他操作来加速训练和推理。本文档介绍了如何在 XLA 设备上使用 AMP 以及最佳实践。
XLA:TPU 的 AMP¶
TPU 上的 AMP 自动将操作转换为以 float32
或 bfloat16
运行,因为 TPU 原生支持 bfloat16。下面是一个简单的 TPU AMP 示例
# Creates model and optimizer in default precision
model = Net().to(xm.xla_device())
# Pytorch/XLA provides sync-free optimizers for improved performance
optimizer = syncfree.SGD(model.parameters(), ...)
for input, target in data:
optimizer.zero_grad()
# Enables autocasting for the forward pass
with autocast(xm.xla_device()):
output = model(input)
loss = loss_fn(output, target)
# Exits the context manager before backward()
loss.backward()
xm.optimizer_step.(optimizer)
autocast(xm.xla_device())
在 XLA 设备是 TPU 时,别名为 torch.autocast('xla')
。或者,如果脚本仅与 TPU 一起使用,则可以直接使用 torch.autocast('xla', dtype=torch.bfloat16)
。
如果存在应自动转换但未包含的运算符,请提交问题或拉取请求。
XLA:TPU 的 AMP 最佳实践¶
autocast
应仅包装网络的前向传递和损失计算。反向操作以 autocast 用于相应前向操作的相同类型运行。由于 TPU 使用 bfloat16 混合精度,因此无需梯度缩放。
Pytorch/XLA 提供了 optimizers 的修改版本,避免了设备和主机之间的额外同步。
支持的运算符¶
TPU 上的 AMP 像 Pytorch 的 AMP 一样运行。下面总结了如何应用自动转换的规则
只有 out-of-place 操作和 Tensor 方法有资格进行自动转换。In-place 变体和显式提供 out=… Tensor 的调用在启用自动转换的区域中是允许的,但不会进行自动转换。例如,在启用自动转换的区域中,a.addmm(b, c) 可以自动转换,但 a.addmm_(b, c) 和 a.addmm(b, c, out=d) 不能。为了获得最佳性能和稳定性,请在启用自动转换的区域中首选 out-of-place 操作。
以 float64 或非浮点数据类型运行的操作不符合条件,无论是否启用自动转换,都将以这些类型运行。此外,使用显式 dtype=… 参数调用的操作也不符合条件,并将生成符合 dtype 参数的输出。
下面未列出的操作不会进行自动转换。它们以输入定义的类型运行。如果未列出的操作位于自动转换操作的下游,则自动转换仍可能更改这些操作的运行类型。
自动转换为 ``bfloat16`` 的操作
__matmul__
, addbmm
, addmm
, addmv
, addr
, baddbmm
,bmm
, conv1d
, conv2d
, conv3d
, conv_transpose1d
, conv_transpose2d
, conv_transpose3d
, linear
, matmul
, mm
, relu
, prelu
, max_pool2d
自动转换为 ``float32`` 的操作
batch_norm
, log_softmax
, binary_cross_entropy
, binary_cross_entropy_with_logits
, prod
, cdist
, trace
, chloesky
,inverse
, reflection_pad
, replication_pad
, mse_loss
, cosine_embbeding_loss
, nll_loss
, multilabel_margin_loss
, qr
, svd
, triangular_solve
, linalg_svd
, linalg_inv_ex
自动转换为最宽输入类型的操作
stack
, cat
, index_copy
XLA:GPU 的 AMP¶
XLA:GPU 设备上的 AMP 重用 Pytorch 的 AMP 规则。有关 CUDA 特定行为,请参阅 Pytorch 的 AMP 文档。下面是一个简单的 CUDA AMP 示例
# Creates model and optimizer in default precision
model = Net().to(xm.xla_device())
# Pytorch/XLA provides sync-free optimizers for improved performance
optimizer = syncfree.SGD(model.parameters(), ...)
scaler = GradScaler()
for input, target in data:
optimizer.zero_grad()
# Enables autocasting for the forward pass
with autocast(xm.xla_device()):
output = model(input)
loss = loss_fn(output, target)
# Exits the context manager before backward pass
scaler.scale(loss).backward()
gradients = xm._fetch_gradients(optimizer)
xm.all_reduce('sum', gradients, scale=1.0 / xr.world_size())
scaler.step(optimizer)
scaler.update()
autocast(xm.xla_device())
在 XLA 设备是 CUDA 设备 (XLA:GPU) 时,别名为 torch.cuda.amp.autocast()
。或者,如果脚本仅与 CUDA 设备一起使用,则可以直接使用 torch.cuda.amp.autocast
,但需要 torch
在编译时支持 cuda
数据类型 torch.bfloat16
。我们建议在 XLA:GPU 上使用 autocast(xm.xla_device())
,因为它不需要 torch.cuda
支持任何数据类型,包括 torch.bfloat16
。
XLA:GPU 的 AMP 最佳实践¶
autocast
应仅包装网络的前向传递和损失计算。反向操作以 autocast 用于相应前向操作的相同类型运行。在 Cuda 设备上使用 AMP 时,请勿设置
XLA_USE_F16
标志。这将覆盖 AMP 提供的每个运算符的精度设置,并导致所有运算符以 float16 执行。使用梯度缩放来防止 float16 梯度下溢。
Pytorch/XLA 提供了 optimizers 的修改版本,避免了设备和主机之间的额外同步。
示例¶
我们的 mnist 训练脚本 和 imagenet 训练脚本 演示了如何在 TPU 和 GPU 上使用 AMP。