自动混合精度示例¶
通常,“自动混合精度训练”是指同时使用 torch.autocast
和 torch.amp.GradScaler
进行训练。
torch.autocast
的实例允许对选定的区域进行自动类型转换。自动类型转换会自动选择操作的精度,以提高性能同时保持准确性。
torch.amp.GradScaler
的实例有助于方便地执行梯度缩放步骤。梯度缩放通过最小化梯度下溢来提高使用 float16
(在 CUDA 和 XPU 上默认为此类型)梯度的网络的收敛性,详情请参阅此处。
torch.autocast
和 torch.amp.GradScaler
是模块化的。在下面的示例中,它们的使用方式都遵循各自文档的建议。
(此处的示例仅作说明。请参阅自动混合精度秘籍以获取可运行的演练。)
典型的混合精度训练¶
# Creates model and optimizer in default precision
model = Net().cuda()
optimizer = optim.SGD(model.parameters(), ...)
# Creates a GradScaler once at the beginning of training.
scaler = GradScaler()
for epoch in epochs:
for input, target in data:
optimizer.zero_grad()
# Runs the forward pass with autocasting.
with autocast(device_type='cuda', dtype=torch.float16):
output = model(input)
loss = loss_fn(output, target)
# Scales loss. Calls backward() on scaled loss to create scaled gradients.
# Backward passes under autocast are not recommended.
# Backward ops run in the same dtype autocast chose for corresponding forward ops.
scaler.scale(loss).backward()
# scaler.step() first unscales the gradients of the optimizer's assigned params.
# If these gradients do not contain infs or NaNs, optimizer.step() is then called,
# otherwise, optimizer.step() is skipped.
scaler.step(optimizer)
# Updates the scale for next iteration.
scaler.update()
处理未缩放的梯度¶
scaler.scale(loss).backward()
产生的全部梯度都是缩放后的。如果您希望在 backward()
和 scaler.step(optimizer)
之间修改或检查参数的 .grad
属性,您应该先取消缩放它们。例如,梯度裁剪会操作一组梯度,使得它们的全局范数(参见torch.nn.utils.clip_grad_norm_()
)或最大幅度(参见torch.nn.utils.clip_grad_value_()
) 某个用户设定的阈值。如果您尝试在不取消缩放的情况下进行裁剪,则梯度的范数/最大幅度也会被缩放,因此您请求的阈值(原意是用于未缩放梯度的阈值)将无效。
scaler.unscale_(optimizer)
会取消缩放由 optimizer
分配的参数所持有的梯度。如果您的模型或模型包含分配给另一个优化器(例如 optimizer2
)的其他参数,您可以单独调用 scaler.unscale_(optimizer2)
来取消缩放这些参数的梯度。
梯度裁剪¶
在裁剪之前调用 scaler.unscale_(optimizer)
使您能够像往常一样裁剪未缩放的梯度
scaler = GradScaler()
for epoch in epochs:
for input, target in data:
optimizer.zero_grad()
with autocast(device_type='cuda', dtype=torch.float16):
output = model(input)
loss = loss_fn(output, target)
scaler.scale(loss).backward()
# Unscales the gradients of optimizer's assigned params in-place
scaler.unscale_(optimizer)
# Since the gradients of optimizer's assigned params are unscaled, clips as usual:
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm)
# optimizer's gradients are already unscaled, so scaler.step does not unscale them,
# although it still skips optimizer.step() if the gradients contain infs or NaNs.
scaler.step(optimizer)
# Updates the scale for next iteration.
scaler.update()
scaler
会记录本轮迭代中已经为此优化器调用了 scaler.unscale_(optimizer)
,因此 scaler.step(optimizer)
知道在(内部)调用 optimizer.step()
之前不需要重复取消缩放梯度。
警告
unscale_
应在每次 step
调用中针对每个优化器仅调用一次,并且仅在该优化器分配的参数的所有梯度累积完成后调用。step
调用之间为给定优化器调用 unscale_
两次将触发 RuntimeError。
处理缩放后的梯度¶
梯度累积¶
梯度累积将梯度累加到有效批次大小为 batch_per_iter * iters_to_accumulate
(如果是分布式则 * num_procs
)的批次上。缩放应针对有效批次进行校准,这意味着 inf/NaN 检查、如果找到 inf/NaN 梯度则跳过步骤,以及缩放更新应以有效批次粒度进行。此外,在给定有效批次的梯度累积期间,梯度应保持缩放状态,并且缩放因子应保持不变。如果在累积完成之前梯度被取消缩放(或缩放因子改变),则下一次反向传播会将缩放后的梯度加到未缩放的梯度(或按不同因子缩放的梯度)上,此后无法恢复 step
必须应用的累积的未缩放梯度。
因此,如果您想取消缩放梯度(例如,允许裁剪未缩放的梯度),请在 step
之前调用 unscale_
,在即将到来的 step
的所有(缩放后)梯度累积完成后进行。此外,仅在您为完整的有效批次调用了 step
的迭代结束时调用 update
scaler = GradScaler()
for epoch in epochs:
for i, (input, target) in enumerate(data):
with autocast(device_type='cuda', dtype=torch.float16):
output = model(input)
loss = loss_fn(output, target)
loss = loss / iters_to_accumulate
# Accumulates scaled gradients.
scaler.scale(loss).backward()
if (i + 1) % iters_to_accumulate == 0:
# may unscale_ here if desired (e.g., to allow clipping unscaled gradients)
scaler.step(optimizer)
scaler.update()
optimizer.zero_grad()
梯度惩罚¶
梯度惩罚实现通常使用 torch.autograd.grad()
创建梯度,将它们组合以创建惩罚值,并将惩罚值添加到损失中。
这里是一个没有梯度缩放或自动类型转换的普通 L2 惩罚示例
for epoch in epochs:
for input, target in data:
optimizer.zero_grad()
output = model(input)
loss = loss_fn(output, target)
# Creates gradients
grad_params = torch.autograd.grad(outputs=loss,
inputs=model.parameters(),
create_graph=True)
# Computes the penalty term and adds it to the loss
grad_norm = 0
for grad in grad_params:
grad_norm += grad.pow(2).sum()
grad_norm = grad_norm.sqrt()
loss = loss + grad_norm
loss.backward()
# clip gradients here, if desired
optimizer.step()
要实现带有梯度缩放的梯度惩罚,传递给 torch.autograd.grad()
的 outputs
Tensor 应被缩放。因此,产生的梯度将被缩放,并且在组合以创建惩罚值之前应取消缩放。
此外,惩罚项计算是前向传播的一部分,因此应包含在 autocast
上下文中。
对于相同的 L2 惩罚,它看起来是这样的
scaler = GradScaler()
for epoch in epochs:
for input, target in data:
optimizer.zero_grad()
with autocast(device_type='cuda', dtype=torch.float16):
output = model(input)
loss = loss_fn(output, target)
# Scales the loss for autograd.grad's backward pass, producing scaled_grad_params
scaled_grad_params = torch.autograd.grad(outputs=scaler.scale(loss),
inputs=model.parameters(),
create_graph=True)
# Creates unscaled grad_params before computing the penalty. scaled_grad_params are
# not owned by any optimizer, so ordinary division is used instead of scaler.unscale_:
inv_scale = 1./scaler.get_scale()
grad_params = [p * inv_scale for p in scaled_grad_params]
# Computes the penalty term and adds it to the loss
with autocast(device_type='cuda', dtype=torch.float16):
grad_norm = 0
for grad in grad_params:
grad_norm += grad.pow(2).sum()
grad_norm = grad_norm.sqrt()
loss = loss + grad_norm
# Applies scaling to the backward call as usual.
# Accumulates leaf gradients that are correctly scaled.
scaler.scale(loss).backward()
# may unscale_ here if desired (e.g., to allow clipping unscaled gradients)
# step() and update() proceed as usual.
scaler.step(optimizer)
scaler.update()
使用多个模型、损失和优化器¶
如果您的网络有多个损失,您必须单独对每个损失调用 scaler.scale
。如果您的网络有多个优化器,您可以单独对其中任何一个调用 scaler.unscale_
,并且必须单独对每个优化器调用 scaler.step
。
然而,scaler.update
应该只调用一次,在本次迭代中使用的所有优化器都已执行 step 后调用
scaler = torch.amp.GradScaler()
for epoch in epochs:
for input, target in data:
optimizer0.zero_grad()
optimizer1.zero_grad()
with autocast(device_type='cuda', dtype=torch.float16):
output0 = model0(input)
output1 = model1(input)
loss0 = loss_fn(2 * output0 + 3 * output1, target)
loss1 = loss_fn(3 * output0 - 5 * output1, target)
# (retain_graph here is unrelated to amp, it's present because in this
# example, both backward() calls share some sections of graph.)
scaler.scale(loss0).backward(retain_graph=True)
scaler.scale(loss1).backward()
# You can choose which optimizers receive explicit unscaling, if you
# want to inspect or modify the gradients of the params they own.
scaler.unscale_(optimizer0)
scaler.step(optimizer0)
scaler.step(optimizer1)
scaler.update()
每个优化器都会检查其梯度是否存在 inf/NaN,并独立决定是否跳过该步骤。这可能导致一个优化器跳过步骤而另一个不跳过。由于跳过步骤很少发生(每几百次迭代一次),这不应阻碍收敛。如果您在为多优化器模型添加梯度缩放后观察到收敛不良,请报告错误。
使用多个 GPU¶
此处描述的问题仅影响 autocast
。GradScaler
的用法保持不变。
单个进程中的 DataParallel¶
即使 torch.nn.DataParallel
会生成线程在每个设备上运行前向传播。autocast 状态会在每个线程中传播,以下代码将起作用
model = MyModel()
dp_model = nn.DataParallel(model)
# Sets autocast in the main thread
with autocast(device_type='cuda', dtype=torch.float16):
# dp_model's internal threads will autocast.
output = dp_model(input)
# loss_fn also autocast
loss = loss_fn(output)
DistributedDataParallel,每个进程一个 GPU¶
torch.nn.parallel.DistributedDataParallel
的文档建议每个进程使用一个 GPU 以获得最佳性能。在这种情况下,DistributedDataParallel
不会在内部生成线程,因此 autocast
和 GradScaler
的用法不受影响。
DistributedDataParallel,每个进程多个 GPU¶
在这种情况下,torch.nn.parallel.DistributedDataParallel
可能会生成一个子线程来在每个设备上运行前向传播,类似于 torch.nn.DataParallel
。解决方法是相同的:将 autocast 应用为模型的 forward
方法的一部分,以确保它在子线程中启用。
Autocast 和自定义 Autograd 函数¶
如果您的网络使用了自定义 autograd 函数(torch.autograd.Function
的子类),则如果任何函数满足以下条件,就需要进行更改以实现 autocast 兼容性
接受多个浮点 Tensor 输入,
包装任何可自动类型转换的操作(参见Autocast 操作参考),或
需要特定的
dtype
(例如,如果它包装了仅针对该dtype
编译的CUDA 扩展)。
在所有情况下,如果您正在导入函数且无法更改其定义,一个安全的备用方案是在发生错误的任何使用点禁用 autocast 并强制以 float32
(或 dtype
)执行
with autocast(device_type='cuda', dtype=torch.float16):
...
with autocast(device_type='cuda', dtype=torch.float16, enabled=False):
output = imported_function(input1.float(), input2.float())
如果您是该函数的作者(或可以更改其定义),一个更好的解决方案是使用 torch.amp.custom_fwd()
和 torch.amp.custom_bwd()
装饰器,如下面的相关示例所示。
具有多个输入或可自动类型转换操作的函数¶
分别将 custom_fwd
和 custom_bwd
(不带参数)应用于 forward
和 backward
。这确保 forward
在当前的 autocast 状态下执行,而 backward
在与 forward
相同的 autocast 状态下执行(这可以防止类型不匹配错误)
class MyMM(torch.autograd.Function):
@staticmethod
@custom_fwd
def forward(ctx, a, b):
ctx.save_for_backward(a, b)
return a.mm(b)
@staticmethod
@custom_bwd
def backward(ctx, grad):
a, b = ctx.saved_tensors
return grad.mm(b.t()), a.t().mm(grad)
现在可以在任何地方调用 MyMM
,无需禁用 autocast 或手动转换输入
mymm = MyMM.apply
with autocast(device_type='cuda', dtype=torch.float16):
output = mymm(input1, input2)
需要特定 dtype
的函数¶
考虑一个需要 torch.float32
输入的自定义函数。将 custom_fwd(device_type='cuda', cast_inputs=torch.float32)
应用于 forward
,将 custom_bwd(device_type='cuda')
应用于 backward
。如果 forward
在启用了 autocast 的区域运行,装饰器会将浮点 Tensor 输入转换为由参数 device_type(本例中为 CUDA)指定的设备上的 float32
,并在 forward
和 backward
期间本地禁用 autocast
class MyFloat32Func(torch.autograd.Function):
@staticmethod
@custom_fwd(device_type='cuda', cast_inputs=torch.float32)
def forward(ctx, input):
ctx.save_for_backward(input)
...
return fwd_output
@staticmethod
@custom_bwd(device_type='cuda')
def backward(ctx, grad):
...
现在可以在任何地方调用 MyFloat32Func
,无需手动禁用 autocast 或转换输入
func = MyFloat32Func.apply
with autocast(device_type='cuda', dtype=torch.float16):
# func will run in float32, regardless of the surrounding autocast state
output = func(input)