快捷方式

PyTorch 2.0 NNModule 支持

作者: Will Constable

torch.compile 对 torch.nn.Module 对象有特殊处理,对其追踪方式与追踪任意 Python 类不同,目的是通过对结构进行假设来生成更快的代码。

本文档描述了这种特殊处理所带来的一些权衡和边缘情况。

NNModule 钩子支持

以前,torch.compile 不支持 nn.Module 上的钩子,如果注册了钩子,它们在编译后的程序中会被直接忽略。事实上,许多用户根本不使用 nn.Module 钩子,或仅将其用于调试工作流程,但将 nn.Module 钩子与 torch.compile 结合使用存在有效的使用场景。

通过 nn.Module.__call__ 实现编排的钩子包括 _forward_pre_hooksforward_hooks_backward_pre_hooks_backward_hooks,这些钩子将被称为“调用钩子”。torch.compile 部分支持这些钩子,但存在下述限制。

另一类钩子包括 _state_dict_hooks 及其 preload_ 变体,这些钩子仍不受 torch.compile 支持。

nn.Module.__call__ 钩子使用及限制

默认情况下,torch.compile 会追踪 nn.Module.__call__ 的内容,这意味着它会遇到并运行前向/前向预钩子。如果您在调用 torch.compile 之前安装钩子,并且之后不移除或更改这些钩子,则您的使用场景默认应受支持。

反向/反向预钩子通常也受支持,但也存在类似注意事项:当前在 Dynamo 中访问 backward_hooks 字典时会发生图中断 (graph-breaks),这可能通过一些工作可以避免。图中断还会影响反向钩子的触发时机,因为图段会作为自动微分函数 (autograd-functions) 运行,它们同时产生所有梯度 (grads)。假设 Dynamo 可以在存在反向钩子的情况下不发生图中断,我们仍然期望一系列模块的反向钩子在整个编译图的反向传播运行后一起触发。

“允许模块”(allowed modules) 上的钩子 torch.compile 会特殊处理常见模块(例如 torch.conv)以及难以追踪的模块,方法是允许它们在 Dynamo 图中以不透明方式调用,而不是被 Dynamo 追踪进入。对于此类模块,钩子目前会触发图中断 (graph-break),导致受影响的模块在 Dynamo 外部运行。根据模型不同,这可能会导致显著的性能下降 (performance regression),需要额外工作来改进此支持。

skip_nnmodule_hook_guards 默认情况下,torch._dynamo.config.skip_nnmodule_hook_guards 设置为 True,这意味着不会在每个 nn.Module 钩子字典上安装守卫 (guards),从而通过减少守卫执行时间来提高运行时性能,代价是编译后如果任何钩子字典发生更改,将不会被注意到。

如果您希望在编译后能够移除或修改钩子,并让 torch.compile 做出适当反应(通过重新编译),则需要将 skip_nnmodule_hook_guards=False,并预计因添加守卫而产生的运行时开销 (runtime penalty)。

TODO: 确认反向/反向预钩子是否工作,并相应地更新文档

state_dict 钩子

torch.compile 尚不支持 state dict 钩子。

TODO: 如果在钩子上发生图中断,则警告一次。如果存在钩子,则警告一次并指向本文档。

文档

查阅 PyTorch 的全面开发者文档

查看文档

教程

获取面向初学者和高级开发者的深入教程

查看教程

资源

查找开发资源并获得问题解答

查看资源