快捷方式

PyTorch 2.0 NNModule 支持

作者: Will Constable

torch.compile 对 torch.nn.Module 对象有特殊处理,以不同于跟踪任意 Python 类的方式跟踪它们,目的是通过对结构进行假设来生成更快的代码。

本文档描述了由于这种专门化而出现的一些权衡或边缘情况。

NNModule 钩子支持

以前,torch.compile 不支持 nn.Modules 上的钩子,如果注册了钩子,它们将在编译程序中被忽略。事实上,许多用户根本不使用 nn.Module 钩子,或者只在调试工作流中使用它们,但将 nn.Module 钩子与 torch.compile 组合起来有有效的用例。

通过 nn.Module.__call__ 实现编排的钩子包括 _forward_pre_hooksforward_hooks_backward_pre_hooks_backward_hooks,并将被称为“调用钩子”。这些钩子部分支持 torch.compile,但存在以下限制。

另一类钩子包括 _state_dict_hooks 及其 preload_ 变体,目前 torch.compile 还不支持。

nn.Module.__call__ 钩子使用和限制

默认情况下,torch.compile 会跟踪 nn.Module.__call__ 的内容,这意味着它会遇到并运行前向/前前向钩子。如果您在调用 torch.compile 之前安装了钩子,并且之后没有删除或更改钩子,那么您的用例应该默认情况下得到支持。

反向/前反向钩子通常也受支持,但存在类似的注意事项:目前,dynamo 在访问 backward_hooks 字典时会发生图中断,这可能可以通过一些工作来避免。图中断也会影响反向钩子的触发时间,因为图段作为自动微分函数运行,这些函数会同时生成所有梯度。假设 dynamo 可以不因反向钩子的存在而发生图中断,我们仍然预计一系列模块的反向钩子将在整个编译图的反向运行完成后一起触发。

“允许的模块”上的钩子 torch.compile 对常见的模块(如 torch.conv)以及难以跟踪的模块进行特殊处理,允许它们在 dynamo 图中不透明地调用,而不是被 dynamo 跟踪。对于此类模块,钩子目前会触发图中断,以便受影响的模块在 dynamo 之外运行。根据模型的不同,这可能会导致性能显著下降,需要进一步的工作来改进这种支持。

skip_nnmodule_hook_guards 默认情况下,torch._dynamo.config.skip_nnmodule_hook_guards 设置为 True,这意味着不会在每个 nn.Module 钩子字典上安装保护,通过减少保护执行时间来提高运行时,但代价是无法注意到编译后是否有任何钩子字典被更改。

如果您希望能够在编译后删除或修改钩子,并让 torch.compile 适当地做出反应(通过重新编译),那么您需要将 skip_nnmodule_hook_guards=False,并预期由于添加了保护而导致运行时性能下降。

TODO:确认反向/前反向钩子是否有效,并相应地记录

状态字典钩子

torch.compile 尚未支持状态字典钩子。

TODO:如果在钩子上发生图中断,则发出一次警告。如果存在钩子,则发出一次警告以指向此文档。

文档

访问 PyTorch 的全面开发者文档

查看文档

教程

获取针对初学者和高级开发人员的深入教程

查看教程

资源

查找开发资源并获得问题的解答

查看资源