PyTorch 2.0 NNModule 支持¶
**作者**:Will Constable
torch.compile 对 torch.nn.Module 对象有特殊处理,对其进行跟踪的方式与跟踪任意 Python 类不同,目的是通过对结构进行假设来生成更快的代码。
本文档描述了由于这种专门化而出现的一些权衡或边缘情况。
NNModule 钩子支持¶
以前,torch.compile 不支持 nn.Module 上的钩子,如果注册了钩子,它们将在编译后的程序中被简单地忽略。实际上,许多用户根本不使用 nn.Module 钩子,或者仅在调试工作流程中使用它们,但将 nn.Module 钩子与 torch.compile 组合使用也存在有效的用例。
通过 nn.Module.__call__ 实现协调的钩子包括 _forward_pre_hooks、forward_hooks、_backward_pre_hooks 和 _backward_hooks,并将被称为“调用钩子”。这些钩子由 torch.compile 部分支持,但存在以下限制。
另一类钩子包括 _state_dict_hooks 及其 pre 和 load_ 变体,并且仍不受 torch.compile 支持。
nn.Module.__call__ 钩子用法和限制¶
默认情况下,torch.compile 将跟踪 nn.Module.__call__ 的内容,这意味着它将遇到并运行前向/前向预钩子。如果您在调用 torch.compile 之前安装钩子,然后以后不删除或更改钩子,则您的用例应该默认受支持。
反向/反向预钩子通常也受支持,存在类似的注意事项:目前,当访问 backward_hooks 字典时,dynamo 中会发生图中断,这可能通过一些工作来避免。图中断也会影响反向钩子触发的时机,因为图段作为 autograd 函数运行,这些函数同时生成所有梯度。假设 dynamo 可以不在 backward-hooks 存在的情况下发生图中断,我们仍然预计一系列模块的反向钩子将在整个编译图的反向运行结束后一起触发。
“允许的模块”上的钩子 torch.compile 将处理常见的模块,如 torch.conv,以及难以跟踪的模块,特别是通过允许它们在 dynamo 图中不透明地调用而不是被 dynamo 跟踪。对于此类模块,钩子目前会触发图中断,以便受影响的模块在 dynamo 外运行。根据模型的不同,这可能会导致性能显着下降,需要进一步的工作来改进此支持。
skip_nnmodule_hook_guards 默认情况下,torch._dynamo.config.skip_nnmodule_hook_guards 设置为 True,这意味着不会在每个 nn.Module 钩子字典上安装防护措施,从而通过减少防护措施执行时间来提高运行时性能,但代价是无法注意到编译后是否有任何钩子字典发生更改。
如果您希望能够在编译后删除或修改钩子并让 torch.compile 做出相应的反应(通过重新编译),则需要将 skip_nnmodule_hook_guards=False 并预期添加的防护措施会带来运行时性能损失。
待办事项:确认反向/反向预钩子是否正常工作并相应地记录