PyTorch 2.0 NNModule 支持¶
作者: Will Constable
torch.compile 对 torch.nn.Module 对象有特殊处理,与它跟踪任意 Python 类的方式不同,它跟踪这些对象,目的是通过对结构进行假设来生成更快的代码。
本文档描述了由于这种专门化而出现的一些权衡或边缘情况。
NNModule 钩子支持¶
以前,torch.compile 不支持 nn.Module 上的钩子,如果注册了钩子,它们将在编译后的程序中被忽略。实际上,许多用户根本不使用 nn.Module 钩子,或者只在调试工作流程中使用它们,但将 nn.Module 钩子与 torch.compile 组合起来是有效的用例。
通过 nn.Module.__call__ 实现来协调的钩子包括 _forward_pre_hooks、forward_hooks、_backward_pre_hooks 和 _backward_hooks,并将被称为“调用钩子”。这些钩子部分受 torch.compile 支持,但存在以下限制。
另一类钩子包括 _state_dict_hooks 及其 pre 和 load_ 变体,这些钩子不受 torch.compile 支持。
nn.Module.__call__ 钩子使用和限制¶
默认情况下,torch.compile 将跟踪 nn.Module.__call__ 的内容,这意味着它会遇到并运行 forward/pre-forward 钩子。如果您在调用 torch.compile 之前安装钩子,并且之后没有删除或更改钩子,那么您的用例应默认受支持。
Backward/Pre-backward 钩子通常也受支持,但存在类似的注意事项:目前,当访问 backward_hooks 字典时,Dynamo 中会发生图断开,这可能可以通过一些工作来避免。图断开也会影响 backward 钩子的触发时间,因为图段作为 autograd 函数运行,这些函数同时生成所有梯度。假设 Dynamo 可以不因 backward 钩子的存在而断开图,我们仍然预计一系列模块的 backward 钩子将在整个编译后的图的 backward 运行结束后一起触发。
“允许的模块”上的钩子 torch.compile 通过允许以不透明的方式调用它们(而不是被 Dynamo 跟踪)来特殊对待常见的模块,例如 torch.conv,以及难以跟踪的模块。对于此类模块,钩子当前会触发图断开,以便受影响的模块在 Dynamo 之外运行。根据模型的不同,这可能会导致性能大幅下降,需要额外的工作才能改进此支持。
skip_nnmodule_hook_guards 默认情况下,torch._dynamo.config.skip_nnmodule_hook_guards 设置为 True,这意味着不会在每个 nn.Module 钩子字典上安装保护措施,通过减少保护执行时间来提高运行时效率,但代价是不能注意到编译后是否更改了任何钩子字典。
如果您想能够在编译后删除或修改钩子,并让 torch.compile 适当地做出反应(通过重新编译),那么您需要设置 skip_nnmodule_hook_guards=False,并预计由于添加了保护措施而导致运行时性能下降。
待办事项:确认 backward/pre_backward 钩子是否正常工作,并相应地记录。