快捷方式

自动微分机制

本说明将概述自动微分的工作原理以及如何记录操作。并非必须完全理解所有内容,但我们建议您熟悉它,因为它将帮助您编写更高效、更简洁的程序,并可以帮助您进行调试。

自动微分如何编码历史记录

自动微分是一个反向自动微分系统。从概念上讲,自动微分会在您执行操作时记录一个图形,该图形记录了创建数据的所有操作,为您提供一个有向无环图,其叶子节点是输入张量,根节点是输出张量。通过从根节点到叶子节点跟踪此图,您可以使用链式法则自动计算梯度。

在内部,自动微分将此图表示为一个由 Function 对象(实际上是表达式)组成的图,可以对其应用 apply() 来计算评估图的结果。在计算前向传递时,自动微分会同时执行请求的计算并构建一个表示计算梯度的函数的图(每个 torch.Tensor.grad_fn 属性都是此图中的一个入口点)。当前向传递完成后,我们在反向传递中评估此图以计算梯度。

需要注意的重要一点是,该图在每次迭代中都会从头开始重新创建,这正是允许使用任意 Python 控制流语句的原因,这些语句可以在每次迭代中更改图的整体形状和大小。您不必在开始训练之前就编码所有可能的路径 - 您运行的内容就是您要微分的内容。

已保存的张量

某些操作需要在前向传递期间保存中间结果,以便执行反向传递。例如,函数 xx2x\mapsto x^2 保存输入 xx 以计算梯度。

在定义自定义 Python Function 时,您可以使用 save_for_backward() 在前向传递期间保存张量,并使用 saved_tensors 在反向传递期间检索它们。有关详细信息,请参阅扩展 PyTorch

对于 PyTorch 定义的操作(例如 torch.pow()),会根据需要自动保存张量。您可以通过查找以 _saved 前缀开头的属性来探索(用于教育或调试目的)某个 grad_fn 保存了哪些张量。

x = torch.randn(5, requires_grad=True)
y = x.pow(2)
print(x.equal(y.grad_fn._saved_self))  # True
print(x is y.grad_fn._saved_self)  # True

在前面的代码中,y.grad_fn._saved_self 指的是与 x 相同的张量对象。但这并非总是如此。例如:

x = torch.randn(5, requires_grad=True)
y = x.exp()
print(y.equal(y.grad_fn._saved_result))  # True
print(y is y.grad_fn._saved_result)  # False

在底层,为了防止循环引用,PyTorch 在保存时会对张量进行*打包*,并在读取时将其*解包*到不同的张量中。在这里,您从访问 y.grad_fn._saved_result 获得的张量与 y 是不同的张量对象(但它们仍然共享相同的存储)。

张量是否会被打包到不同的张量对象中取决于它是否是其自身 grad_fn 的输出,这是一个实现细节,可能会发生变化,用户不应依赖它。

您可以使用已保存张量的钩子来控制 PyTorch 如何进行打包/解包。

不可微函数的梯度

使用自动微分计算梯度只在使用的每个基本函数都可微的情况下才有效。不幸的是,我们在实践中使用的许多函数都不具有此属性(例如,relusqrt0 处)。为了尝试减少不可微函数的影响,我们通过按顺序应用以下规则来定义基本运算的梯度

  1. 如果该函数是可微的,因此在当前点存在梯度,则使用它。

  2. 如果该函数是凸函数(至少是局部凸函数),则使用最小范数的次梯度(它是最大下降方向)。

  3. 如果该函数是凹函数(至少是局部凹函数),则使用最小范数的超梯度(考虑 -f(x) 并应用前一点)。

  4. 如果函数已定义,则通过连续性在当前点定义梯度(请注意,此处可能出现 inf,例如 sqrt(0))。如果有多个值,则任意选择一个。

  5. 如果函数未定义(例如,sqrt(-1)log(-1) 或大多数函数在输入为 NaN 时),则用作梯度的值是任意的(我们也可能会引发错误,但这不能保证)。大多数函数将使用 NaN 作为梯度,但出于性能原因,某些函数将使用其他值(例如 log(-1))。

  6. 如果该函数不是确定性映射(即它不是数学函数),则它将被标记为不可微。如果在需要梯度的张量上使用它,这将使其在反向传播中出错,除非是在 no_grad 环境中。

局部禁用梯度计算

Python 提供了几种机制来局部禁用梯度计算

要禁用整个代码块的梯度,可以使用上下文管理器,如无梯度模式和推理模式。为了更精细地从梯度计算中排除子图,可以设置张量的 requires_grad 字段。

下面,除了讨论上述机制外,我们还将描述评估模式(nn.Module.eval()),该方法不用于禁用梯度计算,但由于其名称,经常与这三种方法混淆。

设置 requires_grad

requires_grad 是一个标志,默认为 false,除非包装在 nn.Parameter 中,它允许从梯度计算中精细地排除子图。它在正向和反向传播中都有效

在正向传播期间,只有当至少一个输入张量需要梯度时,才会在反向图中记录操作。在反向传播(.backward())期间,只有 requires_grad=True 的叶张量才会将其梯度累积到其 .grad 字段中。

需要注意的是,即使每个张量都有此标志,设置它也只对叶张量有意义(没有 grad_fn 的张量,例如 nn.Module 的参数)。非叶张量(具有 grad_fn 的张量)是具有与其关联的反向图的张量。因此,它们的梯度将作为中间结果来计算需要梯度的叶张量的梯度。根据此定义,很明显,所有非叶张量都将自动具有 require_grad=True

设置 requires_grad 应该是控制模型哪些部分参与梯度计算的主要方法,例如,如果您需要在模型微调期间冻结预训练模型的某些部分。

要冻结模型的某些部分,只需将 .requires_grad_(False) 应用于您不想更新的参数。如上所述,由于在正向传播中不会记录使用这些参数作为输入的计算,因此在反向传播中不会更新它们的 .grad 字段,因为它们最初就不属于反向图的一部分,这正是我们想要的。

因为这是一种非常常见的模式,所以也可以在模块级别使用 nn.Module.requires_grad_() 设置 requires_grad。当应用于模块时,.requires_grad_() 对模块的所有参数(默认情况下 requires_grad=True)都有效。

梯度模式

除了设置 requires_grad 之外,还可以从 Python 中选择三种梯度模式,这些模式会影响 PyTorch 中的计算如何由 autograd 在内部处理:默认模式(梯度模式)、无梯度模式和推理模式,所有这些模式都可以通过上下文管理器和装饰器进行切换。

模式

从反向图中排除记录的操作

跳过额外的自动梯度跟踪开销

在启用该模式时创建的张量可以在以后的梯度模式中使用

示例

默认

正向传播

无梯度

优化器更新

推理

数据处理、模型评估

默认模式(梯度模式)

“默认模式”是指在没有启用其他模式(如无梯度模式和推理模式)时,我们隐式处于的模式。为了与“无梯度模式”形成对比,默认模式有时也称为“梯度模式”。

关于默认模式,最重要的是要知道它是唯一一种 requires_grad 生效的模式。requires_grad 在其他两种模式下始终被覆盖为 False

无梯度模式

无梯度模式下的计算表现得好像没有输入需要梯度一样。换句话说,即使存在 require_grad=True 的输入,无梯度模式下的计算也永远不会记录在反向图中。

当您需要执行不应该由 autograd 记录的操作,但您仍然希望稍后在梯度模式下使用这些计算的输出时,请启用无梯度模式。此上下文管理器可以方便地禁用代码块或函数的梯度,而无需临时将张量设置为 requires_grad=False,然后又设置为 True

例如,在编写优化器时,无梯度模式可能很有用:在执行训练更新时,您希望就地更新参数,而无需 autograd 记录更新。您还打算在下一次正向传播中将更新后的参数用于梯度模式下的计算。

torch.nn.init 中的实现也在初始化参数时依赖于无梯度模式,以避免在就地更新初始化参数时进行自动梯度跟踪。

推理模式

推理模式是无梯度模式的极端版本。就像在无梯度模式下一样,推理模式下的计算不会记录在反向图中,但启用推理模式将允许 PyTorch 进一步加速您的模型。这种更好的运行时伴随着一个缺点:在退出推理模式后,无法在要由 autograd 记录的计算中使用在推理模式下创建的张量。

当您执行不需要在反向图中记录的计算时,并且您不打算在退出推理模式后在任何要由 autograd 记录的计算中使用在推理模式下创建的张量时,请启用推理模式。

建议您在代码中不需要自动梯度跟踪的部分(例如,数据处理和模型评估)中尝试使用推理模式。如果它在您的用例中开箱即用,那么您将获得免费的性能提升。如果在启用推理模式后遇到错误,请检查您是否在退出推理模式后在由 autograd 记录的计算中使用了在推理模式下创建的张量。如果在您的情况下无法避免这种使用,则始终可以切换回无梯度模式。

有关推理模式的详细信息,请参阅推理模式

有关推理模式的实现细节,请参阅RFC-0011-InferenceMode

评估模式(nn.Module.eval()

评估模式不是一种局部禁用梯度计算的机制。无论如何,它都包含在这里,因为它有时会被误认为是这样一种机制。

在功能上,module.eval()(或等效的 module.train(False))与无梯度模式和推理模式完全正交。model.eval() 如何影响您的模型完全取决于模型中使用的特定模块以及它们是否定义了任何特定于训练模式的行为。

如果您模型依赖于可能根据训练模式表现不同的模块(例如,torch.nn.Dropouttorch.nn.BatchNorm2d),则您有责任调用 model.eval()model.train(),例如,为了避免在验证数据上更新 BatchNorm 运行统计信息。

建议您在训练时始终使用 model.train(),在评估模型(验证/测试)时始终使用 model.eval(),即使您不确定您的模型是否具有特定于训练模式的行为,因为您正在使用的模块可能会更新为在训练和评估模式下表现不同。

使用 Autograd 进行原地操作

在 Autograd 中支持原地操作是一件困难的事情,我们不建议在大多数情况下使用它们。Autograd 积极的缓冲区释放和重用使其非常高效,并且在极少数情况下,原地操作会显着降低内存使用量。除非您在内存压力很大的情况下操作,否则您可能永远不需要使用它们。

有两个主要原因限制了原地操作的适用性

  1. 原地操作可能会覆盖计算梯度所需的值。

  2. 每个原地操作都需要实现来重写计算图。非原地版本只是分配新对象并保留对旧图的引用,而原地操作则需要将所有输入的创建者更改为表示此操作的 Function。这可能很棘手,特别是当有很多张量引用同一个存储时(例如,通过索引或转置创建),并且如果修改后的输入的存储被任何其他 Tensor 引用,则原地函数将引发错误。

原地正确性检查

每个张量都保留一个版本计数器,每次在任何操作中将其标记为脏时,该计数器都会递增。当一个函数为反向传播保存任何张量时,也会保存它们包含的张量的版本计数器。一旦您访问 self.saved_tensors,就会对其进行检查,如果它大于保存的值,则会引发错误。这确保了如果您正在使用原地函数并且没有看到任何错误,则可以确保计算出的梯度是正确的。

多线程 Autograd

Autograd 引擎负责运行计算反向传播所需的所有反向操作。本节将描述所有细节,可以帮助您在多线程环境中充分利用它。(这仅与 PyTorch 1.6+ 相关,因为先前版本中的行为有所不同。)

用户可以使用多线程代码训练他们的模型(例如 Hogwild 训练),并且不会阻塞并发反向计算,示例代码可以是

# Define a train function to be used in different threads
def train_fn():
    x = torch.ones(5, 5, requires_grad=True)
    # forward
    y = (x + 3) * (x + 4) * 0.5
    # backward
    y.sum().backward()
    # potential optimizer update


# User write their own threading code to drive the train_fn
threads = []
for _ in range(10):
    p = threading.Thread(target=train_fn, args=())
    p.start()
    threads.append(p)

for p in threads:
    p.join()

请注意,用户应该注意的一些行为

CPU 上的并发

当您在 CPU 上的多个线程中通过 Python 或 C++ API 运行 backward()grad() 时,您期望看到额外的并发性,而不是在执行期间按特定顺序序列化所有反向调用(PyTorch 1.6 之前的行为)。

不确定性

如果您正在从多个线程并发调用 backward() 并且具有共享输入(即 Hogwild CPU 训练),那么应该预料到不确定性。这可能会发生,因为参数在线程之间自动共享,因此,多个线程可能会在梯度累积期间访问并尝试累积相同的 .grad 属性。从技术上讲,这是不安全的,并且可能会导致竞争条件,并且结果可能无效。

开发具有共享参数的多线程模型的用户应该牢记线程模型,并且应该理解上面描述的问题。

可以使用函数式 API torch.autograd.grad() 来计算梯度,而不是 backward(),以避免不确定性。

图保留

如果 Autograd 图的一部分在线程之间共享,即在单个线程中运行前向传递的第一部分,然后在多个线程中运行第二部分,则图的第一部分是共享的。在这种情况下,不同的线程在同一个图上执行 grad()backward() 可能会出现一个线程正在动态销毁图的问题,而另一个线程在这种情况下会崩溃。Autograd 将向用户抛出类似于在没有 retain_graph=True 的情况下调用两次 backward() 的错误,并让用户知道他们应该使用 retain_graph=True

Autograd 节点上的线程安全

由于 Autograd 允许调用方线程驱动其反向执行以实现潜在的并行性,因此重要的是,我们必须确保在 CPU 上使用共享部分/全部 GraphTask 的并行 backward() 调用时的线程安全。

由于 GIL 的存在,自定义 Python autograd.Function 会自动线程安全。对于内置 C++ Autograd 节点(例如 AccumulateGrad、CopySlices)和自定义 autograd::Function,Autograd 引擎使用线程互斥锁来确保可能具有状态写入/读取的 Autograd 节点上的线程安全。

C++ 钩子上没有线程安全

Autograd 依赖用户编写线程安全的 C++ 钩子。如果您希望钩子在多线程环境中正确应用,则需要编写适当的线程锁定代码以确保钩子是线程安全的。

复数的 Autograd

简短版本

  • 当您使用 PyTorch 对任何具有复数域和/或复数值域的函数 f(z)f(z) 进行微分时,梯度是在假设该函数是更大的实值损失函数 g(input)=Lg(input)=L 的一部分的情况下计算的。计算的梯度是 Lz\frac{\partial L}{\partial z^*}(注意 z 的共轭),其负数恰好是梯度下降算法中使用的最速下降方向。因此,存在一条可行的路径,可以使现有的优化器开箱即用地使用复数参数。

  • 此约定与 TensorFlow 的复数微分约定相匹配,但与 JAX 不同(JAX 计算 Lz\frac{\partial L}{\partial z})。

  • 如果您有一个内部使用复数运算的实数到实数函数,则此处的约定无关紧要:您将始终获得与仅使用实数运算实现时相同的结果。

如果您对数学细节感到好奇,或者想知道如何在 PyTorch 中定义复数导数,请继续阅读。

什么是复数导数?

复数可微性的数学定义采用导数的极限定义,并将其推广到对复数进行运算。考虑一个函数 f:CCf: ℂ → ℂ

f(z=x+yj)=u(x,y)+v(x,y)jf(z=x+yj) = u(x, y) + v(x, y)j

其中 uuvv 是两个实值二元函数,而 jj 是虚数单位。

使用导数定义,我们可以写成

f(z)=limh0,hCf(z+h)f(z)hf'(z) = \lim_{h \to 0, h \in C} \frac{f(z+h) - f(z)}{h}

为了使此极限存在,不仅 uuvv 必须是实可微的,而且 ff 还必须满足柯西-黎曼方程。换句话说:对实部和虚部步长(hh)计算的极限必须相等。这是一个更严格的条件。

复可微函数通常被称为全纯函数。它们表现良好,具有你在实可微函数中看到的所有优良性质,但在优化领域却毫无用处。对于优化问题,研究界只使用实值目标函数,因为复数不属于任何有序域,因此具有复值损失没有多大意义。

事实还证明,没有任何有趣的实值目标函数能满足柯西-黎曼方程。因此,全纯函数理论不能用于优化,因此大多数人使用Wirtinger微积分。

Wirtinger微积分应运而生…

所以,我们有这个伟大的复可微性和全纯函数理论,但我们根本无法使用它,因为许多常用的函数都不是全纯的。可怜的数学家该怎么办?好吧,Wirtinger观察到,即使 f(z)f(z) 不是全纯的,我们也可以把它改写成一个二元函数 f(z,z)f(z, z*),它总是全纯的。这是因为 zz 的实部和虚部可以用 zzzz^* 来表示,如下所示:

Re(z)=z+z2Im(z)=zz2j\begin{aligned} \mathrm{Re}(z) &= \frac {z + z^*}{2} \\ \mathrm{Im}(z) &= \frac {z - z^*}{2j} \end{aligned}

维丁格微积分建议研究 f(z,z)f(z, z^*),如果 ff 是实可微的,则保证它是全纯的(另一种理解方式是将其视为坐标系的变换,从 f(x,y)f(x, y)f(z,z)f(z, z^*).) 该函数的偏导数为 z\frac{\partial }{\partial z}z\frac{\partial}{\partial z^{*}}。我们可以使用链式法则来建立这些偏导数与 zz 的实部和虚部的偏导数之间的关系。

x=zxz+zxz=z+zy=zyz+zyz=1j(zz)\begin{aligned} \frac{\partial }{\partial x} &= \frac{\partial z}{\partial x} * \frac{\partial }{\partial z} + \frac{\partial z^*}{\partial x} * \frac{\partial }{\partial z^*} \\ &= \frac{\partial }{\partial z} + \frac{\partial }{\partial z^*} \\ \\ \frac{\partial }{\partial y} &= \frac{\partial z}{\partial y} * \frac{\partial }{\partial z} + \frac{\partial z^*}{\partial y} * \frac{\partial }{\partial z^*} \\ &= 1j * \left(\frac{\partial }{\partial z} - \frac{\partial }{\partial z^*}\right) \end{aligned}

从上面的等式中,我们得到

z=1/2(x1jy)z=1/2(x+1jy)\begin{aligned} \frac{\partial }{\partial z} &= 1/2 * \left(\frac{\partial }{\partial x} - 1j * \frac{\partial }{\partial y}\right) \\ \frac{\partial }{\partial z^*} &= 1/2 * \left(\frac{\partial }{\partial x} + 1j * \frac{\partial }{\partial y}\right) \end{aligned}

这就是你在维基百科上可以找到的 Wirtinger 微积分的经典定义。

这种变化有很多美好的结果。

  • 首先,Cauchy-Riemann 方程可以简单地解释为 fz=0\frac{\partial f}{\partial z^*} = 0 (也就是说,函数 ff 可以完全用 zz 来表示,而不必参考 zz^*)。

  • 另一个重要的(并且有点违反直觉的)结果是,正如我们稍后将看到的,当我们对实值损失进行优化时,我们在进行变量更新时应该采取的步骤由 Lossz\frac{\partial Loss}{\partial z^*} 给出(而不是 Lossz\frac{\partial Loss}{\partial z})。

有关更多信息,请查看:https://arxiv.org/pdf/0906.4835.pdf

Wirtinger 微积分在优化中有什么用?

音频和其他领域的研究人员更常使用梯度下降来优化具有复变量的实值损失函数。通常,这些人将实部和虚部视为可以更新的单独通道。对于步长 α/2\alpha/2 和损失 LL,我们可以写出以下在 R2ℝ^2 中的方程

xn+1=xn(α/2)Lxyn+1=yn(α/2)Ly\begin{aligned} x_{n+1} &= x_n - (\alpha/2) * \frac{\partial L}{\partial x} \\ y_{n+1} &= y_n - (\alpha/2) * \frac{\partial L}{\partial y} \end{aligned}

这些方程如何转换为复数空间 C

zn+1=xn(α/2)Lx+1j(yn(α/2)Ly)=znα1/2(Lx+jLy)=znαLz\begin{aligned} z_{n+1} &= x_n - (\alpha/2) * \frac{\partial L}{\partial x} + 1j * (y_n - (\alpha/2) * \frac{\partial L}{\partial y}) \\ &= z_n - \alpha * 1/2 * \left(\frac{\partial L}{\partial x} + j \frac{\partial L}{\partial y}\right) \\ &= z_n - \alpha * \frac{\partial L}{\partial z^*} \end{aligned}

非常有趣的事情发生了:Wirtinger 微积分告诉我们,我们可以简化上述复变量更新公式,使其仅涉及共轭 Wirtinger 导数Lz\frac{\partial L}{\partial z^*},从而得出我们在优化中采取的步骤。

因为共轭 Wirtinger 导数为实值损失函数提供了完全正确的步骤,所以当您对具有实值损失的函数进行微分时,PyTorch 会为您提供此导数。

PyTorch 如何计算共轭 Wirtinger 导数?

通常,我们的导数公式将 grad_output 作为输入,表示我们已经计算出的传入向量-雅可比积,也称为 Ls\frac{\partial L}{\partial s^*},其中 LL 是整个计算的损失(产生一个实值损失),ss 是我们函数的输出。这里的目标是计算 Lz\frac{\partial L}{\partial z^*},其中 zz 是函数的输入。事实证明,在实值损失的情况下,我们可以_只_计算 Ls\frac{\partial L}{\partial s^*},即使链式法则意味着我们还需要访问 Ls\frac{\partial L}{\partial s}。如果您想跳过此推导,请查看本节中的最后一个等式,然后跳至下一节。

让我们继续使用 f:CCf: ℂ → ℂ 定义为 f(z)=f(x+yj)=u(x,y)+v(x,y)jf(z) = f(x+yj) = u(x, y) + v(x, y)j。如上所述,autograd 的梯度约定以实值损失函数的优化为中心,因此让我们假设 ff 是更大的实值损失函数 gg 的一部分。使用链式法则,我们可以写成

(1)Lz=Luuz+Lvvz\frac{\partial L}{\partial z^*} = \frac{\partial L}{\partial u} * \frac{\partial u}{\partial z^*} + \frac{\partial L}{\partial v} * \frac{\partial v}{\partial z^*}

现在,使用 Wirtinger 导数定义,我们可以写成:

Ls=1/2(LuLvj)Ls=1/2(Lu+Lvj)\begin{aligned} \frac{\partial L}{\partial s} = 1/2 * \left(\frac{\partial L}{\partial u} - \frac{\partial L}{\partial v} j\right) \\ \frac{\partial L}{\partial s^*} = 1/2 * \left(\frac{\partial L}{\partial u} + \frac{\partial L}{\partial v} j\right) \end{aligned} ...

这里需要注意的是,由于 uuvv 是实函数,并且根据我们的假设,ff 是实值函数的一部分,因此 LL 也是实数,我们有:

(2)(Ls)=Ls\left( \frac{\partial L}{\partial s} \right)^* = \frac{\partial L}{\partial s^*}

也就是说,Ls\frac{\partial L}{\partial s} 等于 grad_outputgrad\_output^*

求解上述关于 Lu\frac{\partial L}{\partial u}Lv\frac{\partial L}{\partial v} 的方程式,我们得到:

(3)Lu=Ls+LsLv=1j(LsLs)\begin{aligned} \frac{\partial L}{\partial u} = \frac{\partial L}{\partial s} + \frac{\partial L}{\partial s^*} \\ \frac{\partial L}{\partial v} = 1j * \left(\frac{\partial L}{\partial s} - \frac{\partial L}{\partial s^*}\right) \end{aligned}

(3)代入(1),我们得到

Lz=(Ls+Ls)uz+1j(LsLs)vz=Ls(uz+vzj)+Ls(uzvzj)=Ls(u+vj)z+Ls(u+vj)z=Lssz+Lssz\begin{aligned} \frac{\partial L}{\partial z^*} &= \left(\frac{\partial L}{\partial s} + \frac{\partial L}{\partial s^*}\right) * \frac{\partial u}{\partial z^*} + 1j * \left(\frac{\partial L}{\partial s} - \frac{\partial L}{\partial s^*}\right) * \frac{\partial v}{\partial z^*} \\ &= \frac{\partial L}{\partial s} * \left(\frac{\partial u}{\partial z^*} + \frac{\partial v}{\partial z^*} j\right) + \frac{\partial L}{\partial s^*} * \left(\frac{\partial u}{\partial z^*} - \frac{\partial v}{\partial z^*} j\right) \\ &= \frac{\partial L}{\partial s} * \frac{\partial (u + vj)}{\partial z^*} + \frac{\partial L}{\partial s^*} * \frac{\partial (u + vj)^*}{\partial z^*} \\ &= \frac{\partial L}{\partial s} * \frac{\partial s}{\partial z^*} + \frac{\partial L}{\partial s^*} * \frac{\partial s^*}{\partial z^*} \\ \end{aligned}

使用(2),我们得到

(4)Lz=(Ls)sz+Ls(sz)=(grad_output)sz+grad_output(sz)\begin{aligned} \frac{\partial L}{\partial z^*} &= \left(\frac{\partial L}{\partial s^*}\right)^* * \frac{\partial s}{\partial z^*} + \frac{\partial L}{\partial s^*} * \left(\frac{\partial s}{\partial z}\right)^* \\ &= \boxed{ (grad\_output)^* * \frac{\partial s}{\partial z^*} + grad\_output * \left(\frac{\partial s}{\partial z}\right)^* } \\ \end{aligned}

最后一个等式对于编写您自己的梯度非常重要,因为它将我们的导数公式分解为一个更简单的公式,可以很容易地手动计算。

如何为复函数编写自己的导数公式?

上述带框方程式为我们提供了复函数所有导数的一般公式。然而,我们仍然需要计算 sz\frac{\partial s}{\partial z}sz\frac{\partial s}{\partial z^*}。 你可以通过两种方式做到这一点。

  • 第一种方法是直接使用 Wirtinger 导数的定义,并使用 sx\frac{\partial s}{\partial x}sy\frac{\partial s}{\partial y} (你可以用通常的方式计算) 来计算 sz\frac{\partial s}{\partial z}sz\frac{\partial s}{\partial z^*}

  • 第二种方法是使用变量替换技巧,将 f(z)f(z) 重写为二元函数 f(z,z)f(z, z^*),并将 zzzz^* 视为自变量来计算共轭 Wirtinger 导数。这通常更容易; 例如,如果所讨论的函数是全纯的,则只使用 zz(并且 sz\frac{\partial s}{\partial z^*} 将为零)。

我们以函数 f(z=x+yj)=cz=c(x+yj)f(z = x + yj) = c * z = c * (x+yj) 为例,其中 cRc \in ℝ

使用第一种方法计算 Wirtinger 导数,我们有:

sz=1/2(sxsyj)=1/2(c(c1j)1j)=csz=1/2(sx+syj)=1/2(c+(c1j)1j)=0\begin{aligned} \frac{\partial s}{\partial z} &= 1/2 * \left(\frac{\partial s}{\partial x} - \frac{\partial s}{\partial y} j\right) \\ &= 1/2 * (c - (c * 1j) * 1j) \\ &= c \\ \\ \\ \frac{\partial s}{\partial z^*} &= 1/2 * \left(\frac{\partial s}{\partial x} + \frac{\partial s}{\partial y} j\right) \\ &= 1/2 * (c + (c * 1j) * 1j) \\ &= 0 \\ \end{aligned}

使用 (4)grad_output = 1.0(这是在 PyTorch 中对标量输出调用 backward() 时使用的默认梯度输出值),我们得到

Lz=10+1c=c\frac{\partial L}{\partial z^*} = 1 * 0 + 1 * c = c

使用第二种计算 Wirtinger 导数的方法,我们直接得到

sz=(cz)z=csz=(cz)z=0\begin{aligned} \frac{\partial s}{\partial z} &= \frac{\partial (c*z)}{\partial z} \\ &= c \\ \frac{\partial s}{\partial z^*} &= \frac{\partial (c*z)}{\partial z^*} \\ &= 0 \end{aligned}

再次使用 (4),我们得到 Lz=c\frac{\partial L}{\partial z^*} = c。如您所见,第二种方法涉及的计算量更少,并且更便于更快地进行计算。

跨域函数呢?

一些函数从复数输入映射到实数输出,反之亦然。这些函数构成了 (4) 的一个特例,我们可以使用链式法则推导出它

  • 对于 f:CRf: ℂ → ℝ,我们得到

    Lz=2grad_outputsz\frac{\partial L}{\partial z^*} = 2 * grad\_output * \frac{\partial s}{\partial z^{*}}
  • 对于 f:RCf: ℝ → ℂ,我们得到

    Lz=2Re(grad_outputsz)\frac{\partial L}{\partial z^*} = 2 * \mathrm{Re}(grad\_output^* * \frac{\partial s}{\partial z^{*}})

保存张量的钩子

您可以通过定义一对 pack_hook / unpack_hook 钩子来控制 保存张量的打包/解包方式pack_hook 函数应该将张量作为其唯一参数,但可以返回任何 Python 对象(例如,另一个张量、元组,甚至包含文件名的字符串)。unpack_hook 函数将 pack_hook 的输出作为其唯一参数,并且应该返回一个要在反向传播中使用的张量。 unpack_hook 返回的张量只需要与作为输入传递给 pack_hook 的张量具有相同的内容。 特别是,任何与自动求导相关的元数据都可以忽略,因为它们将在解包过程中被覆盖。

此类对的示例如下:

class SelfDeletingTempFile():
    def __init__(self):
        self.name = os.path.join(tmp_dir, str(uuid.uuid4()))

    def __del__(self):
        os.remove(self.name)

def pack_hook(tensor):
    temp_file = SelfDeletingTempFile()
    torch.save(tensor, temp_file.name)
    return temp_file

def unpack_hook(temp_file):
    return torch.load(temp_file.name)

请注意,unpack_hook 不应删除临时文件,因为它可能会被多次调用:只要返回的 SelfDeletingTempFile 对象存在,临时文件就应该存在。 在上面的示例中,我们通过在不再需要时关闭临时文件(删除 SelfDeletingTempFile 对象时)来防止泄漏临时文件。

注意

我们保证 pack_hook 只会被调用一次,但 unpack_hook 可以根据反向传播的需要调用多次,并且我们希望它每次都返回相同的数据。

警告

禁止对任何函数的输入执行就地操作,因为它们可能会导致意外的副作用。 如果对 pack 钩子的输入进行就地修改,PyTorch 将会抛出一个错误,但不会捕获对 unpack 钩子的输入进行就地修改的情况。

为保存的张量注册钩子

您可以通过在 SavedTensor 对象上调用 register_hooks() 方法,在保存的张量上注册一对钩子。 这些对象作为 grad_fn 的属性公开,并以 _raw_saved_ 前缀开头。

x = torch.randn(5, requires_grad=True)
y = x.pow(2)
y.grad_fn._raw_saved_self.register_hooks(pack_hook, unpack_hook)

注册该对后,将立即调用 pack_hook 方法。 每当需要访问保存的张量时,无论是通过 y.grad_fn._saved_self 还是在反向传播期间,都会调用 unpack_hook 方法。

警告

如果您在释放保存的张量后(即调用 backward 后)保留对 SavedTensor 的引用,则禁止调用其 register_hooks()。 PyTorch 在大多数情况下会抛出一个错误,但在某些情况下可能会失败,并可能出现未定义的行为。

为保存的张量注册默认钩子

或者,您可以使用上下文管理器 saved_tensors_hooks 来注册一对钩子,这些钩子将应用于在该上下文中创建的*所有*保存的张量。

示例

# Only save on disk tensors that have size >= 1000
SAVE_ON_DISK_THRESHOLD = 1000

def pack_hook(x):
    if x.numel() < SAVE_ON_DISK_THRESHOLD:
        return x
    temp_file = SelfDeletingTempFile()
    torch.save(tensor, temp_file.name)
    return temp_file

def unpack_hook(tensor_or_sctf):
    if isinstance(tensor_or_sctf, torch.Tensor):
        return tensor_or_sctf
    return torch.load(tensor_or_sctf.name)

class Model(nn.Module):
    def forward(self, x):
        with torch.autograd.graph.saved_tensors_hooks(pack_hook, unpack_hook):
          # ... compute output
          output = x
        return output

model = Model()
net = nn.DataParallel(model)

使用此上下文管理器定义的钩子是线程本地的。 因此,以下代码不会产生预期的效果,因为钩子不会通过 DataParallel

# Example what NOT to do

net = nn.DataParallel(model)
with torch.autograd.graph.saved_tensors_hooks(pack_hook, unpack_hook):
    output = net(input)

请注意,使用这些钩子会禁用所有用于减少张量对象创建的优化。 例如

with torch.autograd.graph.saved_tensors_hooks(lambda x: x, lambda x: x):
    x = torch.randn(5, requires_grad=True)
    y = x * x

如果没有钩子,xy.grad_fn._saved_selfy.grad_fn._saved_other 都引用同一个张量对象。 有了钩子,PyTorch 会将 x 打包并解包到两个新的张量对象中,这些对象与原始的 x 共享相同的存储空间(不执行复制)。

反向钩子执行

本节将讨论不同钩子触发或不触发的时间。然后将讨论它们的触发顺序。将涵盖以下钩子:通过 torch.Tensor.register_hook() 注册到 Tensor 的反向钩子、通过 torch.Tensor.register_post_accumulate_grad_hook() 注册到 Tensor 的累积梯度后钩子、通过 torch.autograd.graph.Node.register_hook() 注册到 Node 的后置钩子,以及通过 torch.autograd.graph.Node.register_prehook() 注册到 Node 的前置钩子。

特定钩子是否会被触发

当正在为 Tensor 计算梯度时,将执行通过 torch.Tensor.register_hook() 注册到该 Tensor 的钩子。(请注意,这并不要求执行 Tensor 的 grad_fn。例如,如果将 Tensor 作为 inputs 参数的一部分传递给 torch.autograd.grad(),则可能不会执行 Tensor 的 grad_fn,但始终会执行注册到该 Tensor 的钩子。)

在为 Tensor 累积梯度后(即已设置 Tensor 的 grad 字段),将执行通过 torch.Tensor.register_post_accumulate_grad_hook() 注册到该 Tensor 的钩子。通过 torch.Tensor.register_hook() 注册的钩子在计算梯度时运行,而通过 torch.Tensor.register_post_accumulate_grad_hook() 注册的钩子仅在反向传递结束时由 autograd 更新 Tensor 的 grad 字段后才会触发。因此,累积梯度后钩子只能为叶子 Tensor 注册。即使您调用 backward(retain_graph=True),在非叶子 Tensor 上通过 torch.Tensor.register_post_accumulate_grad_hook() 注册钩子也会出错。

仅当执行使用 torch.autograd.graph.Node.register_hook()torch.autograd.graph.Node.register_prehook() 注册到 torch.autograd.graph.Node 的钩子所注册到的 Node 时,才会触发这些钩子。

特定 Node 是否被执行可能取决于反向传递是使用 torch.autograd.grad() 还是 torch.autograd.backward() 调用的。具体来说,当您在与您要传递给 torch.autograd.grad()torch.autograd.backward() 的 Tensor 对应的 Node 上注册钩子作为 inputs 参数的一部分时,您应该注意这些差异。

如果您正在使用 torch.autograd.backward(),则无论您是否指定了 inputs 参数,上述所有钩子都将被执行。这是因为 .backward() 会执行所有 Node,即使它们对应于指定为输入的 Tensor。(请注意,执行与作为 inputs 传递的 Tensor 对应的此附加 Node 通常是不必要的,但无论如何都会执行。此行为可能会发生变化;您不应该依赖它。)

另一方面,如果您正在使用 torch.autograd.grad(),则可能不会执行注册到与传递给 input 的 Tensor 对应的 Node 的反向钩子,因为除非有另一个输入依赖于此 Node 的梯度结果,否则不会执行这些 Node。

不同钩子的触发顺序

事件发生的顺序如下:

  1. 执行注册到 Tensor 的钩子

  2. 执行注册到 Node 的前置钩子(如果 Node 被执行)。

  3. 为保留梯度的 Tensor 更新 .grad 字段

  4. 执行 Node(取决于上述规则)

  5. 对于已累积 .grad 的叶子 Tensor,执行累积梯度后钩子

  6. 执行注册到 Node 的后置钩子(如果 Node 被执行)

如果在同一个 Tensor 或 Node 上注册了多个相同类型的钩子,则将按照注册顺序执行它们。稍后执行的钩子可以观察到先前钩子对梯度所做的修改。

特殊钩子

torch.autograd.graph.register_multi_grad_hook() 是使用注册到 Tensor 的钩子实现的。每个单独的 Tensor 钩子都按照上面定义的 Tensor 钩子顺序触发,并且在计算最后一个 Tensor 梯度时调用注册的多梯度钩子。

torch.nn.modules.module.register_module_full_backward_hook() 是使用注册到 Node 的钩子实现的。在计算前向传递时,钩子被注册到对应于模块的输入和输出的 grad_fn。因为一个模块可以接受多个输入并返回多个输出,所以在前向传递之前首先对模块的输入应用一个虚拟的自定义自动求导函数,并在前向传递的输出返回之前对模块的输出应用一个虚拟的自定义自动求导函数,以确保这些张量共享一个单一的 grad_fn,然后我们可以将我们的钩子附加到该 grad_fn 上。

Tensor 原地修改时 Tensor 钩子的行为

通常,注册到 Tensor 的钩子会接收输出相对于该 Tensor 的梯度,其中 Tensor 的值取为计算反向传递时的值。

但是,如果您将钩子注册到 Tensor,然后原地修改该 Tensor,则在原地修改之前注册的钩子同样会接收输出相对于该 Tensor 的梯度,但 Tensor 的值取为原地修改之前的值。

如果您希望采用前一种情况的行为,则应在对 Tensor 进行所有原地修改后将它们注册到该 Tensor。例如

t = torch.tensor(1., requires_grad=True).sin()
t.cos_()
t.register_hook(fn)
t.backward()

此外,了解到在幕后,当将钩子注册到 Tensor 时,它们实际上会永久绑定到该 Tensor 的 grad_fn,这将很有帮助,因此,如果随后原地修改了该 Tensor,即使该 Tensor 现在有了一个新的 grad_fn,在原地修改之前注册的钩子仍将继续与旧的 grad_fn 关联,例如,当自动求导引擎在图中到达该 Tensor 的旧 grad_fn 时,它们将被触发。

文档

访问 PyTorch 的全面开发者文档

查看文档

教程

获取针对初学者和高级开发人员的深入教程

查看教程

资源

查找开发资源并获得您的问题的答案

查看资源