自动混合精度示例¶
通常,“自动混合精度训练”意味着同时使用 torch.autocast
和 torch.amp.GradScaler
进行训练。
torch.autocast
的实例为选定区域启用自动类型转换。自动类型转换自动选择运算精度,以提高性能并保持准确性。
torch.amp.GradScaler
的实例有助于方便地执行梯度缩放步骤。梯度缩放通过最大限度地减少梯度下溢来提高具有 float16
(CUDA 和 XPU 上默认为此类型)梯度的网络的收敛性,具体说明请参阅此处。
torch.autocast
和 torch.amp.GradScaler
是模块化的。在下面的示例中,每个都按照其各自的文档建议使用。
(此处的示例仅为说明性目的。有关可运行的演练,请参阅自动混合精度配方。)
典型的混合精度训练¶
# Creates model and optimizer in default precision
model = Net().cuda()
optimizer = optim.SGD(model.parameters(), ...)
# Creates a GradScaler once at the beginning of training.
scaler = GradScaler()
for epoch in epochs:
for input, target in data:
optimizer.zero_grad()
# Runs the forward pass with autocasting.
with autocast(device_type='cuda', dtype=torch.float16):
output = model(input)
loss = loss_fn(output, target)
# Scales loss. Calls backward() on scaled loss to create scaled gradients.
# Backward passes under autocast are not recommended.
# Backward ops run in the same dtype autocast chose for corresponding forward ops.
scaler.scale(loss).backward()
# scaler.step() first unscales the gradients of the optimizer's assigned params.
# If these gradients do not contain infs or NaNs, optimizer.step() is then called,
# otherwise, optimizer.step() is skipped.
scaler.step(optimizer)
# Updates the scale for next iteration.
scaler.update()
处理未缩放的梯度¶
scaler.scale(loss).backward()
产生的所有梯度都是缩放的。如果您希望在 backward()
和 scaler.step(optimizer)
之间修改或检查参数的 .grad
属性,则应首先取消缩放它们。例如,梯度裁剪会操作一组梯度,使其全局范数(参见 torch.nn.utils.clip_grad_norm_()
)或最大幅度(参见 torch.nn.utils.clip_grad_value_()
) 某个用户强加的阈值。如果您尝试在不取消缩放的情况下裁剪,则梯度的范数/最大幅度也会被缩放,因此您请求的阈值(旨在作为未缩放梯度的阈值)将无效。
scaler.unscale_(optimizer)
取消缩放 optimizer
的已分配参数持有的梯度。如果您的模型包含分配给另一个优化器(例如 optimizer2
)的其他参数,您可以单独调用 scaler.unscale_(optimizer2)
以取消缩放这些参数的梯度。
梯度裁剪¶
在裁剪之前调用 scaler.unscale_(optimizer)
使您能够像往常一样裁剪未缩放的梯度
scaler = GradScaler()
for epoch in epochs:
for input, target in data:
optimizer.zero_grad()
with autocast(device_type='cuda', dtype=torch.float16):
output = model(input)
loss = loss_fn(output, target)
scaler.scale(loss).backward()
# Unscales the gradients of optimizer's assigned params in-place
scaler.unscale_(optimizer)
# Since the gradients of optimizer's assigned params are unscaled, clips as usual:
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm)
# optimizer's gradients are already unscaled, so scaler.step does not unscale them,
# although it still skips optimizer.step() if the gradients contain infs or NaNs.
scaler.step(optimizer)
# Updates the scale for next iteration.
scaler.update()
scaler
记录 scaler.unscale_(optimizer)
已在此迭代中为此优化器调用,因此 scaler.step(optimizer)
知道在(内部)调用 optimizer.step()
之前不要冗余地取消缩放梯度。
警告
unscale_
对于每个优化器,每次 step
调用只能调用一次,并且只能在累积了该优化器分配的参数的所有梯度之后调用。在每次 step
之间为给定的优化器调用 unscale_
两次会触发 RuntimeError。
处理缩放的梯度¶
梯度累积¶
梯度累积在有效批大小 batch_per_iter * iters_to_accumulate
(如果分布式,则为 * num_procs
)上添加梯度。应该为有效批校准缩放,这意味着 inf/NaN 检查、如果发现 inf/NaN 梯度则跳过步骤以及缩放更新应以有效批粒度发生。此外,在累积给定有效批的梯度时,梯度应保持缩放,并且缩放因子应保持不变。如果在累积完成之前梯度被取消缩放(或缩放因子发生变化),则下一个反向传播将缩放梯度添加到未缩放梯度(或由不同因子缩放的梯度),之后无法恢复必须应用的累积未缩放梯度 step
。
因此,如果您想 unscale_
梯度(例如,允许裁剪未缩放的梯度),请在 step
之前调用 unscale_
,在即将到来的 step
的所有(缩放的)梯度都已累积之后。此外,仅在您为完整有效批调用 step
的迭代结束时调用 update
scaler = GradScaler()
for epoch in epochs:
for i, (input, target) in enumerate(data):
with autocast(device_type='cuda', dtype=torch.float16):
output = model(input)
loss = loss_fn(output, target)
loss = loss / iters_to_accumulate
# Accumulates scaled gradients.
scaler.scale(loss).backward()
if (i + 1) % iters_to_accumulate == 0:
# may unscale_ here if desired (e.g., to allow clipping unscaled gradients)
scaler.step(optimizer)
scaler.update()
optimizer.zero_grad()
梯度惩罚¶
梯度惩罚实现通常使用 torch.autograd.grad()
创建梯度,将它们组合以创建惩罚值,并将惩罚值添加到损失中。
这是一个没有梯度缩放或自动类型转换的 L2 惩罚的普通示例
for epoch in epochs:
for input, target in data:
optimizer.zero_grad()
output = model(input)
loss = loss_fn(output, target)
# Creates gradients
grad_params = torch.autograd.grad(outputs=loss,
inputs=model.parameters(),
create_graph=True)
# Computes the penalty term and adds it to the loss
grad_norm = 0
for grad in grad_params:
grad_norm += grad.pow(2).sum()
grad_norm = grad_norm.sqrt()
loss = loss + grad_norm
loss.backward()
# clip gradients here, if desired
optimizer.step()
要使用梯度缩放实现梯度惩罚,传递给 torch.autograd.grad()
的 outputs
张量应进行缩放。因此,生成的梯度将被缩放,并且应在组合以创建惩罚值之前取消缩放。
此外,惩罚项计算是前向传播的一部分,因此应在 autocast
上下文中。
以下是相同 L2 惩罚的外观
scaler = GradScaler()
for epoch in epochs:
for input, target in data:
optimizer.zero_grad()
with autocast(device_type='cuda', dtype=torch.float16):
output = model(input)
loss = loss_fn(output, target)
# Scales the loss for autograd.grad's backward pass, producing scaled_grad_params
scaled_grad_params = torch.autograd.grad(outputs=scaler.scale(loss),
inputs=model.parameters(),
create_graph=True)
# Creates unscaled grad_params before computing the penalty. scaled_grad_params are
# not owned by any optimizer, so ordinary division is used instead of scaler.unscale_:
inv_scale = 1./scaler.get_scale()
grad_params = [p * inv_scale for p in scaled_grad_params]
# Computes the penalty term and adds it to the loss
with autocast(device_type='cuda', dtype=torch.float16):
grad_norm = 0
for grad in grad_params:
grad_norm += grad.pow(2).sum()
grad_norm = grad_norm.sqrt()
loss = loss + grad_norm
# Applies scaling to the backward call as usual.
# Accumulates leaf gradients that are correctly scaled.
scaler.scale(loss).backward()
# may unscale_ here if desired (e.g., to allow clipping unscaled gradients)
# step() and update() proceed as usual.
scaler.step(optimizer)
scaler.update()
处理多个模型、损失和优化器¶
如果您的网络有多个损失,则必须单独对每个损失调用 scaler.scale
。如果您的网络有多个优化器,您可以单独对任何一个优化器调用 scaler.unscale_
,并且必须单独对每个优化器调用 scaler.step
。
但是,scaler.update
应该只调用一次,在本次迭代中使用的所有优化器都已执行 step 之后
scaler = torch.amp.GradScaler()
for epoch in epochs:
for input, target in data:
optimizer0.zero_grad()
optimizer1.zero_grad()
with autocast(device_type='cuda', dtype=torch.float16):
output0 = model0(input)
output1 = model1(input)
loss0 = loss_fn(2 * output0 + 3 * output1, target)
loss1 = loss_fn(3 * output0 - 5 * output1, target)
# (retain_graph here is unrelated to amp, it's present because in this
# example, both backward() calls share some sections of graph.)
scaler.scale(loss0).backward(retain_graph=True)
scaler.scale(loss1).backward()
# You can choose which optimizers receive explicit unscaling, if you
# want to inspect or modify the gradients of the params they own.
scaler.unscale_(optimizer0)
scaler.step(optimizer0)
scaler.step(optimizer1)
scaler.update()
每个优化器都会检查其梯度中是否存在 inf/NaN,并独立决定是否跳过 step。这可能会导致一个优化器跳过 step,而另一个优化器不跳过。由于 step 跳过很少发生(每几百次迭代一次),因此这不应妨碍收敛。如果您在向多优化器模型添加梯度缩放后观察到收敛性差,请报告错误。
处理多个 GPU¶
此处描述的问题仅影响 autocast
。GradScaler
的用法不变。
单进程中的 DataParallel¶
即使 torch.nn.DataParallel
产生线程以在每个设备上运行前向传播。自动类型转换状态在每个线程中传播,以下代码将起作用
model = MyModel()
dp_model = nn.DataParallel(model)
# Sets autocast in the main thread
with autocast(device_type='cuda', dtype=torch.float16):
# dp_model's internal threads will autocast.
output = dp_model(input)
# loss_fn also autocast
loss = loss_fn(output)
DistributedDataParallel,每个进程一个 GPU¶
torch.nn.parallel.DistributedDataParallel
的文档建议每个进程使用一个 GPU 以获得最佳性能。在这种情况下,DistributedDataParallel
不会在内部产生线程,因此 autocast
和 GradScaler
的用法不受影响。
DistributedDataParallel,每个进程多个 GPU¶
在这里,torch.nn.parallel.DistributedDataParallel
可能会产生一个辅助线程,以像 torch.nn.DataParallel
一样在每个设备上运行前向传播。修复方法相同:将 autocast 应用为模型 forward
方法的一部分,以确保在辅助线程中启用它。
Autocast 和自定义 Autograd 函数¶
如果您的网络使用自定义 autograd 函数(torch.autograd.Function
的子类),如果任何函数满足以下条件,则需要进行更改以实现 autocast 兼容性
接受多个浮点张量输入,
包装任何可自动类型转换的运算(参见自动类型转换运算参考),或
需要特定的
dtype
(例如,如果它包装了仅针对dtype
编译的 CUDA 扩展)。
在所有情况下,如果您要导入该函数并且无法更改其定义,则安全的后备方案是在发生错误的使用点禁用 autocast 并强制以 float32
(或 dtype
)执行
with autocast(device_type='cuda', dtype=torch.float16):
...
with autocast(device_type='cuda', dtype=torch.float16, enabled=False):
output = imported_function(input1.float(), input2.float())
如果您是函数的作者(或可以更改其定义),则更好的解决方案是使用 torch.amp.custom_fwd()
和 torch.amp.custom_bwd()
装饰器,如下面相关案例所示。
具有多个输入或可自动类型转换运算的函数¶
分别将 custom_fwd
和 custom_bwd
(不带参数)应用于 forward
和 backward
。这些确保 forward
以当前的 autocast 状态执行,并且 backward
以与 forward
相同的 autocast 状态执行(这可以防止类型不匹配错误)
class MyMM(torch.autograd.Function):
@staticmethod
@custom_fwd
def forward(ctx, a, b):
ctx.save_for_backward(a, b)
return a.mm(b)
@staticmethod
@custom_bwd
def backward(ctx, grad):
a, b = ctx.saved_tensors
return grad.mm(b.t()), a.t().mm(grad)
现在可以在任何地方调用 MyMM
,而无需禁用 autocast 或手动转换输入
mymm = MyMM.apply
with autocast(device_type='cuda', dtype=torch.float16):
output = mymm(input1, input2)
需要特定 dtype
的函数¶
考虑一个需要 torch.float32
输入的自定义函数。将 custom_fwd(device_type='cuda', cast_inputs=torch.float32)
应用于 forward
,将 custom_bwd(device_type='cuda')
应用于 backward
。如果 forward
在启用 autocast 的区域中运行,则装饰器会将浮点张量输入转换为参数 device_type(在本例中为 CUDA)指定的设备上的 float32
,并在 forward
和 backward
期间本地禁用 autocast
class MyFloat32Func(torch.autograd.Function):
@staticmethod
@custom_fwd(device_type='cuda', cast_inputs=torch.float32)
def forward(ctx, input):
ctx.save_for_backward(input)
...
return fwd_output
@staticmethod
@custom_bwd(device_type='cuda')
def backward(ctx, grad):
...
现在可以在任何地方调用 MyFloat32Func
,而无需手动禁用 autocast 或转换输入
func = MyFloat32Func.apply
with autocast(device_type='cuda', dtype=torch.float16):
# func will run in float32, regardless of the surrounding autocast state
output = func(input)