快捷方式

torch.autograd.function.FunctionCtx.mark_non_differentiable

FunctionCtx.mark_non_differentiable(*args)[source][source]

将输出标记为不可微分。

这个方法最多只能调用一次,可以在 setup_context()forward() 方法中调用,并且所有参数都应该是 tensor 输出。

这将把输出标记为不需要梯度,从而提高反向计算的效率。你仍然需要在 backward() 方法中为每个输出接受一个梯度,但这始终是一个零张量,其形状与对应输出的形状相同。

这用于例如从排序返回的索引。参见示例:
>>> class Func(Function):
>>>     @staticmethod
>>>     def forward(ctx, x):
>>>         sorted, idx = x.sort()
>>>         ctx.mark_non_differentiable(idx)
>>>         ctx.save_for_backward(x, idx)
>>>         return sorted, idx
>>>
>>>     @staticmethod
>>>     @once_differentiable
>>>     def backward(ctx, g1, g2):  # still need to accept g2
>>>         x, idx = ctx.saved_tensors
>>>         grad_input = torch.zeros_like(x)
>>>         grad_input.index_add_(0, idx, g1)
>>>         return grad_input

文档

查阅 PyTorch 的全面开发者文档

查看文档

教程

获取针对初学者和高级开发者的深入教程

查看教程

资源

查找开发资源并获得问题解答

查看资源