带有 TorchScript 支持的分布式优化器¶
创建于:2021 年 4 月 26 日 | 最后更新:2024 年 12 月 02 日 | 最后验证:2024 年 11 月 05 日
警告
TorchScript 已不再积极开发。
在本食谱中,您将学习
带有 TorchScript 支持的分布式优化器的高级思想以及此功能带来的好处
如何编写自定义的分布式优化器,使其支持 TorchScript
要求¶
PyTorch 1.8+
什么是分布式优化器?¶
DistributedOptimizer 接受远程参数 (RRef) 列表,并在参数所在的 worker 上本地运行优化器,这通常与分布式 RPC/Autograd 一起用于进行模型并行训练。它可以使用任何本地优化器算法(torch.optim
中提供的预定义算法或自定义算法)来应用每个 worker 上的梯度。
什么是带有 TorchScript 支持的分布式优化器?¶
分布式优化器广泛用于分布式模型并行训练,在某些常见用例中,出于性能考虑和资源利用率(或至少部分多线程,即参数服务器托管部分模型和参数,并且新线程根据请求更新参数),训练需要在多线程方式下完成,而不是多进程方式下完成。PyTorch 本身并不原生支持多线程训练,因为它受到 Python 的全局解释器锁 (GIL) 的限制,但它可以利用 TorchScript 来摆脱 GIL,并以多线程方式运行模型。
对于关键的模型训练工作负载,提高训练性能是一个重要的课题。研究人员通常希望通过图表示(即通过运算符融合)实现不同的优化策略,或者实现自定义运算符内核,以加快训练速度。
带有 TorchScript 支持的分布式优化器可以帮助摆脱 GIL,从而提高 PyTorch 在多线程环境中的训练性能,它还释放了通过使用 TorchScript 提供的先进编译器技术(即 CPU/GPU 融合)进一步增强性能的潜力。
如何编写带有 TorchScript 支持的自定义分布式优化器?¶
下面的代码展示了如何在现有本地优化器实现的基础上编写自定义的分布式优化器,从而解锁 TorchScript 的优势,包括消除 GIL 和提高性能的机会。
假设您已经有一个本地优化器,目前在训练期间使用。在本例中,我们将使用 准双曲动量 (QHM) 作为示例,展示如何启用 TorchScript 支持,请注意,它也适用于从 torch.optim.Optimizer
继承的任何自定义优化器。
首先,我们需要从优化器实现中分离出计算和状态管理,这样我们就可以提取计算部分并使其成为一个自由函数,这对 TorchScript 友好。它有两个好处:1. 计算逻辑变得更容易检查,它允许我们快速将参数更新/计算部分转换为 TorchScript,并利用 TorchScript IR 进行进一步优化(运算符融合等)。2. 分布式优化器底层使用不同的机制来获取梯度和更新参数(我们单独存储梯度,而不是在反向传播期间直接填充 param.grad
字段)。分离计算使分布式优化器能够以多线程模式进行优化器更新,因为它消除了 param.grad
的可能竞争条件。
import torch
from torch import Tensor
from typing import List
def qhm_update(params: List[Tensor],
dp_list: List[Tensor],
momentum_buffer_list: List[Tensor],
lr: float,
nu: float,
weight_decay: float,
weight_decay_type: str,
momentum: float):
for p, d_p, momentum_buffer in zip(params, dp_list, momentum_buffer_list):
if weight_decay != 0:
if weight_decay_type == "grad":
d_p.add_(weight_decay, p)
elif weight_decay_type == "direct":
p.mul_(1.0 - lr * weight_decay)
else:
raise ValueError("Invalid weight decay type provided")
momentum_buffer.mul_(momentum).add_(1.0 - momentum, d_p)
p.data.add_(-lr * nu, momentum_buffer)
p.data.add_(-lr * (1.0 - nu), d_p)
接下来,我们将定义一个具有 TorchScript 兼容性的分布式函数优化器,以管理优化器状态并调用我们上面定义的 TorchScript 兼容更新函数。请注意,与普通自定义优化器相比,有几个约定不同:1. 我们不继承 torch.optim.Optimizer
,因为 TorchScript 不支持多态性 2. step
接受梯度列表而不是损失闭包。
import torch
from torch import Tensor
from typing import List, Optional, Dict
# define this as a TorchScript class
@torch.jit.script
class FunctionalQHM(object):
def __init__(self,
params: List[Tensor],
lr: float,
momentum: float,
nu: float,
weight_decay: float = 0.0,
weight_decay_type: str = "grad"):
if lr < 0.0:
raise ValueError("Invalid learning rate: {}".format(lr))
if momentum < 0.0:
raise ValueError("Invalid momentum value: {}".format(momentum))
if weight_decay < 0.0:
raise ValueError("Invalid weight_decay value: {}".format(weight_decay))
if weight_decay_type not in ("grad", "direct"):
raise ValueError("Invalid weight_decay_type value: {}".format(weight_decay_type))
self.defaults = {
"lr": lr,
"momentum": momentum,
"nu": nu,
"weight_decay": weight_decay,
}
self.weight_decay_type = weight_decay_type
# NOTE: we only have one param_group here and don't allow user to add additional
# param group as it's not a common use case.
self.param_group = {"params": params}
self.state = torch.jit.annotate(Dict[torch.Tensor, Dict[str, torch.Tensor]], {})
def step(self, gradients: List[Optional[Tensor]]):
params = self.param_group['params']
params_with_grad = []
grads = []
momentum_buffer_list: List[Tensor] = []
if len(params) != len(gradients):
raise ValueError(
"the gradients passed in does not equal to the size of the parameters!"
+ f"Params length: {len(params)}. "
+ f"Gradients length: {len(gradients)}"
)
for param, gradient in zip(self.param_group['params'], gradients):
if gradient is not None:
params_with_grad.append(param)
grads.append(gradient)
state = self.state[param]
state['momentum_buffer'] = torch.zeros_like(param, memory_format=torch.preserve_format)
momentum_buffer_list.append(state['momentum_buffer'])
# calls into the update function we just defined
with torch.no_grad():
qhm_update(params_with_grad,
grads,
momentum_buffer_list,
self.defaults['lr'],
self.defaults['nu'],
self.defaults['weight_decay'],
self.weight_decay_type,
self.defaults['momentum'])
最后,我们将新定义的分布式函数优化器注册到 functional_optim_map
中。这样,DistributedOptimizer
将尝试拾取我们的自定义实现,而不是预定义的默认实现。
from torch.distributed.optim import DistributedOptimizer
DistributedOptimizer.functional_optim_map[QHM] = FunctionalQHM
现在,您可以通过将 QHM
优化器传递给 DistributedOptimizer,像往常一样在分布式训练中使用它
...
remote_params_list = [...]
dist_optim = DistributedOptimizer(
QHM, remote_params_list, *args, **kwargs
)
DistributedOptimizer 将在后台自动将 QHM 优化器转换为 FunctionalQHM
,并启用 TorchScript 支持。这将释放多线程训练带来的性能提升,并为进一步改进提供更多潜力(即 TorchScript 融合等)。
请注意,大多数 PyTorch 内置优化器已经使用这种方法来加速分布式训练。如果您看到关于某些优化器尚未转换的警告,您可以按照此食谱编写自己的转换。