快捷方式

Gradcheck 机制

本说明概述了 gradcheck()gradgradcheck() 函数的工作原理。

它将涵盖实值函数和复值函数的前向和反向模式自动微分,以及高阶导数。本说明还涵盖了 gradcheck 的默认行为,以及传递 fast_mode=True 参数的情况(以下称为快速 gradcheck)。

符号和背景信息

在整个说明中,我们将使用以下约定

  1. xxyyaabbvvuuururuiui 是实值向量,zz 是一个复值向量,可以用两个实值向量表示为 z=a+ibz = a + i b

  2. NNMM 是两个整数,我们将分别使用它们表示输入空间和输出空间的维度。

  3. f:RNRMf: \mathcal{R}^N \to \mathcal{R}^M 是我们的基本实数到实数函数,使得 y=f(x)y = f(x)

  4. g:CNRMg: \mathcal{C}^N \to \mathcal{R}^M 是我们的基本复数到实数函数,使得 y=g(z)y = g(z)

对于简单的实数到实数情况,我们用 JfJ_f 表示与 ff 相关的雅可比矩阵,大小为 M×NM \times N。这个矩阵包含所有偏导数,使得位置 (i,j)(i, j) 处的元素包含 yixj\frac{\partial y_i}{\partial x_j}。然后,对于给定的向量 vv(大小为 MM),反向模式自动微分计算的是 vTJfv^T J_f。另一方面,对于给定的向量 uu(大小为 NN),正向模式自动微分计算的是 JfuJ_f u

对于包含复数值的函数,情况要复杂得多。我们在此只提供要点,完整的描述可以在 复数的自动微分 中找到。

对于所有实值损失函数来说,满足复可微性(柯西-黎曼方程)的约束条件过于严格,因此我们选择使用Wirtinger微积分。在Wirtinger微积分的基本设置中,链式法则要求同时访问Wirtinger导数(以下称为 WW)和共轭Wirtinger导数(以下称为 CWCW)。WWCWCW 都需要被传播,因为一般来说,尽管它们的名字相似,但其中一个并不是另一个的复共轭。

为了避免传播这两个值,对于反向模式自动微分,我们始终假设正在计算其导数的函数是实值函数,或者是更大实值函数的一部分。这个假设意味着我们在反向传递过程中计算的所有中间梯度也与实值函数相关联。在实践中,这个假设在进行优化时并没有限制,因为这类问题需要实值目标函数(因为复数没有自然序)。

在这个假设下,使用 WWCWCW 的定义,我们可以证明 W=CWW = CW^*(我们使用 * 表示复共轭),因此实际上只需要将这两个值中的一个“反向传播通过计算图”,因为另一个值可以很容易地恢复。为了简化内部计算,PyTorch使用 2CW2 * CW 作为其反向传播的值,并在用户请求梯度时返回该值。与实数情况类似,当输出实际上在 RM\mathcal{R}^M 中时,反向模式自动微分不会计算 2CW2 * CW 而只计算 vT(2CW)v^T (2 * CW),其中向量 vRMv \in \mathcal{R}^M

对于正向模式 AD,我们使用类似的逻辑,在这种情况下,假设函数是更大函数的一部分,其输入在 R\mathcal{R} 中。在这种假设下,我们可以做出类似的声明,即每个中间结果都对应于一个函数,其输入在 R\mathcal{R} 中,并且在这种情况下,使用 WWCWCW 定义,我们可以证明对于中间函数,W=CWW = CW。为了确保在最简单的一维函数情况下正向模式和反向模式计算出相同的量,正向模式还计算 2CW2 * CW。与实数情况类似,当输入实际上在 RN\mathcal{R}^N 中时,正向模式 AD 不会计算 2CW2 * CW,而只计算给定向量 uRNu \in \mathcal{R}^N(2CW)u(2 * CW) u

默认反向模式 gradcheck 行为

实数到实数函数

为了测试函数 f:RNRM,xyf: \mathcal{R}^N \to \mathcal{R}^M, x \to y,我们以两种方式重构大小为 M×NM \times N 的完整雅可比矩阵 JfJ_f:解析版本和数值版本。解析版本使用我们的反向模式 AD,而数值版本使用有限差分。然后逐元素比较两个重构的雅可比矩阵是否相等。

默认实数输入数值评估

如果我们考虑一维函数 (N=M=1N = M = 1) 的基本情况,那么我们可以使用维基百科文章中的基本有限差分公式。为了获得更好的数值属性,我们使用“中心差分”

yxf(x+eps)f(xeps)2eps\frac{\partial y}{\partial x} \approx \frac{f(x + eps) - f(x - eps)}{2 * eps}

这个公式可以很容易地推广到多个输出(M>1M \gt 1),方法是将 yx\frac{\partial y}{\partial x} 设为大小为 M×1M \times 1 的列向量,如 f(x+eps)f(x + eps)。在这种情况下,上述公式可以按原样重复使用,并且仅通过两次用户函数评估(即 f(x+eps)f(x + eps)f(xeps)f(x - eps))来近似完整的雅可比矩阵。

处理多个输入(N>1N \gt 1)的情况计算成本更高。在这种情况下,我们依次循环遍历所有输入,并对 xx 的每个元素依次应用 epseps 扰动。这允许我们逐列地重建 JfJ_f 矩阵。

默认实数输入解析求值

对于解析求值,我们使用上面描述的事实,即反向模式 AD 计算 vTJfv^T J_f。对于具有单个输出的函数,我们只需使用 v=1v = 1 就可以通过一次反向传递来恢复完整的雅可比矩阵。

对于具有多个输出的函数,我们使用 for 循环来迭代输出,其中每个 vv 都是对应于每个输出的 one-hot 向量。这允许逐行地重建 JfJ_f 矩阵。

复数到实数的函数

为了测试函数 g:CNRM,zyg: \mathcal{C}^N \to \mathcal{R}^M, z \to y,其中 z=a+ibz = a + i b,我们重构包含 2CW2 * CW 的(复值)矩阵。

默认复数输入数值计算

首先考虑 N=M=1N = M = 1 的基本情况。我们从(这篇研究论文 的第 3 章)中知道

CW:=yz=12(ya+iyb)CW := \frac{\partial y}{\partial z^*} = \frac{1}{2} * (\frac{\partial y}{\partial a} + i \frac{\partial y}{\partial b})

请注意,上述等式中的 ya\frac{\partial y}{\partial a}yb\frac{\partial y}{\partial b}RR\mathcal{R} \to \mathcal{R} 导数。为了对其进行数值计算,我们对实数到实数的情况使用上述方法。这允许我们计算 CWCW 矩阵,然后将其乘以 22

请注意,在撰写本文时,代码以一种稍微复杂的方式计算此值

# Code from https://github.com/pytorch/pytorch/blob/58eb23378f2a376565a66ac32c93a316c45b6131/torch/autograd/gradcheck.py#L99-L105
# Notation changes in this code block:
# s here is y above
# x, y here are a, b above

ds_dx = compute_gradient(eps)
ds_dy = compute_gradient(eps * 1j)
# conjugate wirtinger derivative
conj_w_d = 0.5 * (ds_dx + ds_dy * 1j)
# wirtinger derivative
w_d = 0.5 * (ds_dx - ds_dy * 1j)
d[d_idx] = grad_out.conjugate() * conj_w_d + grad_out * w_d.conj()

# Since grad_out is always 1, and W and CW are complex conjugate of each other, the last line ends up computing exactly `conj_w_d + w_d.conj() = conj_w_d + conj_w_d = 2 * conj_w_d`.

默认复数输入解析计算

由于反向模式 AD 已经精确计算了两倍的 CWCW 导数,因此我们在这里只需使用与实数到实数情况相同的技巧,并在存在多个实数输出时逐行重构矩阵。

具有复数输出的函数

在这种情况下,用户提供的函数不符合自动微分的假设,即我们计算反向自动微分的函数是实值函数。这意味着直接在这个函数上使用自动微分是没有定义的。为了解决这个问题,我们将用两个函数: hrhrhihi 来替换对函数 h:PNCMh: \mathcal{P}^N \to \mathcal{C}^M (其中 P\mathcal{P} 可以是 R\mathcal{R}C\mathcal{C}) 的测试,使得:

hr(q):=real(f(q))hi(q):=imag(f(q))\begin{aligned} hr(q) &:= real(f(q)) \\ hi(q) &:= imag(f(q)) \end{aligned}

其中 qPq \in \mathcal{P}。然后,我们使用上面描述的实数到实数或复数到实数的情况(取决于 P\mathcal{P}),对 hrhrhihi 都进行基本的梯度检查。

请注意,在撰写本文时,代码不会显式创建这些函数,而是通过将 grad_out\text{grad\_out} 参数传递给不同的函数,手动使用 realrealimagimag 函数执行链式法则。当 grad_out=1\text{grad\_out} = 1 时,我们考虑的是 hrhr。当 grad_out=1j\text{grad\_out} = 1j 时,我们考虑的是 hihi

快速反向模式梯度检查

虽然上述梯度检查的公式很好,既能确保正确性,又能提高可调试性,但它非常慢,因为它要重建完整的雅可比矩阵。本节介绍一种在不影响正确性的情况下更快地执行梯度检查的方法。可调试性可以通过在检测到错误时添加特殊逻辑来恢复。在这种情况下,我们可以运行重建完整矩阵的默认版本,向用户提供完整的详细信息。

这里的高级策略是找到一个标量,它可以通过数值方法和解析方法有效地计算出来,并且能够很好地表示慢速梯度检查计算出的完整矩阵,以确保它能够捕捉到雅可比矩阵中的任何差异。

实数到实数函数的快速梯度检查

我们想要计算的标量是 vTJfuv^T J_f u,其中 vRMv \in \mathcal{R}^M 为给定随机向量,uRNu \in \mathcal{R}^N 为随机单位向量。

对于数值计算,我们可以高效地计算

Jfuf(x+ueps)f(xueps)2eps.J_f u \approx \frac{f(x + u * eps) - f(x - u * eps)}{2 * eps}.

然后,我们将该向量与 vv 进行点积,以获得我们感兴趣的标量值。

对于解析版本,我们可以使用反向模式自动微分直接计算 vTJfv^T J_f。然后,我们与 uu 进行点积,以获得期望值。

复数到实数函数的快速梯度检查

与实数到实数的情况类似,我们希望对完整矩阵执行降维。但是 2CW2 * CW 矩阵是复数值矩阵,因此在这种情况下,我们将与复数标量进行比较。

由于在数值情况下我们可以有效计算的内容受到一些限制,并且为了尽量减少数值计算的次数,我们计算以下(尽管令人惊讶)标量值

s:=2vT(real(CW)ur+iimag(CW)ui)s := 2 * v^T (real(CW) ur + i * imag(CW) ui)

其中 vRMv \in \mathcal{R}^MurRNur \in \mathcal{R}^N 并且 uiRNui \in \mathcal{R}^N

快速复数输入数值评估

我们首先考虑如何使用数值方法计算 ss。为此,请记住我们正在考虑 g:CNRM,zyg: \mathcal{C}^N \to \mathcal{R}^M, z \to y,其中 z=a+ibz = a + i b,并且 CW=12(ya+iyb)CW = \frac{1}{2} * (\frac{\partial y}{\partial a} + i \frac{\partial y}{\partial b}),我们将其改写如下:

s=2vT(real(CW)ur+iimag(CW)ui)=2vT(12yaur+i12ybui)=vT(yaur+iybui)=vT((yaur)+i(ybui))\begin{aligned} s &= 2 * v^T (real(CW) ur + i * imag(CW) ui) \\ &= 2 * v^T (\frac{1}{2} * \frac{\partial y}{\partial a} ur + i * \frac{1}{2} * \frac{\partial y}{\partial b} ui) \\ &= v^T (\frac{\partial y}{\partial a} ur + i * \frac{\partial y}{\partial b} ui) \\ &= v^T ((\frac{\partial y}{\partial a} ur) + i * (\frac{\partial y}{\partial b} ui)) \end{aligned}

在这个公式中,我们可以看到 yaur\frac{\partial y}{\partial a} urybui\frac{\partial y}{\partial b} ui 可以像实数到实数情况的快速版本一样进行评估。一旦计算出这些实数值,我们就可以重建右侧的复数向量,并与实值 vv 向量进行点积。

快速复数输入解析评估

对于解析情况,事情更简单,我们将公式改写为:

s=2vT(real(CW)ur+iimag(CW)ui)=vTreal(2CW)ur+ivTimag(2CW)ui)=real(vT(2CW))ur+iimag(vT(2CW))ui\begin{aligned} s &= 2 * v^T (real(CW) ur + i * imag(CW) ui) \\ &= v^T real(2 * CW) ur + i * v^T imag(2 * CW) ui) \\ &= real(v^T (2 * CW)) ur + i * imag(v^T (2 * CW)) ui \end{aligned}

因此,我们可以利用反向模式 AD 为我们提供了一种计算 vT(2CW)v^T (2 * CW) 的有效方法,然后将其实部与 urur 进行点积,将虚部与 uiui 进行点积,最后重构最终的复数标量 ss

为什么不使用复数 uu

在这一点上,您可能想知道为什么我们没有选择一个复杂的 uu 并直接进行约简 2vTCWu2 * v^T CW u'。为了深入探讨这一点,在本段中,我们将使用 uu 的复数形式,记为 u=ur+iuiu' = ur' + i ui'。使用这样的复数 uu',问题是在进行数值评估时,我们需要计算

2CWu=(ya+iyb)(ur+iui)=yaur+iyaui+iyburybui\begin{aligned} 2*CW u' &= (\frac{\partial y}{\partial a} + i \frac{\partial y}{\partial b})(ur' + i ui') \\ &= \frac{\partial y}{\partial a} ur' + i \frac{\partial y}{\partial a} ui' + i \frac{\partial y}{\partial b} ur' - \frac{\partial y}{\partial b} ui' \end{aligned}

这将需要对实数到实数的有限差分进行四次评估(与上述方法相比,评估次数增加了一倍)。由于此方法没有更多自由度(实值变量的数量相同),并且我们尝试在此处获得尽可能快的评估速度,因此我们使用上述另一种公式。

复数输出函数的快速梯度检查

就像在慢速情况下一样,我们考虑两个实值函数,并对每个函数使用上述相应的规则。

梯度梯度检查实现

PyTorch 还提供了一个实用程序来验证二阶梯度。此处的目标是确保反向实现也是可微分的,并且计算结果正确。

此功能是通过考虑函数 F:x,vvTJfF: x, v \to v^T J_f 并对该函数使用上面定义的梯度检查来实现的。请注意,在这种情况下,vv 只是一个与 f(x)f(x) 类型相同的随机向量。

gradgradcheck 的快速版本是通过在相同函数 FF 上使用 gradcheck 的快速版本来实现的。

文档

访问 PyTorch 全面的开发者文档

查看文档

教程

获取针对初学者和高级开发人员的深入教程

查看教程

资源

查找开发资源并获得您的问题的答案

查看资源