torch.autograd.graph.increment_version¶
- torch.autograd.graph.increment_version(tensor)[源代码]¶
更新 autograd 元数据以跟踪给定张量是否已就地修改。
这样做是为了在 autograd 引擎中启用更准确的错误检查。它已由 PyTorch 函数以及在 custom Function 中调用 mark_dirty() 时自动完成,因此您只需要在以 PyTorch 不知道的方式对张量数据进行就地操作时才需要显式调用它。例如,自定义内核会读取张量 data_ptr 并根据该指针就地修改内存。可以接受张量或张量列表。
请注意,对单个就地操作多次递增版本计数器没有问题。
请注意,如果您传入在 torch.inference_mode() 下构建的张量,我们将不会增加其版本计数器(因为您的张量没有版本计数器)。