快捷方式

Autograd 机制

本笔记将概述 autograd 的工作原理以及如何记录操作。虽然不一定需要理解所有这些内容,但我们建议您熟悉它,因为它将帮助您编写更高效、更简洁的程序,并可以帮助您进行调试。

autograd 如何编码历史记录

Autograd 是一个反向自动微分系统。从概念上讲,autograd 会记录一个图,记录所有创建数据的操作,就像您执行操作一样,为您提供一个有向无环图,其叶节点是输入张量,根节点是输出张量。通过从根节点到叶节点追踪此图,您可以自动使用链式法则计算梯度。

在内部,autograd 将此图表示为 Function 对象(实际上是表达式)的图,可以 apply() 以计算评估图的结果。在计算前向传播时,autograd 同时执行请求的计算并构建一个图,该图表示计算梯度的函数(每个 torch.Tensor.grad_fn 属性是此图的入口点)。当前向传播完成后,我们会在反向传播中评估此图以计算梯度。

需要注意的重要一点是,图在每次迭代时都会从头开始重新创建,这正是允许使用任意 Python 控制流语句的原因,这些语句可以在每次迭代时更改图的整体形状和大小。您不必在启动训练之前编码所有可能的路径 - 您运行什么,就微分什么。

保存的张量

某些操作需要在前向传播期间保存中间结果,以便执行反向传播。例如,函数 xx2x\mapsto x^2 保存输入 xx 以计算梯度。

在定义自定义 Python Function 时,您可以使用 save_for_backward() 在前向传播期间保存张量,并使用 saved_tensors 在反向传播期间检索它们。有关更多信息,请参阅 扩展 PyTorch

对于 PyTorch 定义的操作(例如 torch.pow()),张量会根据需要自动保存。您可以探索(出于教育或调试目的)某个 grad_fn 保存了哪些张量,方法是查找其以 _saved 前缀开头的属性。

x = torch.randn(5, requires_grad=True)
y = x.pow(2)
print(x.equal(y.grad_fn._saved_self))  # True
print(x is y.grad_fn._saved_self)  # True

在之前的代码中,y.grad_fn._saved_self 指的是与 x 相同的 Tensor 对象。但这可能并不总是如此。例如

x = torch.randn(5, requires_grad=True)
y = x.exp()
print(y.equal(y.grad_fn._saved_result))  # True
print(y is y.grad_fn._saved_result)  # False

在底层,为了防止引用循环,PyTorch 在保存时打包了张量,并在读取时将其解包为不同的张量。在这里,您从访问 y.grad_fn._saved_result 获取的张量是与 y 不同的张量对象(但它们仍然共享相同的存储)。

张量是否会被打包到不同的张量对象中取决于它是否是其自身 grad_fn 的输出,这是一个实现细节,可能会发生变化,用户不应依赖它。

您可以使用 保存张量的钩子 来控制 PyTorch 如何进行打包/解包。

不可微函数的梯度

仅当所使用的每个基本函数都是可微的时,使用自动微分计算梯度才是有效的。不幸的是,我们在实践中使用的许多函数都不具备此属性(例如,relusqrt0 处)。为了尽量减少不可微函数的影响,我们通过按顺序应用以下规则来定义基本操作的梯度

  1. 如果函数是可微的,因此在当前点存在梯度,则使用它。

  2. 如果函数是凸函数(至少在局部是凸函数),则使用最小范数的次梯度(它是最速下降方向)。

  3. 如果函数是凹函数(至少在局部是凹函数),则使用最小范数的超梯度(考虑 -f(x) 并应用上一点)。

  4. 如果函数已定义,则通过连续性在当前点定义梯度(请注意,这里可能存在 inf,例如对于 sqrt(0))。如果存在多个可能的值,则任意选择一个。

  5. 如果函数未定义(例如,sqrt(-1)log(-1) 或大多数输入为 NaN 时的函数),则用作梯度的值是任意的(我们也可能会引发错误,但这不能保证)。大多数函数将使用 NaN 作为梯度,但出于性能原因,某些函数将使用其他值(例如,log(-1))。

  6. 如果函数不是确定性映射(即,它不是 数学函数),则它将被标记为不可微。如果用于需要 grad 的张量且在 no_grad 环境之外,这将使其在反向传播中报错。

局部禁用梯度计算

有几种机制可以从 Python 局部禁用梯度计算

要禁用整个代码块的梯度,可以使用上下文管理器,如 no-grad 模式 和 推理模式。对于从梯度计算中更细粒度地排除子图,可以设置张量的 requires_grad 字段。

下面,除了讨论上述机制外,我们还将介绍评估模式 (nn.Module.eval()),这是一种不用于禁用梯度计算的方法,但由于其名称,常常与前三种方法混淆。

设置 requires_grad

requires_grad 是一个标志,默认值为 false,除非包装在 nn.Parameter 中,它允许从梯度计算中细粒度地排除子图。它在前向传播和反向传播中都生效

在前向传播期间,仅当其至少一个输入张量需要 grad 时,操作才会被记录在反向图中。在反向传播期间 (.backward()),只有 requires_grad=True 的叶张量才会将其梯度累积到其 .grad 字段中。

重要的是要注意,即使每个张量都有此标志,设置它也仅对叶张量(没有 grad_fn 的张量,例如 nn.Module 的参数)才有意义。非叶张量(具有 grad_fn 的张量)是具有与其关联的反向图的张量。因此,它们的梯度将作为中间结果,用于计算需要 grad 的叶张量的梯度。从这个定义可以清楚地看出,所有非叶张量都将自动具有 require_grad=True

设置 requires_grad 应该是您控制模型的哪些部分参与梯度计算的主要方法,例如,如果您需要在模型微调期间冻结预训练模型的某些部分。

要冻结模型的某些部分,只需将 .requires_grad_(False) 应用于您不想更新的参数。如上所述,由于将这些参数用作输入的计算不会记录在前向传播中,因此它们不会在反向传播中更新其 .grad 字段,因为它们从一开始就不会成为反向图的一部分,正如预期的那样。

由于这是一个非常常见的模式,因此也可以使用 nn.Module.requires_grad_() 在模块级别设置 requires_grad。当应用于模块时,.requires_grad_() 对模块的所有参数生效(默认情况下这些参数具有 requires_grad=True)。

Grad 模式

除了设置 requires_grad 之外,还有三种 grad 模式可以从 Python 中选择,它们会影响 autograd 在内部处理 PyTorch 中的计算的方式:默认模式(grad 模式)、no-grad 模式 和 推理模式,所有这些模式都可以通过上下文管理器和装饰器进行切换。

模式

将操作从记录在反向图中排除

跳过额外的 autograd 跟踪开销

在启用模式时创建的张量稍后可以在 grad 模式中使用

示例

默认

前向传播

no-grad

优化器更新

推理

数据处理,模型评估

默认模式(Grad 模式)

“默认模式”是当我们没有启用其他模式(如 no-grad 和 推理模式)时,我们隐式处于的模式。为了与 “no-grad 模式” 区分,“默认模式” 有时也称为 “grad 模式”。

关于默认模式,最重要的是要知道它是 requires_grad 生效的唯一模式。在其他两种模式下,requires_grad 始终被覆盖为 False

No-grad 模式

no-grad 模式下的计算行为就像没有输入需要 grad 一样。换句话说,即使存在 require_grad=True 的输入,no-grad 模式下的计算也永远不会记录在反向图中。

当您需要执行不应由 autograd 记录的操作时,但您仍然希望稍后在 grad 模式中使用这些计算的输出时,启用 no-grad 模式。此上下文管理器可以方便地禁用代码块或函数的梯度,而无需临时将张量设置为 requires_grad=False,然后再设置回 True

例如,no-grad 模式在编写优化器时可能很有用:在执行训练更新时,您希望就地更新参数,而无需更新被 autograd 记录。您还打算在下一次前向传播中将更新后的参数用于 grad 模式下的计算。

torch.nn.init 中的实现也依赖 no-grad 模式,以便在初始化参数时避免在就地更新初始化的参数时进行 autograd 跟踪。

推理模式

推理模式是 no-grad 模式的极端版本。与 no-grad 模式一样,推理模式下的计算不会记录在反向图中,但启用推理模式将使 PyTorch 能够更快地加速您的模型。这种更好的运行时性能是有代价的:在推理模式下创建的张量将无法在退出推理模式后在要由 autograd 记录的计算中使用。

当您执行不与 autograd 交互的计算,并且不计划在稍后要由 autograd 记录的任何计算中使用在推理模式下创建的张量时,启用推理模式。

建议您在代码中不需要 autograd 跟踪的部分(例如,数据处理和模型评估)中尝试推理模式。如果它开箱即用地适用于您的用例,那么这是一个免费的性能提升。如果在启用推理模式后遇到错误,请检查您是否在退出推理模式后,在由 autograd 记录的计算中使用了在推理模式下创建的张量。如果您的用例中无法避免此类使用,您可以随时切换回 no-grad 模式。

有关推理模式的详细信息,请参阅 推理模式

有关推理模式的实现细节,请参阅 RFC-0011-InferenceMode

评估模式 (nn.Module.eval())

评估模式不是一种局部禁用梯度计算的机制。尽管如此,此处仍然包含它,因为它有时会被误认为是这样一种机制。

从功能上讲,module.eval()(或等效的 module.train(False))与 no-grad 模式和推理模式完全正交。model.eval() 如何影响您的模型完全取决于模型中使用的特定模块以及它们是否定义了任何训练模式特定的行为。

您负责调用 model.eval()model.train(),如果您的模型依赖于诸如 torch.nn.Dropouttorch.nn.BatchNorm2d 之类的模块,这些模块可能会根据训练模式表现出不同的行为,例如,避免在验证数据上更新 BatchNorm 运行统计信息。

建议您始终在训练时使用 model.train(),在评估模型(验证/测试)时使用 model.eval(),即使您不确定您的模型是否具有训练模式特定的行为,因为您正在使用的模块可能会更新为在训练和评估模式下表现出不同的行为。

使用 autograd 的就地操作

在 autograd 中支持就地操作是一个难题,我们不鼓励在大多数情况下使用它们。Autograd 的激进缓冲区释放和重用使其非常高效,并且很少有就地操作能显着降低内存使用量的情况。除非您在重内存压力下运行,否则您可能永远不需要使用它们。

有两个主要原因限制了就地操作的适用性

  1. 就地操作可能会覆盖计算梯度所需的值。

  2. 每个就地操作都需要实现来重写计算图。异地版本只是分配新对象并保留对旧图的引用,而就地操作需要更改表示此操作的 Function 的所有输入的创建者。这可能很棘手,特别是当有许多 Tensor 引用相同的存储时(例如,通过索引或转置创建),如果修改后的输入的存储被任何其他 Tensor 引用,就地函数将引发错误。

就地正确性检查

每个张量都维护一个版本计数器,每次在任何操作中将其标记为脏时,该计数器都会递增。当 Function 保存任何张量以进行反向传播时,也会保存其包含的 Tensor 的版本计数器。一旦您访问 self.saved_tensors,就会进行检查,如果它大于保存的值,则会引发错误。这确保了如果您正在使用就地函数并且没有看到任何错误,您可以确信计算出的梯度是正确的。

多线程 Autograd

autograd 引擎负责运行计算反向传播所需的所有反向操作。本节将描述所有可以帮助您在多线程环境中充分利用它的细节。(这仅与 PyTorch 1.6+ 相关,因为以前版本的行为有所不同。)

用户可以使用多线程代码(例如 Hogwild 训练)训练他们的模型,并且不会阻塞并发的反向计算,示例代码可能如下所示

# Define a train function to be used in different threads
def train_fn():
    x = torch.ones(5, 5, requires_grad=True)
    # forward
    y = (x + 3) * (x + 4) * 0.5
    # backward
    y.sum().backward()
    # potential optimizer update


# User write their own threading code to drive the train_fn
threads = []
for _ in range(10):
    p = threading.Thread(target=train_fn, args=())
    p.start()
    threads.append(p)

for p in threads:
    p.join()

请注意,用户应该注意以下一些行为

CPU 上的并发

当您在 CPU 上的多个线程中通过 python 或 C++ API 运行 backward()grad() 时,您期望看到额外的并发性,而不是在执行期间按特定顺序串行化所有反向调用(PyTorch 1.6 之前的行为)。

非确定性

如果您从多个线程并发调用 backward() 并且有共享输入(即 Hogwild CPU 训练),那么应该预期会出现非确定性。发生这种情况的原因是参数在线程之间自动共享,因此,多个线程可能访问并尝试在梯度累积期间累积相同的 .grad 属性。从技术上讲,这不安全,并且可能会导致竞争条件,并且结果可能无效。

开发具有共享参数的多线程模型的用户应该牢记线程模型,并且应该理解上述问题。

可以使用函数式 API torch.autograd.grad() 来计算梯度,而不是 backward(),以避免非确定性。

图保留

如果 autograd 图的一部分在线程之间共享,即先在单线程中运行第一部分前向传播,然后在多线程中运行第二部分,则图的第一部分是共享的。在这种情况下,不同的线程在同一图上执行 grad()backward() 可能会出现在一个线程的运行中销毁图的问题,而另一个线程在这种情况下会崩溃。Autograd 将向用户报错,类似于调用 backward() 两次且没有 retain_graph=True 的情况,并告知用户他们应该使用 retain_graph=True

Autograd 节点上的线程安全

由于 Autograd 允许调用者线程驱动其反向执行以实现潜在的并行性,因此重要的是我们确保 CPU 上的线程安全,以应对共享部分/全部 GraphTask 的并行 backward() 调用。

自定义 Python autograd.Function 由于 GIL 而自动是线程安全的。对于内置的 C++ Autograd 节点(例如 AccumulateGrad、CopySlices)和自定义 autograd::Function,Autograd 引擎使用线程互斥锁来确保可能具有状态写入/读取的 autograd 节点上的线程安全。

C++ 钩子上没有线程安全

Autograd 依赖用户编写线程安全的 C++ 钩子。如果您希望钩子在多线程环境中正确应用,您将需要编写适当的线程锁定代码以确保钩子是线程安全的。

复数的 Autograd

简短版本

  • 当您使用 PyTorch 对任何函数 f(z)f(z) 进行微分时,且该函数具有复数域和/或复数共域,则梯度的计算基于以下假设:该函数是更大的实值损失函数 g(input)=Lg(input)=L 的一部分。计算出的梯度为 Lz\frac{\partial L}{\partial z^*} (请注意 z 的共轭),其负数恰好是梯度下降算法中使用的最速下降方向。因此,在使现有优化器与复数参数一起开箱即用方面,存在一条可行的路径。

  • 此约定与 TensorFlow 的复数微分约定相匹配,但与 JAX 不同(JAX 计算的是 Lz\frac{\partial L}{\partial z})。

  • 如果您有一个实数到实数的函数,该函数在内部使用复数运算,那么这里的约定无关紧要:您将始终获得与仅使用实数运算实现时相同的结果。

如果您对数学细节感到好奇,或者想知道如何在 PyTorch 中定义复数导数,请继续阅读。

什么是复数导数?

复数可微性的数学定义采用了导数的极限定义,并将其推广到复数运算。考虑一个函数 f:CCf: ℂ → ℂ

f(z=x+yj)=u(x,y)+v(x,y)jf(z=x+yj) = u(x, y) + v(x, y)j

其中 uuvv 是两个变量的实值函数,而 jj 是虚数单位。

使用导数定义,我们可以写出

f(z)=limh0,hCf(z+h)f(z)hf'(z) = \lim_{h \to 0, h \in C} \frac{f(z+h) - f(z)}{h}

为了使此极限存在,不仅 uuvv 必须是实数可微的,而且 ff 还必须满足柯西-黎曼 方程。换句话说:对于实数和虚数步长(hh)计算的极限必须相等。这是一个更严格的条件。

复数可微函数通常被称为全纯函数。它们表现良好,具有您从实数可微函数中看到的所有良好属性,但在优化领域几乎没有用处。对于优化问题,研究界仅使用实值目标函数,因为复数不属于任何有序域,因此具有复数值损失没有多大意义。

事实证明,没有有趣的实值目标函数满足柯西-黎曼方程。因此,全纯函数理论不能用于优化,因此大多数人使用 Wirtinger 演算。

Wirtinger 演算开始发挥作用……

因此,我们有这个很棒的复数可微性和全纯函数理论,但我们根本无法使用它,因为许多常用函数都不是全纯的。可怜的数学家该怎么办?嗯,Wirtinger 观察到,即使 f(z)f(z) 不是全纯的,也可以将其重写为双变量函数 f(z,z)f(z, z*),它始终是全纯的。这是因为 zz 的分量的实部和虚部可以用 zzzz^* 表示为

Re(z)=z+z2Im(z)=zz2j\begin{aligned} \mathrm{Re}(z) &= \frac {z + z^*}{2} \\ \mathrm{Im}(z) &= \frac {z - z^*}{2j} \end{aligned}

Wirtinger 演算建议研究 f(z,z)f(z, z^*),如果 ff 是实数可微的,则保证它是全纯的(另一种考虑方式是坐标系的更改,从 f(x,y)f(x, y)f(z,z)f(z, z^*)。)此函数具有偏导数 z\frac{\partial }{\partial z}z\frac{\partial}{\partial z^{*}}。我们可以使用链式法则来建立这些偏导数与 zz 的实部和虚部相关的偏导数之间的关系。

x=zxz+zxz=z+zy=zyz+zyz=1j(zz)\begin{aligned} \frac{\partial }{\partial x} &= \frac{\partial z}{\partial x} * \frac{\partial }{\partial z} + \frac{\partial z^*}{\partial x} * \frac{\partial }{\partial z^*} \\ &= \frac{\partial }{\partial z} + \frac{\partial }{\partial z^*} \\ \\ \frac{\partial }{\partial y} &= \frac{\partial z}{\partial y} * \frac{\partial }{\partial z} + \frac{\partial z^*}{\partial y} * \frac{\partial }{\partial z^*} \\ &= 1j * \left(\frac{\partial }{\partial z} - \frac{\partial }{\partial z^*}\right) \end{aligned}

从以上公式,我们得到

z=1/2(x1jy)z=1/2(x+1jy)\begin{aligned} \frac{\partial }{\partial z} &= 1/2 * \left(\frac{\partial }{\partial x} - 1j * \frac{\partial }{\partial y}\right) \\ \frac{\partial }{\partial z^*} &= 1/2 * \left(\frac{\partial }{\partial x} + 1j * \frac{\partial }{\partial y}\right) \end{aligned}

这是维尔廷格 calculus 的经典定义,您可以在维基百科上找到。

这个改变有很多美妙的结果。

  • 首先,柯西-黎曼方程可以简单地转化为 fz=0\frac{\partial f}{\partial z^*} = 0 (也就是说,函数 ff 可以完全用 zz 表示,而无需参考 zz^*)。

  • 另一个重要(且有些违反直觉)的结果,正如我们稍后将看到的,是当我们对实值损失进行优化时,变量更新时应采取的步骤由 Lossz\frac{\partial Loss}{\partial z^*} 给出(不是 Lossz\frac{\partial Loss}{\partial z})。

更多阅读材料,请查看: https://arxiv.org/pdf/0906.4835.pdf

维尔廷格 Calculus 在优化中有什么用?

音频和其他领域的研究人员更常用梯度下降法来优化具有复变量的实值损失函数。通常,这些人将实部和虚部视为可以单独更新的通道。对于步长 α/2\alpha/2 和损失 LL,我们可以用 R2ℝ^2 写出以下公式

xn+1=xn(α/2)Lxyn+1=yn(α/2)Ly\begin{aligned} x_{n+1} &= x_n - (\alpha/2) * \frac{\partial L}{\partial x} \\ y_{n+1} &= y_n - (\alpha/2) * \frac{\partial L}{\partial y} \end{aligned}

这些公式如何转换到复数空间 C

zn+1=xn(α/2)Lx+1j(yn(α/2)Ly)=znα1/2(Lx+jLy)=znαLz\begin{aligned} z_{n+1} &= x_n - (\alpha/2) * \frac{\partial L}{\partial x} + 1j * (y_n - (\alpha/2) * \frac{\partial L}{\partial y}) \\ &= z_n - \alpha * 1/2 * \left(\frac{\partial L}{\partial x} + j \frac{\partial L}{\partial y}\right) \\ &= z_n - \alpha * \frac{\partial L}{\partial z^*} \end{aligned}

非常有趣的事情发生了:Wirtinger 演算告诉我们,我们可以简化上面的复变量更新公式,使其仅参考共轭 Wirtinger 导数 Lz\frac{\partial L}{\partial z^*},这正是我们在优化中采取的步骤。

因为共轭 Wirtinger 导数对于实值损失函数给出了完全正确的步骤,所以当您对具有实值损失的函数进行微分时,PyTorch 会提供此导数。

PyTorch 如何计算共轭 Wirtinger 导数?

通常,我们的导数公式将 grad_output 作为输入,表示我们已经计算出的传入的向量-雅可比乘积,也称为 Ls\frac{\partial L}{\partial s^*},其中 LL 是整个计算的损失(产生实值损失),而 ss 是我们函数的输出。 这里的目标是计算 Lz\frac{\partial L}{\partial z^*},其中 zz 是函数的输入。 事实证明,在实值损失的情况下,我们可以计算 Ls\frac{\partial L}{\partial s^*},即使链式法则意味着我们还需要访问 Ls\frac{\partial L}{\partial s}。 如果您想跳过此推导,请查看本节中的最后一个公式,然后跳到下一节。

让我们继续使用定义为 f:CCf: ℂ → ℂ 的函数 f(z)=f(x+yj)=u(x,y)+v(x,y)jf(z) = f(x+yj) = u(x, y) + v(x, y)j。 如上所述,autograd 的梯度约定以实值损失函数的优化为中心,因此让我们假设 ff 是更大的实值损失函数 gg 的一部分。 使用链式法则,我们可以写出

(1)Lz=Luuz+Lvvz\frac{\partial L}{\partial z^*} = \frac{\partial L}{\partial u} * \frac{\partial u}{\partial z^*} + \frac{\partial L}{\partial v} * \frac{\partial v}{\partial z^*}

现在使用 Wirtinger 导数定义,我们可以写出

Ls=1/2(LuLvj)Ls=1/2(Lu+Lvj)\begin{aligned} \frac{\partial L}{\partial s} = 1/2 * \left(\frac{\partial L}{\partial u} - \frac{\partial L}{\partial v} j\right) \\ \frac{\partial L}{\partial s^*} = 1/2 * \left(\frac{\partial L}{\partial u} + \frac{\partial L}{\partial v} j\right) \end{aligned}

这里应该注意,因为 uuvv 是实函数,并且 LL 根据我们假设的 ff 是实值函数的一部分,我们有

(2)(Ls)=Ls\left( \frac{\partial L}{\partial s} \right)^* = \frac{\partial L}{\partial s^*}

即,Ls\frac{\partial L}{\partial s} 等于 grad_outputgrad\_output^*

求解上述方程得到 Lu\frac{\partial L}{\partial u}Lv\frac{\partial L}{\partial v},我们得到

(3)Lu=Ls+LsLv=1j(LsLs)\begin{aligned} \frac{\partial L}{\partial u} = \frac{\partial L}{\partial s} + \frac{\partial L}{\partial s^*} \\ \frac{\partial L}{\partial v} = 1j * \left(\frac{\partial L}{\partial s} - \frac{\partial L}{\partial s^*}\right) \end{aligned}

(3) 代入 (1),我们得到

Lz=(Ls+Ls)uz+1j(LsLs)vz=Ls(uz+vzj)+Ls(uzvzj)=Ls(u+vj)z+Ls(u+vj)z=Lssz+Lssz\begin{aligned} \frac{\partial L}{\partial z^*} &= \left(\frac{\partial L}{\partial s} + \frac{\partial L}{\partial s^*}\right) * \frac{\partial u}{\partial z^*} + 1j * \left(\frac{\partial L}{\partial s} - \frac{\partial L}{\partial s^*}\right) * \frac{\partial v}{\partial z^*} \\ &= \frac{\partial L}{\partial s} * \left(\frac{\partial u}{\partial z^*} + \frac{\partial v}{\partial z^*} j\right) + \frac{\partial L}{\partial s^*} * \left(\frac{\partial u}{\partial z^*} - \frac{\partial v}{\partial z^*} j\right) \\ &= \frac{\partial L}{\partial s} * \frac{\partial (u + vj)}{\partial z^*} + \frac{\partial L}{\partial s^*} * \frac{\partial (u + vj)^*}{\partial z^*} \\ &= \frac{\partial L}{\partial s} * \frac{\partial s}{\partial z^*} + \frac{\partial L}{\partial s^*} * \frac{\partial s^*}{\partial z^*} \\ \end{aligned}

使用 公式 (2),我们得到

(4)Lz=(Ls)sz+Ls(sz)=(grad_output)sz+grad_output(sz)\begin{aligned} \frac{\partial L}{\partial z^*} &= \left(\frac{\partial L}{\partial s^*}\right)^* * \frac{\partial s}{\partial z^*} + \frac{\partial L}{\partial s^*} * \left(\frac{\partial s}{\partial z}\right)^* \\ &= \boxed{ (grad\_output)^* * \frac{\partial s}{\partial z^*} + grad\_output * \left(\frac{\partial s}{\partial z}\right)^* } \\ \end{aligned}

这个最后的公式对于编写你自己的梯度非常重要,因为它将我们的导数公式分解为一个更简单的公式,这个公式很容易手动计算。

如何为复函数编写自己的导数公式?

上面框起来的公式给出了复函数上所有导数的通用公式。然而,我们仍然需要计算 sz\frac{\partial s}{\partial z}sz\frac{\partial s}{\partial z^*}。 你可以通过两种方式做到这一点

  • 第一种方法是直接使用 Wirtinger 导数的定义,并计算 sz\frac{\partial s}{\partial z}sz\frac{\partial s}{\partial z^*},通过使用 sx\frac{\partial s}{\partial x}sy\frac{\partial s}{\partial y} (你可以用通常的方式计算)。

  • 第二种方法是使用变量替换技巧,将 f(z)f(z) 重写为二元函数 f(z,z)f(z, z^*),并通过将 zzzz^* 视为独立变量来计算共轭 Wirtinger 导数。 这通常更容易;例如,如果所讨论的函数是全纯函数,则只会使用 zz (并且 sz\frac{\partial s}{\partial z^*} 将为零)。

让我们考虑函数 f(z=x+yj)=cz=c(x+yj)f(z = x + yj) = c * z = c * (x+yj) 作为一个例子,其中 cRc \in ℝ

使用第一种方法计算 Wirtinger 导数,我们得到。

sz=1/2(sxsyj)=1/2(c(c1j)1j)=csz=1/2(sx+syj)=1/2(c+(c1j)1j)=0\begin{aligned} \frac{\partial s}{\partial z} &= 1/2 * \left(\frac{\partial s}{\partial x} - \frac{\partial s}{\partial y} j\right) \\ &= 1/2 * (c - (c * 1j) * 1j) \\ &= c \\ \\ \\ \frac{\partial s}{\partial z^*} &= 1/2 * \left(\frac{\partial s}{\partial x} + \frac{\partial s}{\partial y} j\right) \\ &= 1/2 * (c + (c * 1j) * 1j) \\ &= 0 \\ \end{aligned}

使用 (4),以及 grad_output = 1.0 (这是在 PyTorch 中对标量输出调用 backward() 时使用的默认 grad 输出值),我们得到

Lz=10+1c=c\frac{\partial L}{\partial z^*} = 1 * 0 + 1 * c = c

使用第二种计算 Wirtinger 导数的方法,我们直接得到

sz=(cz)z=csz=(cz)z=0\begin{aligned} \frac{\partial s}{\partial z} &= \frac{\partial (c*z)}{\partial z} \\ &= c \\ \frac{\partial s}{\partial z^*} &= \frac{\partial (c*z)}{\partial z^*} \\ &= 0 \end{aligned}

再次使用 (4),我们得到 Lz=c\frac{\partial L}{\partial z^*} = c。 如你所见,第二种方法涉及较少的计算,并且对于更快的计算更方便。

跨域函数呢?

一些函数将复数输入映射到实数输出,反之亦然。 这些函数构成了 (4) 的一个特例,我们可以使用链式法则推导出来

  • 对于 f:CRf: ℂ → ℝ,我们得到

    Lz=2grad_outputsz\frac{\partial L}{\partial z^*} = 2 * grad\_output * \frac{\partial s}{\partial z^{*}}
  • 对于 f:RCf: ℝ → ℂ,我们得到

    Lz=2Re(grad_outputsz)\frac{\partial L}{\partial z^*} = 2 * \mathrm{Re}(grad\_output^* * \frac{\partial s}{\partial z^{*}})

保存张量的钩子

您可以通过定义一对 pack_hook / unpack_hook 钩子来控制如何打包/解包保存的张量pack_hook 函数应接受一个张量作为其唯一参数,但可以返回任何 python 对象(例如,另一个张量、元组,甚至包含文件名的字符串)。unpack_hook 函数接受 pack_hook 的输出作为其唯一参数,并且应返回一个张量,用于反向传播。unpack_hook 返回的张量只需要与作为 pack_hook 输入传递的张量具有相同的内容即可。特别地,任何与 autograd 相关的元数据都可以忽略,因为它们将在解包期间被覆盖。

这样一个钩子对的例子是

class SelfDeletingTempFile():
    def __init__(self):
        self.name = os.path.join(tmp_dir, str(uuid.uuid4()))

    def __del__(self):
        os.remove(self.name)

def pack_hook(tensor):
    temp_file = SelfDeletingTempFile()
    torch.save(tensor, temp_file.name)
    return temp_file

def unpack_hook(temp_file):
    return torch.load(temp_file.name)

请注意,unpack_hook 不应删除临时文件,因为它可能会被多次调用:临时文件应在返回的 SelfDeletingTempFile 对象存活期间保持活动状态。在上面的示例中,我们通过在不再需要临时文件时(在删除 SelfDeletingTempFile 对象时)关闭它来防止临时文件泄漏。

注意

我们保证 pack_hook 只会被调用一次,但 unpack_hook 可以根据反向传播的需要被调用多次,并且我们期望它每次都返回相同的数据。

警告

禁止对任何函数的输入执行原地操作,因为它们可能会导致意外的副作用。如果原地修改了 pack 钩子的输入,PyTorch 将抛出错误,但不会捕获原地修改 unpack 钩子的输入的情况。

为保存的张量注册钩子

您可以通过在 SavedTensor 对象上调用 register_hooks() 方法,在保存的张量上注册一对钩子。这些对象作为 grad_fn 的属性公开,并以 _raw_saved_ 前缀开头。

x = torch.randn(5, requires_grad=True)
y = x.pow(2)
y.grad_fn._raw_saved_self.register_hooks(pack_hook, unpack_hook)

pack_hook 方法在钩子对注册后立即调用。unpack_hook 方法在每次需要访问保存的张量时调用,无论是通过 y.grad_fn._saved_self 还是在反向传播期间。

警告

如果您在保存的张量被释放后(即在反向传播被调用后)仍然保持对 SavedTensor 的引用,则禁止调用其 register_hooks()。PyTorch 大多数时候会抛出错误,但在某些情况下可能无法做到这一点,并且可能会出现未定义的行为。

为保存的张量注册默认钩子

或者,您可以使用上下文管理器 saved_tensors_hooks 来注册一对钩子,这些钩子将应用于在该上下文中创建的所有保存的张量。

例子

# Only save on disk tensors that have size >= 1000
SAVE_ON_DISK_THRESHOLD = 1000

def pack_hook(x):
    if x.numel() < SAVE_ON_DISK_THRESHOLD:
        return x
    temp_file = SelfDeletingTempFile()
    torch.save(tensor, temp_file.name)
    return temp_file

def unpack_hook(tensor_or_sctf):
    if isinstance(tensor_or_sctf, torch.Tensor):
        return tensor_or_sctf
    return torch.load(tensor_or_sctf.name)

class Model(nn.Module):
    def forward(self, x):
        with torch.autograd.graph.saved_tensors_hooks(pack_hook, unpack_hook):
          # ... compute output
          output = x
        return output

model = Model()
net = nn.DataParallel(model)

使用此上下文管理器定义的钩子是线程局部的。因此,以下代码不会产生预期的效果,因为钩子不会通过 DataParallel

# Example what NOT to do

net = nn.DataParallel(model)
with torch.autograd.graph.saved_tensors_hooks(pack_hook, unpack_hook):
    output = net(input)

请注意,使用这些钩子会禁用所有用于减少张量对象创建的优化。例如

with torch.autograd.graph.saved_tensors_hooks(lambda x: x, lambda x: x):
    x = torch.randn(5, requires_grad=True)
    y = x * x

在没有钩子的情况下,xy.grad_fn._saved_selfy.grad_fn._saved_other 都引用同一个张量对象。使用钩子,PyTorch 将 x 打包和解包为两个新的张量对象,这两个对象与原始 x 共享相同的存储(不执行复制)。

反向钩子的执行

本节将讨论不同的钩子何时触发或不触发。然后它将讨论它们被触发的顺序。将涵盖的钩子包括:通过 torch.Tensor.register_hook() 注册到张量的反向钩子,通过 torch.Tensor.register_post_accumulate_grad_hook() 注册到张量的后累积梯度钩子,通过 torch.autograd.graph.Node.register_hook() 注册到节点的后钩子,以及通过 torch.autograd.graph.Node.register_prehook() 注册到节点的前钩子。

特定钩子是否会被触发

通过 torch.Tensor.register_hook() 注册到张量的钩子在为该张量计算梯度时执行。(请注意,这不需要执行张量的 grad_fn。例如,如果张量作为 torch.autograd.grad()inputs 参数的一部分传递,则张量的 grad_fn 可能不会执行,但注册到该张量的钩子将始终执行。)

通过 torch.Tensor.register_post_accumulate_grad_hook() 注册到张量的钩子在该张量的梯度累积后执行,这意味着张量的 grad 字段已被设置。而通过 torch.Tensor.register_hook() 注册的钩子在计算梯度时运行,通过 torch.Tensor.register_post_accumulate_grad_hook() 注册的钩子仅在反向传播结束时张量的 grad 字段被 autograd 更新后触发。因此,后累积梯度钩子只能为叶张量注册。即使您调用 backward(retain_graph=True),在非叶张量上通过 torch.Tensor.register_post_accumulate_grad_hook() 注册钩子也会报错。

使用 torch.autograd.graph.Node.register_hook()torch.autograd.graph.Node.register_prehook() 注册到 torch.autograd.graph.Node 的钩子仅在该钩子注册到的节点被执行时触发。

特定节点是否被执行可能取决于反向传播是使用 torch.autograd.grad() 还是 torch.autograd.backward() 调用的。具体来说,当您在对应于您传递给 torch.autograd.grad()torch.autograd.backward() 的张量的节点上注册钩子作为 inputs 参数的一部分时,您应该注意这些差异。

如果您使用 torch.autograd.backward(),则所有上述提到的钩子都将被执行,无论您是否指定了 inputs 参数。这是因为 .backward() 执行所有节点,即使它们对应于作为输入指定的张量。(请注意,执行这个额外的节点(对应于作为 inputs 传递的张量)通常是不必要的,但仍然会执行。此行为可能会发生变化;您不应依赖它。)

另一方面,如果您使用 torch.autograd.grad(),则注册到对应于传递给 input 的张量的节点的反向钩子可能不会被执行,因为除非有另一个输入依赖于此节点的梯度结果,否则这些节点将不会被执行。

不同钩子被触发的顺序

事件发生的顺序是

  1. 注册到张量的钩子被执行

  2. 注册到节点的前钩子被执行(如果节点被执行)。

  3. 对于 retain_grad 的张量,.grad 字段被更新

  4. 节点被执行(服从上述规则)

  5. 对于累积了 .grad 的叶张量,后累积梯度钩子被执行

  6. 注册到节点的后钩子被执行(如果节点被执行)

如果在同一个张量或节点上注册了多个相同类型的钩子,它们将按照注册的顺序执行。稍后执行的钩子可以观察到早期钩子对梯度所做的修改。

特殊钩子

torch.autograd.graph.register_multi_grad_hook() 是使用注册到张量的钩子实现的。每个单独的张量钩子都按照上面定义的张量钩子顺序触发,并且当最后一个张量梯度被计算时,注册的多梯度钩子被调用。

torch.nn.modules.module.register_module_full_backward_hook() 是使用注册到节点的钩子实现的。当计算前向传播时,钩子被注册到对应于模块的输入和输出的 grad_fn。由于一个模块可能接受多个输入并返回多个输出,因此在返回前向传播的输出之前,首先将一个虚拟的自定义 autograd 函数应用于模块的输入,然后再应用于模块的输出,以确保这些张量共享一个 grad_fn,然后我们可以将我们的钩子附加到该 grad_fn。

当张量被原地修改时张量钩子的行为

通常,注册到张量的钩子接收输出相对于该张量的梯度,其中张量的值被认为是反向传播计算时的值。

但是,如果您向张量注册钩子,然后原地修改该张量,则在原地修改之前注册的钩子类似地接收输出相对于该张量的梯度,但张量的值被认为是原地修改之前的值。

如果您更喜欢前一种情况下的行为,则应在对张量进行所有原地修改之后再向其注册钩子。例如

t = torch.tensor(1., requires_grad=True).sin()
t.cos_()
t.register_hook(fn)
t.backward()

此外,了解在底层,当钩子注册到张量时,它们实际上永久绑定到该张量的 grad_fn,因此如果该张量随后被原地修改,即使张量现在有一个新的 grad_fn,但在原地修改之前注册的钩子将继续与旧的 grad_fn 关联,例如,当 autograd 引擎在图中到达该张量的旧 grad_fn 时,它们将被触发,这可能会有所帮助。

文档

访问 PyTorch 的全面开发者文档

查看文档

教程

获取面向初学者和高级开发者的深入教程

查看教程

资源

查找开发资源并获得您的问题解答

查看资源