自动混合精度示例¶
通常,“自动混合精度训练”是指同时使用 torch.autocast 和 torch.amp.GradScaler 进行训练。
torch.autocast 的实例允许对选定的区域进行自动类型转换。自动类型转换会自动选择操作的精度,以提高性能同时保持准确性。
torch.amp.GradScaler 的实例有助于方便地执行梯度缩放步骤。梯度缩放通过最小化梯度下溢来提高使用 float16(在 CUDA 和 XPU 上默认为此类型)梯度的网络的收敛性,详情请参阅此处。
torch.autocast 和 torch.amp.GradScaler 是模块化的。在下面的示例中,它们的使用方式都遵循各自文档的建议。
(此处的示例仅作说明。请参阅自动混合精度秘籍以获取可运行的演练。)
典型的混合精度训练¶
# Creates model and optimizer in default precision
model = Net().cuda()
optimizer = optim.SGD(model.parameters(), ...)
# Creates a GradScaler once at the beginning of training.
scaler = GradScaler()
for epoch in epochs:
for input, target in data:
optimizer.zero_grad()
# Runs the forward pass with autocasting.
with autocast(device_type='cuda', dtype=torch.float16):
output = model(input)
loss = loss_fn(output, target)
# Scales loss. Calls backward() on scaled loss to create scaled gradients.
# Backward passes under autocast are not recommended.
# Backward ops run in the same dtype autocast chose for corresponding forward ops.
scaler.scale(loss).backward()
# scaler.step() first unscales the gradients of the optimizer's assigned params.
# If these gradients do not contain infs or NaNs, optimizer.step() is then called,
# otherwise, optimizer.step() is skipped.
scaler.step(optimizer)
# Updates the scale for next iteration.
scaler.update()
处理未缩放的梯度¶
scaler.scale(loss).backward() 产生的全部梯度都是缩放后的。如果您希望在 backward() 和 scaler.step(optimizer) 之间修改或检查参数的 .grad 属性,您应该先取消缩放它们。例如,梯度裁剪会操作一组梯度,使得它们的全局范数(参见torch.nn.utils.clip_grad_norm_())或最大幅度(参见torch.nn.utils.clip_grad_value_()) 某个用户设定的阈值。如果您尝试在不取消缩放的情况下进行裁剪,则梯度的范数/最大幅度也会被缩放,因此您请求的阈值(原意是用于未缩放梯度的阈值)将无效。
scaler.unscale_(optimizer) 会取消缩放由 optimizer 分配的参数所持有的梯度。如果您的模型或模型包含分配给另一个优化器(例如 optimizer2)的其他参数,您可以单独调用 scaler.unscale_(optimizer2) 来取消缩放这些参数的梯度。
梯度裁剪¶
在裁剪之前调用 scaler.unscale_(optimizer) 使您能够像往常一样裁剪未缩放的梯度
scaler = GradScaler()
for epoch in epochs:
for input, target in data:
optimizer.zero_grad()
with autocast(device_type='cuda', dtype=torch.float16):
output = model(input)
loss = loss_fn(output, target)
scaler.scale(loss).backward()
# Unscales the gradients of optimizer's assigned params in-place
scaler.unscale_(optimizer)
# Since the gradients of optimizer's assigned params are unscaled, clips as usual:
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm)
# optimizer's gradients are already unscaled, so scaler.step does not unscale them,
# although it still skips optimizer.step() if the gradients contain infs or NaNs.
scaler.step(optimizer)
# Updates the scale for next iteration.
scaler.update()
scaler 会记录本轮迭代中已经为此优化器调用了 scaler.unscale_(optimizer),因此 scaler.step(optimizer) 知道在(内部)调用 optimizer.step() 之前不需要重复取消缩放梯度。
警告
unscale_ 应在每次 step 调用中针对每个优化器仅调用一次,并且仅在该优化器分配的参数的所有梯度累积完成后调用。step 调用之间为给定优化器调用 unscale_ 两次将触发 RuntimeError。
处理缩放后的梯度¶
梯度累积¶
梯度累积将梯度累加到有效批次大小为 batch_per_iter * iters_to_accumulate(如果是分布式则 * num_procs)的批次上。缩放应针对有效批次进行校准,这意味着 inf/NaN 检查、如果找到 inf/NaN 梯度则跳过步骤,以及缩放更新应以有效批次粒度进行。此外,在给定有效批次的梯度累积期间,梯度应保持缩放状态,并且缩放因子应保持不变。如果在累积完成之前梯度被取消缩放(或缩放因子改变),则下一次反向传播会将缩放后的梯度加到未缩放的梯度(或按不同因子缩放的梯度)上,此后无法恢复 step 必须应用的累积的未缩放梯度。
因此,如果您想取消缩放梯度(例如,允许裁剪未缩放的梯度),请在 step 之前调用 unscale_,在即将到来的 step 的所有(缩放后)梯度累积完成后进行。此外,仅在您为完整的有效批次调用了 step 的迭代结束时调用 update
scaler = GradScaler()
for epoch in epochs:
for i, (input, target) in enumerate(data):
with autocast(device_type='cuda', dtype=torch.float16):
output = model(input)
loss = loss_fn(output, target)
loss = loss / iters_to_accumulate
# Accumulates scaled gradients.
scaler.scale(loss).backward()
if (i + 1) % iters_to_accumulate == 0:
# may unscale_ here if desired (e.g., to allow clipping unscaled gradients)
scaler.step(optimizer)
scaler.update()
optimizer.zero_grad()
梯度惩罚¶
梯度惩罚实现通常使用 torch.autograd.grad() 创建梯度,将它们组合以创建惩罚值,并将惩罚值添加到损失中。
这里是一个没有梯度缩放或自动类型转换的普通 L2 惩罚示例
for epoch in epochs:
for input, target in data:
optimizer.zero_grad()
output = model(input)
loss = loss_fn(output, target)
# Creates gradients
grad_params = torch.autograd.grad(outputs=loss,
inputs=model.parameters(),
create_graph=True)
# Computes the penalty term and adds it to the loss
grad_norm = 0
for grad in grad_params:
grad_norm += grad.pow(2).sum()
grad_norm = grad_norm.sqrt()
loss = loss + grad_norm
loss.backward()
# clip gradients here, if desired
optimizer.step()
要实现带有梯度缩放的梯度惩罚,传递给 torch.autograd.grad() 的 outputs Tensor 应被缩放。因此,产生的梯度将被缩放,并且在组合以创建惩罚值之前应取消缩放。
此外,惩罚项计算是前向传播的一部分,因此应包含在 autocast 上下文中。
对于相同的 L2 惩罚,它看起来是这样的
scaler = GradScaler()
for epoch in epochs:
for input, target in data:
optimizer.zero_grad()
with autocast(device_type='cuda', dtype=torch.float16):
output = model(input)
loss = loss_fn(output, target)
# Scales the loss for autograd.grad's backward pass, producing scaled_grad_params
scaled_grad_params = torch.autograd.grad(outputs=scaler.scale(loss),
inputs=model.parameters(),
create_graph=True)
# Creates unscaled grad_params before computing the penalty. scaled_grad_params are
# not owned by any optimizer, so ordinary division is used instead of scaler.unscale_:
inv_scale = 1./scaler.get_scale()
grad_params = [p * inv_scale for p in scaled_grad_params]
# Computes the penalty term and adds it to the loss
with autocast(device_type='cuda', dtype=torch.float16):
grad_norm = 0
for grad in grad_params:
grad_norm += grad.pow(2).sum()
grad_norm = grad_norm.sqrt()
loss = loss + grad_norm
# Applies scaling to the backward call as usual.
# Accumulates leaf gradients that are correctly scaled.
scaler.scale(loss).backward()
# may unscale_ here if desired (e.g., to allow clipping unscaled gradients)
# step() and update() proceed as usual.
scaler.step(optimizer)
scaler.update()
使用多个模型、损失和优化器¶
如果您的网络有多个损失,您必须单独对每个损失调用 scaler.scale。如果您的网络有多个优化器,您可以单独对其中任何一个调用 scaler.unscale_,并且必须单独对每个优化器调用 scaler.step。
然而,scaler.update 应该只调用一次,在本次迭代中使用的所有优化器都已执行 step 后调用
scaler = torch.amp.GradScaler()
for epoch in epochs:
for input, target in data:
optimizer0.zero_grad()
optimizer1.zero_grad()
with autocast(device_type='cuda', dtype=torch.float16):
output0 = model0(input)
output1 = model1(input)
loss0 = loss_fn(2 * output0 + 3 * output1, target)
loss1 = loss_fn(3 * output0 - 5 * output1, target)
# (retain_graph here is unrelated to amp, it's present because in this
# example, both backward() calls share some sections of graph.)
scaler.scale(loss0).backward(retain_graph=True)
scaler.scale(loss1).backward()
# You can choose which optimizers receive explicit unscaling, if you
# want to inspect or modify the gradients of the params they own.
scaler.unscale_(optimizer0)
scaler.step(optimizer0)
scaler.step(optimizer1)
scaler.update()
每个优化器都会检查其梯度是否存在 inf/NaN,并独立决定是否跳过该步骤。这可能导致一个优化器跳过步骤而另一个不跳过。由于跳过步骤很少发生(每几百次迭代一次),这不应阻碍收敛。如果您在为多优化器模型添加梯度缩放后观察到收敛不良,请报告错误。
使用多个 GPU¶
此处描述的问题仅影响 autocast。GradScaler 的用法保持不变。
单个进程中的 DataParallel¶
即使 torch.nn.DataParallel 会生成线程在每个设备上运行前向传播。autocast 状态会在每个线程中传播,以下代码将起作用
model = MyModel()
dp_model = nn.DataParallel(model)
# Sets autocast in the main thread
with autocast(device_type='cuda', dtype=torch.float16):
# dp_model's internal threads will autocast.
output = dp_model(input)
# loss_fn also autocast
loss = loss_fn(output)
DistributedDataParallel,每个进程一个 GPU¶
torch.nn.parallel.DistributedDataParallel 的文档建议每个进程使用一个 GPU 以获得最佳性能。在这种情况下,DistributedDataParallel 不会在内部生成线程,因此 autocast 和 GradScaler 的用法不受影响。
DistributedDataParallel,每个进程多个 GPU¶
在这种情况下,torch.nn.parallel.DistributedDataParallel 可能会生成一个子线程来在每个设备上运行前向传播,类似于 torch.nn.DataParallel。解决方法是相同的:将 autocast 应用为模型的 forward 方法的一部分,以确保它在子线程中启用。
Autocast 和自定义 Autograd 函数¶
如果您的网络使用了自定义 autograd 函数(torch.autograd.Function 的子类),则如果任何函数满足以下条件,就需要进行更改以实现 autocast 兼容性
接受多个浮点 Tensor 输入,
包装任何可自动类型转换的操作(参见Autocast 操作参考),或
需要特定的
dtype(例如,如果它包装了仅针对该dtype编译的CUDA 扩展)。
在所有情况下,如果您正在导入函数且无法更改其定义,一个安全的备用方案是在发生错误的任何使用点禁用 autocast 并强制以 float32(或 dtype)执行
with autocast(device_type='cuda', dtype=torch.float16):
...
with autocast(device_type='cuda', dtype=torch.float16, enabled=False):
output = imported_function(input1.float(), input2.float())
如果您是该函数的作者(或可以更改其定义),一个更好的解决方案是使用 torch.amp.custom_fwd() 和 torch.amp.custom_bwd() 装饰器,如下面的相关示例所示。
具有多个输入或可自动类型转换操作的函数¶
分别将 custom_fwd 和 custom_bwd(不带参数)应用于 forward 和 backward。这确保 forward 在当前的 autocast 状态下执行,而 backward 在与 forward 相同的 autocast 状态下执行(这可以防止类型不匹配错误)
class MyMM(torch.autograd.Function):
@staticmethod
@custom_fwd
def forward(ctx, a, b):
ctx.save_for_backward(a, b)
return a.mm(b)
@staticmethod
@custom_bwd
def backward(ctx, grad):
a, b = ctx.saved_tensors
return grad.mm(b.t()), a.t().mm(grad)
现在可以在任何地方调用 MyMM,无需禁用 autocast 或手动转换输入
mymm = MyMM.apply
with autocast(device_type='cuda', dtype=torch.float16):
output = mymm(input1, input2)
需要特定 dtype 的函数¶
考虑一个需要 torch.float32 输入的自定义函数。将 custom_fwd(device_type='cuda', cast_inputs=torch.float32) 应用于 forward,将 custom_bwd(device_type='cuda') 应用于 backward。如果 forward 在启用了 autocast 的区域运行,装饰器会将浮点 Tensor 输入转换为由参数 device_type(本例中为 CUDA)指定的设备上的 float32,并在 forward 和 backward 期间本地禁用 autocast
class MyFloat32Func(torch.autograd.Function):
@staticmethod
@custom_fwd(device_type='cuda', cast_inputs=torch.float32)
def forward(ctx, input):
ctx.save_for_backward(input)
...
return fwd_output
@staticmethod
@custom_bwd(device_type='cuda')
def backward(ctx, grad):
...
现在可以在任何地方调用 MyFloat32Func,无需手动禁用 autocast 或转换输入
func = MyFloat32Func.apply
with autocast(device_type='cuda', dtype=torch.float16):
# func will run in float32, regardless of the surrounding autocast state
output = func(input)