torch.utils.module_tracker¶
此工具可用于跟踪当前在 torch.nn.Module
层次结构中的位置。它可与其他跟踪工具结合使用,以便轻松地将测量量与用户友好的名称关联起来。目前,它特别用于 FlopCounterMode 中。
- class torch.utils.module_tracker.ModuleTracker[source][source]¶
ModuleTracker
是一个上下文管理器,用于在执行期间跟踪 nn.Module 层次结构,以便其他系统可以查询当前正在执行哪个 Module(或其反向传播)。您可以通过访问此上下文管理器的
parents
属性来获取当前正在执行的所有 Module 的集合,它们通过 fqn(完全限定名,也用作 state_dict 中的键)标识。您可以访问is_bw
属性来了解当前是否正在运行反向传播。注意,
parents
永远不为空,并且始终包含“Global”键。is_bw
标志在正向传播后将保持True
,直到执行另一个 Module。如果您需要它更精确,请提交一个 issue 来请求此功能。添加从 fqn 到模块实例的映射是可能的,但尚未完成,如果您需要此功能,请提交一个 issue 来请求。示例用法
mod = torch.nn.Linear(2, 2) with ModuleTracker() as tracker: # Access anything during the forward pass def my_linear(m1, m2, bias): print(f"Current modules: {tracker.parents}") return torch.mm(m1, m2.t()) + bias torch.nn.functional.linear = my_linear mod(torch.rand(2, 2))