PyTorch 2.0 NNModule 支持¶
作者: Will Constable
torch.compile 对 torch.nn.Module 对象有特殊的处理,与跟踪任意 python 类的方式不同,旨在通过对结构进行假设来生成更快的代码。
本文档介绍了一些由于这种特殊化而出现的权衡或边缘情况。
NNModule 钩子支持¶
以前,torch.compile 不支持 nn.Module 上的钩子,如果注册了钩子,它们在编译后的程序中会被直接忽略。实际上,许多用户根本不使用 nn.Module 钩子,或者仅将其用于调试工作流程,但将 nn.Module 钩子与 torch.compile 结合使用存在有效的用例。
通过 nn.Module.__call__ 实现编排的钩子包括 _forward_pre_hooks、forward_hooks、_backward_pre_hooks 和 _backward_hooks,并将被称为“调用钩子”。torch.compile 部分支持这些钩子,但存在以下描述的限制。
另一类钩子包括 _state_dict_hooks 及其 pre 和 load_ 变体,torch.compile 仍然不支持这些钩子。
nn.Module.__call__ 钩子的用法和限制¶
默认情况下,torch.compile 将跟踪 nn.Module.__call__ 的内容,这意味着它将遇到并运行前向/预前向钩子。如果您在调用 torch.compile 之前安装了钩子,然后在之后不删除或更改钩子,则默认情况下应支持您的用例。
通常也支持后向/预后向钩子,但具有类似的注意事项:目前,当访问 backward_hooks 字典时,dynamo 中会发生图中断,这可能可以通过一些工作来避免。图中断也会影响触发后向钩子的时序,因为图段作为 autograd 函数运行,这些函数同时生成所有梯度。假设 dynamo 可能不会因为存在后向钩子而发生图中断,我们仍然预计一系列模块的后向钩子将在整个编译图的后向运行后一起触发。
“允许模块”上的钩子 torch.compile 特别对待常见模块(如 torch.conv)以及难以跟踪的模块,方法是允许它们在 dynamo 图中不透明地调用,而不是被 dynamo 跟踪到其中。对于此类模块,钩子当前会触发图中断,以便受影响的模块在 dynamo 之外运行。根据模型的不同,这可能会导致明显的性能下降,需要进行额外的工作来改进此支持。
skip_nnmodule_hook_guards 默认情况下,torch._dynamo.config.skip_nnmodule_hook_guards 设置为 True,这意味着不会在每个 nn.Module 钩子字典上安装保护,从而通过减少保护执行时间来提高运行时性能,但代价是如果编译后任何钩子字典被更改,则不会注意到。
如果您希望能够在编译后删除或修改钩子,并让 torch.compile 做出适当的反应(通过重新编译),则需要设置 skip_nnmodule_hook_guards=False,并预期添加保护会带来运行时性能损失。
TODO:确认后向/pre_backward 钩子是否正常工作,并据此记录
state_dict 钩子¶
torch.compile 尚不支持 State dict 钩子。
TODO:如果钩子导致图中断,则 warn_once。如果存在钩子,则 warn_once 指向此文档。