注意
点击此处下载完整示例代码
(beta) 运行编译后的优化器与学习率调度器¶
创建于: 2024 年 5 月 21 日 | 最后更新于: 2024 年 5 月 21 日 | 最后验证于: 2024 年 11 月 5 日
作者: Michael Lazos
优化器是训练任何深度学习模型的关键算法。在本示例中,我们将展示如何将使用 torch.compile
编译的优化器与学习率调度器配合使用,以加速训练收敛。
注意
本教程需要 PyTorch 2.3.0 或更高版本。
模型设置¶
在本示例中,我们将使用一个简单的线性层序列。
import torch
# Create simple model
model = torch.nn.Sequential(
*[torch.nn.Linear(1024, 1024, False, device="cuda") for _ in range(10)]
)
input = torch.rand(1024, device="cuda")
# run forward pass
output = model(input)
# run backward to populate the grads for our optimizer below
output.sum().backward()
设置和运行编译后的优化器与学习率调度器¶
在本节中,我们将使用 Adam 优化器与 LinearLR 调度器,并创建一个辅助函数,将它们各自的 step()
调用包装在 torch.compile()
中。
注意
torch.compile
仅支持计算能力为 7.0 或更高的 CUDA 设备。
# exit cleanly if we are on a device that doesn't support ``torch.compile``
if torch.cuda.get_device_capability() < (7, 0):
print("Exiting because torch.compile is not supported on this device.")
import sys
sys.exit(0)
# !!! IMPORTANT !!! Wrap the lr in a Tensor if we are pairing the
# the optimizer with an LR Scheduler.
# Without this, torch.compile will recompile as the value of the LR
# changes.
opt = torch.optim.Adam(model.parameters(), lr=torch.tensor(0.01))
sched = torch.optim.lr_scheduler.LinearLR(opt, total_iters=5)
@torch.compile(fullgraph=False)
def fn():
opt.step()
sched.step()
# Warmup runs to compile the function
for _ in range(5):
fn()
print(opt.param_groups[0]["lr"])
tensor(0.0047)
tensor(0.0060)
tensor(0.0073)
tensor(0.0087)
tensor(0.0100)
扩展:非张量学习率会发生什么?¶
对于好奇的读者,我们将展示如何查看在不将学习率包装为张量时,torch.compile
会发生什么。
# No longer wrap the LR in a tensor here
opt = torch.optim.Adam(model.parameters(), lr=0.01)
sched = torch.optim.lr_scheduler.LinearLR(opt, total_iters=5)
@torch.compile(fullgraph=False)
def fn():
opt.step()
sched.step()
# Setup logging to view recompiles
torch._logging.set_logs(recompiles=True)
# Warmup runs to compile the function
# We will now recompile on each iteration
# as the value of the lr is mutated.
for _ in range(5):
fn()
V0423 16:41:13.410000 660 torch/_dynamo/guards.py:2997] [33/1] [__recompiles] Recompiling function wrapper in /var/lib/ci-user/.local/lib/python3.10/site-packages/torch/optim/optimizer.py:465
V0423 16:41:13.410000 660 torch/_dynamo/guards.py:2997] [33/1] [__recompiles] triggered by the following guard failure(s):
V0423 16:41:13.410000 660 torch/_dynamo/guards.py:2997] [33/1] [__recompiles] - 33/0: Cache line invalidated because L['args'][0] got deallocated
V0423 16:41:13.424000 660 torch/_dynamo/guards.py:2997] [34/1] [__recompiles] Recompiling function step in /var/lib/ci-user/.local/lib/python3.10/site-packages/torch/optim/adam.py:212
V0423 16:41:13.424000 660 torch/_dynamo/guards.py:2997] [34/1] [__recompiles] triggered by the following guard failure(s):
V0423 16:41:13.424000 660 torch/_dynamo/guards.py:2997] [34/1] [__recompiles] - 34/0: Cache line invalidated because L['self'] got deallocated
V0423 16:41:16.228000 660 torch/_dynamo/guards.py:2997] [34/2] [__recompiles] Recompiling function step in /var/lib/ci-user/.local/lib/python3.10/site-packages/torch/optim/adam.py:212
V0423 16:41:16.228000 660 torch/_dynamo/guards.py:2997] [34/2] [__recompiles] triggered by the following guard failure(s):
V0423 16:41:16.228000 660 torch/_dynamo/guards.py:2997] [34/2] [__recompiles] - 34/1: ___as_tensor(self.param_groups[0]['lr']).item() == 0.003333333333333333 # (unknown source ___as_tensor(self.param_groups[0]['lr']).item(), please file a bug)
V0423 16:41:16.228000 660 torch/_dynamo/guards.py:2997] [34/2] [__recompiles] - 34/0: Cache line invalidated because L['self'] got deallocated
V0423 16:41:18.204000 660 torch/_dynamo/guards.py:2997] [34/3] [__recompiles] Recompiling function step in /var/lib/ci-user/.local/lib/python3.10/site-packages/torch/optim/adam.py:212
V0423 16:41:18.204000 660 torch/_dynamo/guards.py:2997] [34/3] [__recompiles] triggered by the following guard failure(s):
V0423 16:41:18.204000 660 torch/_dynamo/guards.py:2997] [34/3] [__recompiles] - 34/2: ___as_tensor(self.param_groups[0]['lr']).item() == 0.004666666666666667 # (unknown source ___as_tensor(self.param_groups[0]['lr']).item(), please file a bug)
V0423 16:41:18.204000 660 torch/_dynamo/guards.py:2997] [34/3] [__recompiles] - 34/1: ___as_tensor(self.param_groups[0]['lr']).item() == 0.003333333333333333 # (unknown source ___as_tensor(self.param_groups[0]['lr']).item(), please file a bug)
V0423 16:41:18.204000 660 torch/_dynamo/guards.py:2997] [34/3] [__recompiles] - 34/0: Cache line invalidated because L['self'] got deallocated
V0423 16:41:20.190000 660 torch/_dynamo/guards.py:2997] [34/4] [__recompiles] Recompiling function step in /var/lib/ci-user/.local/lib/python3.10/site-packages/torch/optim/adam.py:212
V0423 16:41:20.190000 660 torch/_dynamo/guards.py:2997] [34/4] [__recompiles] triggered by the following guard failure(s):
V0423 16:41:20.190000 660 torch/_dynamo/guards.py:2997] [34/4] [__recompiles] - 34/3: ___as_tensor(self.param_groups[0]['lr']).item() == 0.006000000000000001 # (unknown source ___as_tensor(self.param_groups[0]['lr']).item(), please file a bug)
V0423 16:41:20.190000 660 torch/_dynamo/guards.py:2997] [34/4] [__recompiles] - 34/2: ___as_tensor(self.param_groups[0]['lr']).item() == 0.004666666666666667 # (unknown source ___as_tensor(self.param_groups[0]['lr']).item(), please file a bug)
V0423 16:41:20.190000 660 torch/_dynamo/guards.py:2997] [34/4] [__recompiles] - 34/1: ___as_tensor(self.param_groups[0]['lr']).item() == 0.003333333333333333 # (unknown source ___as_tensor(self.param_groups[0]['lr']).item(), please file a bug)
V0423 16:41:20.190000 660 torch/_dynamo/guards.py:2997] [34/4] [__recompiles] - 34/0: Cache line invalidated because L['self'] got deallocated
V0423 16:41:22.170000 660 torch/_dynamo/guards.py:2997] [34/5] [__recompiles] Recompiling function step in /var/lib/ci-user/.local/lib/python3.10/site-packages/torch/optim/adam.py:212
V0423 16:41:22.170000 660 torch/_dynamo/guards.py:2997] [34/5] [__recompiles] triggered by the following guard failure(s):
V0423 16:41:22.170000 660 torch/_dynamo/guards.py:2997] [34/5] [__recompiles] - 34/4: ___as_tensor(self.param_groups[0]['lr']).item() == 0.007333333333333335 # (unknown source ___as_tensor(self.param_groups[0]['lr']).item(), please file a bug)
V0423 16:41:22.170000 660 torch/_dynamo/guards.py:2997] [34/5] [__recompiles] - 34/3: ___as_tensor(self.param_groups[0]['lr']).item() == 0.006000000000000001 # (unknown source ___as_tensor(self.param_groups[0]['lr']).item(), please file a bug)
V0423 16:41:22.170000 660 torch/_dynamo/guards.py:2997] [34/5] [__recompiles] - 34/2: ___as_tensor(self.param_groups[0]['lr']).item() == 0.004666666666666667 # (unknown source ___as_tensor(self.param_groups[0]['lr']).item(), please file a bug)
V0423 16:41:22.170000 660 torch/_dynamo/guards.py:2997] [34/5] [__recompiles] - 34/1: ___as_tensor(self.param_groups[0]['lr']).item() == 0.003333333333333333 # (unknown source ___as_tensor(self.param_groups[0]['lr']).item(), please file a bug)
V0423 16:41:22.170000 660 torch/_dynamo/guards.py:2997] [34/5] [__recompiles] - 34/0: Cache line invalidated because L['self'] got deallocated
通过此示例,我们可以看到,由于 param_groups[0]
中 lr
的守卫失效,我们重新编译了几次优化器。
结论¶
在本教程中,我们展示了如何将使用 torch.compile
编译的优化器与学习率调度器配合使用,以加速训练收敛。我们使用了一个由简单线性层序列组成的模型,并搭配 Adam 优化器和 LinearLR 调度器,演示了学习率在迭代过程中的变化。
另请参阅
编译后的优化器教程 - 编译后的优化器入门介绍。
使用 PT2 编译优化器 - 关于编译后优化器的更深层技术细节。
脚本总运行时间: ( 0 分钟 13.342 秒)