torch.autograd.gradcheck.gradcheck¶
- torch.autograd.gradcheck.gradcheck(func, inputs, *, eps=1e-06, atol=1e-05, rtol=0.001, raise_exception=True, nondet_tol=0.0, check_undefined_grad=True, check_grad_dtypes=False, check_batched_grad=False, check_batched_forward_grad=False, check_forward_ad=False, check_backward_ad=True, fast_mode=False, masked=None)[source][source]¶
针对浮点型或复数类型且
requires_grad=True
的inputs
张量,检查通过小有限差分计算的梯度与解析梯度是否一致。数值梯度和解析梯度之间的检查使用
allclose()
。对于我们出于优化目的考虑的大多数复杂函数,不存在雅可比矩阵的概念。相反,gradcheck 验证 Wirtinger 导数和共轭 Wirtinger 导数的数值和解析值是否一致。由于梯度计算是在总体函数具有实值输出的假设下完成的,因此我们以特殊方式处理具有复数输出的函数。对于这些函数,gradcheck 应用于两个实值函数,这两个函数分别对应于取复数输出的实部和取复数输出的虚部。有关更多详细信息,请查看 复数自动微分。
注意
默认值是为双精度
input
设计的。如果input
的精度较低,例如FloatTensor
,则此检查可能会失败。注意
Gradcheck 在不可微点上评估时可能会失败,因为通过有限差分数值计算的梯度可能与解析计算的梯度不同(不一定是因为任何一个是错误的)。有关更多背景信息,请参阅 不可微函数的梯度。
警告
如果
input
中任何被检查的张量具有重叠内存,即指向同一内存地址的不同索引(例如,来自torch.Tensor.expand()
),则此检查可能会失败,因为通过此类索引处的点扰动数值计算的梯度将更改共享同一内存地址的所有其他索引处的值。- 参数
func (函数) – 一个 Python 函数,它接受张量输入并返回张量或张量元组
eps (float, 可选) – 有限差分的扰动
atol (float, 可选) – 绝对公差
rtol (float, 可选) – 相对公差
raise_exception (bool, 可选) – 指示检查失败时是否引发异常。异常会提供有关失败确切性质的更多信息。这在调试 gradcheck 时很有帮助。
nondet_tol (float, 可选) – 非确定性的公差。当通过微分运行相同的输入时,结果必须完全匹配(默认值,0.0)或在此公差范围内。
check_undefined_grad (bool, 可选) – 如果为
True
,则检查是否支持未定义的输出梯度并将其视为零,对于Tensor
输出。check_batched_grad (bool, 可选) – 如果为
True
,则检查我们是否可以使用原型 vmap 支持计算批处理梯度。默认为 False。check_batched_forward_grad (bool, 可选) – 如果为
True
,则检查我们是否可以使用前向 AD 和原型 vmap 支持计算批处理前向梯度。默认为False
。check_forward_ad (bool, 可选) – 如果为
True
,则检查使用前向模式 AD 计算的梯度是否与数值梯度匹配。默认为False
。check_backward_ad (bool, 可选) – 如果为
False
,则不执行任何依赖于向后模式 AD 实现的检查。默认为True
。fast_mode (bool, 可选) – gradcheck 和 gradgradcheck 的快速模式目前仅针对 R 到 R 函数实现。如果输入和输出都不是复数,则运行 gradcheck 的更快实现,该实现不再计算整个雅可比矩阵;否则,我们将回退到慢速实现。
masked (bool, 可选) – 如果为
True
,则忽略稀疏张量的未指定元素的梯度。默认为False
。
- 返回
如果所有差异都满足 allclose 条件,则返回
True
- 返回类型