Autograd 机制¶
本文档将概述 Autograd 的工作原理以及如何记录操作。虽然不强制要求完全理解这些内容,但我们建议您熟悉它,因为它将帮助您编写更高效、更清晰的程序,并有助于调试。
Autograd 如何编码历史¶
Autograd 是一个反向自动微分系统。从概念上讲,Autograd 会记录一个图,记录您执行操作时创建数据的所有操作,从而生成一个有向无环图,其叶节点是输入张量,根节点是输出张量。通过从根节点追溯到叶节点,您可以使用链式法则自动计算梯度。
在内部,Autograd 将此图表示为 Function 对象(实际上是表达式)的图,可以对其执行 apply() 来计算评估图的结果。在计算前向传播时,Autograd 同时执行请求的计算并构建一个表示计算梯度的函数的图(每个 torch.Tensor 的 .grad_fn 属性是进入此图的入口点)。当前向传播完成后,我们在反向传播中评估此图以计算梯度。
需要注意的一点是,该图在每次迭代时都会从头开始重新创建,这正是允许使用任意 Python 控制流语句的原因,这些语句可以在每次迭代时改变图的整体形状和大小。您无需在开始训练之前编码所有可能的路径 - 您运行的内容就是您要进行微分的内容。
保存的张量¶
有些操作需要在前向传播期间保存中间结果,以便执行反向传播。例如,函数 会保存输入 以计算梯度。
定义自定义 Python Function 时,您可以使用 save_for_backward() 在前向传播期间保存张量,并使用 saved_tensors 在反向传播期间检索它们。有关详细信息,请参阅扩展 PyTorch。
对于 PyTorch 定义的操作(例如 torch.pow()),会根据需要自动保存张量。您可以(出于教育或调试目的)通过查找以 _saved 为前缀的属性,来查看特定 grad_fn 保存了哪些张量。
x = torch.randn(5, requires_grad=True)
y = x.pow(2)
print(x.equal(y.grad_fn._saved_self)) # True
print(x is y.grad_fn._saved_self) # True
在前面的代码中,y.grad_fn._saved_self 指的是与 x 相同的张量对象。但并非总是如此。例如
x = torch.randn(5, requires_grad=True)
y = x.exp()
print(y.equal(y.grad_fn._saved_result)) # True
print(y is y.grad_fn._saved_result) # False
在底层,为了防止引用循环,PyTorch 在保存时会打包张量,并在读取时将其解包到不同的张量中。在这里,您通过访问 y.grad_fn._saved_result 获得的张量是与 y 不同的张量对象(但它们仍然共享相同的存储空间)。
张量是否会被打包到不同的张量对象中取决于它是否是其自身 grad_fn 的输出,这是一个实现细节,可能会发生变化,用户不应依赖于此。
您可以使用保存张量的 Hooks 来控制 PyTorch 如何进行打包/解包。
不可微分函数的梯度¶
使用自动微分计算梯度仅在使用的每个基本函数是可微分时有效。不幸的是,我们在实践中使用的许多函数不具备此属性(例如 relu 或 sqrt 在 0 处)。为了尝试减少不可微分函数的影响,我们通过按顺序应用以下规则来定义基本操作的梯度
如果函数是可微分的,并且当前点存在梯度,则使用它。
如果函数是凸函数(至少局部),则使用范数最小的次梯度(它是最陡的下降方向)。
如果函数是凹函数(至少局部),则使用范数最小的超梯度(考虑 -f(x) 并应用上一点)。
如果函数已定义,则通过连续性定义当前点的梯度(注意此处可能出现
inf,例如对于sqrt(0))。如果可能存在多个值,则任意选择一个。如果函数未定义(例如
sqrt(-1)、log(-1)或当输入为NaN时的大多数函数),则用作梯度的值是任意的(我们也可能引发错误,但这不保证)。大多数函数会使用NaN作为梯度,但出于性能原因,某些函数会使用其他值(例如log(-1))。如果函数不是确定性映射(即它不是数学函数),则会被标记为不可微分。如果在
no_grad环境之外对需要梯度的张量使用它,这将导致反向传播出错。
局部禁用梯度计算¶
Python 中提供了几种机制来局部禁用梯度计算
要在整个代码块中禁用梯度,可以使用 no-grad 模式和推理模式等上下文管理器。为了更精细地将子图排除在梯度计算之外,可以设置张量的 requires_grad 字段。
下面,除了讨论上述机制外,我们还将介绍评估模式(nn.Module.eval()),这种方法不是用来禁用梯度计算的,但由于其名称,经常与前三种机制混淆。
设置 requires_grad¶
requires_grad 是一个标志,默认为 false,除非封装在 nn.Parameter 中,它允许将子图精细地排除在梯度计算之外。它在前向传播和反向传播中都生效
在前向传播期间,只有当至少一个输入张量需要梯度时,操作才会被记录在反向图 (backward graph) 中。在反向传播 (.backward()) 期间,只有 requires_grad=True 的叶张量 (leaf tensors) 的梯度才会被累积到其 .grad 字段中。
需要注意的是,尽管每个张量都有这个标志,但设置它只对叶张量(没有 grad_fn 的张量,例如 nn.Module 的参数)有意义。非叶张量(有 grad_fn 的张量)是与其关联了反向图的张量。因此,它们的梯度将作为中间结果用于计算需要梯度的叶张量的梯度。从这个定义可以清楚地看出,所有非叶张量将自动具有 require_grad=True。
设置 requires_grad 应该是控制模型哪些部分参与梯度计算的主要方式,例如,如果您需要在模型微调期间冻结预训练模型的部分参数。
要冻结模型的部分参数,只需对您不想更新的参数应用 .requires_grad_(False)。如上所述,由于使用这些参数作为输入的计算不会被记录在前向传播中,因此它们的 .grad 字段在反向传播中不会被更新,因为它们根本不会成为反向图的一部分,这正是我们所期望的。
由于这是一种非常常见的模式,requires_grad 也可以通过 nn.Module.requires_grad_() 在模块级别设置。当应用于模块时,.requires_grad_() 会影响模块的所有参数(这些参数默认具有 requires_grad=True)。
梯度模式¶
除了设置 requires_grad 之外,Python 中还有三种可选的梯度模式,它们会影响 Autograd 在内部处理 PyTorch 计算的方式:默认模式(grad 模式)、no-grad 模式和推理模式,所有这些模式都可以通过上下文管理器和装饰器进行切换。
模式 |
排除操作不记录在反向图中 |
跳过额外的 Autograd 跟踪开销 |
在此模式下创建的张量稍后可在 grad 模式中使用 |
示例 |
|---|---|---|---|---|
默认 |
✓ |
前向传播 |
||
no-grad |
✓ |
✓ |
优化器更新 |
|
推理 |
✓ |
✓ |
数据处理,模型评估 |
默认模式 (Grad 模式)¶
“默认模式”是当没有启用 no-grad 模式和 inference 模式等其他模式时,我们隐含所处的模式。与“no-grad 模式”相对比,默认模式有时也称为“grad 模式”。
关于默认模式最重要的一点是,它是唯一使 requires_grad 生效的模式。requires_grad 在其他两种模式下总是被覆盖为 False。
No-grad 模式¶
在 no-grad 模式下的计算行为就好像所有输入都不需要梯度一样。换句话说,即使存在 require_grad=True 的输入,no-grad 模式下的计算也永远不会被记录在反向图中。
当您需要执行不应由 Autograd 记录的操作,但您仍然希望稍后在 grad 模式下使用这些计算的输出时,请启用 no-grad 模式。这个上下文管理器方便您为一个代码块或函数禁用梯度,而无需临时将张量设置为 requires_grad=False,然后再设置回 True。
例如,在编写优化器时,no-grad 模式可能很有用:在执行训练更新时,您希望原地更新参数,而无需 Autograd 记录此更新。您还打算在下一个前向传播中将更新后的参数用于 grad 模式下的计算。
torch.nn.init 中的实现也依赖于 no-grad 模式来初始化参数,以避免在原地更新已初始化参数时进行 Autograd 跟踪。
推理模式¶
推理模式是 no-grad 模式的极端版本。就像 no-grad 模式一样,推理模式下的计算不会被记录在反向图中,但启用推理模式可以让 PyTorch 进一步加速您的模型。这种更好的运行时性能伴随着一个缺点:在推理模式下创建的张量在退出推理模式后将无法用于需要 Autograd 记录的计算。
当您执行与 Autograd 没有交互的计算,并且不打算在稍后需要 Autograd 记录的任何计算中使用在推理模式下创建的张量时,请启用推理模式。
建议您在代码中不需要 Autograd 跟踪的部分(例如,数据处理和模型评估)尝试使用推理模式。如果它对您的用例直接适用,那将是免费的性能提升。如果在启用推理模式后遇到错误,请检查您是否在退出推理模式后,在需要 Autograd 记录的计算中使用了在推理模式下创建的张量。如果您的用例无法避免这种情况,您可以随时切换回 no-grad 模式。
有关推理模式的详细信息,请参阅推理模式。
有关推理模式的实现细节,请参阅RFC-0011-InferenceMode。
评估模式 (nn.Module.eval())¶
评估模式不是局部禁用梯度计算的机制。尽管如此,此处仍包含它,因为它有时会被误认为是此类机制。
从功能上讲,module.eval()(或等效的 module.train(False))与 no-grad 模式和推理模式完全正交。 model.eval() 如何影响您的模型完全取决于您模型中使用的特定模块以及它们是否定义了任何训练模式特有的行为。
如果您的模型依赖于 torch.nn.Dropout 和 torch.nn.BatchNorm2d 等可能因训练模式而表现不同的模块,例如,为了避免在验证数据上更新 BatchNorm 的运行统计信息,您有责任调用 model.eval() 和 model.train()。
建议您在训练模型时始终使用 model.train(),在评估模型时(验证/测试)使用 model.eval(),即使您不确定您的模型是否具有训练模式特有的行为,因为您使用的模块可能会更新以在训练和评估模式下表现不同。
Autograd 中的原地操作¶
在 Autograd 中支持原地操作是一个困难的问题,并且在大多数情况下我们不鼓励使用它们。Autograd 积极的缓冲区释放和重用使其非常高效,原地操作能够显著降低内存使用量的情况非常少。除非您面临严重的内存压力,否则您可能根本不需要使用它们。
限制原地操作适用性的主要原因有两个
原地操作可能会覆盖计算梯度所需的值。
每个原地操作都需要实现重写计算图。非原地版本只是简单地分配新对象并保留对旧图的引用,而原地操作需要更改表示此操作的
Function的所有输入的创建者。这可能很棘手,特别是如果有很多张量引用同一存储空间(例如通过索引或转置创建的),并且如果修改后的输入的存储空间被任何其他Tensor引用,原地函数将引发错误。
原地操作正确性检查¶
每个张量都维护一个版本计数器,每次在任何操作中被标记为脏时都会递增。当 Function 为反向传播保存任何张量时,其包含的 Tensor 的版本计数器也会被保存。一旦访问 self.saved_tensors,就会进行检查,如果它大于保存的值,则会引发错误。这确保了如果您使用原地函数且没有看到任何错误,则可以确信计算出的梯度是正确的。
多线程 Autograd¶
Autograd 引擎负责运行计算反向传播所需的所有反向操作。本节将详细介绍如何在多线程环境中充分利用 Autograd。(这仅与 PyTorch 1.6+ 相关,因为之前版本的行为不同。)
用户可以使用多线程代码(例如 Hogwild 训练)训练模型,并且不会阻塞并发的反向计算,示例代码如下
# Define a train function to be used in different threads
def train_fn():
x = torch.ones(5, 5, requires_grad=True)
# forward
y = (x + 3) * (x + 4) * 0.5
# backward
y.sum().backward()
# potential optimizer update
# User write their own threading code to drive the train_fn
threads = []
for _ in range(10):
p = threading.Thread(target=train_fn, args=())
p.start()
threads.append(p)
for p in threads:
p.join()
请注意用户应了解的一些行为
CPU 上的并发¶
当您在 CPU 上通过 Python 或 C++ API 在多个线程中运行 backward() 或 grad() 时,您期望看到额外的并发性,而不是在执行期间按特定顺序序列化所有反向调用(PyTorch 1.6 之前的行为)。
非确定性¶
如果您从多个线程并发调用 backward() 并且有共享输入(即 Hogwild CPU 训练),则应预料到非确定性。这可能发生,因为参数会在线程之间自动共享,因此多个线程可能会在梯度累积期间访问并尝试累积相同的 .grad 属性。这在技术上是不安全的,并可能导致竞态条件,结果可能无法使用。
开发具有共享参数的多线程模型的用户应该考虑到线程模型,并理解上述问题。
可以使用函数式 API torch.autograd.grad() 来计算梯度,而不是使用 backward(),以避免非确定性。
图保留¶
如果 Autograd 图的一部分在线程之间共享,即前向传播的第一部分在单个线程中运行,然后第二部分在多个线程中运行,则图的第一部分是共享的。在这种情况下,不同的线程在同一个图上执行 grad() 或 backward() 可能会出现一个线程在执行过程中销毁图的问题,另一个线程在这种情况下会崩溃。Autograd 会向用户报错,类似于在没有 retain_graph=True 的情况下调用 backward() 两次,并告知用户应该使用 retain_graph=True。
Autograd 节点上的线程安全¶
由于 Autograd 允许调用线程驱动其反向执行以实现潜在的并行性,因此确保在 CPU 上处理共享部分/全部 GraphTask 的并行 backward() 调用时的线程安全非常重要。
自定义 Python autograd.Function 由于 GIL(全局解释器锁)的存在而自动是线程安全的。对于内置的 C++ Autograd 节点(例如 AccumulateGrad、CopySlices)和自定义的 autograd::Function,Autograd 引擎使用线程互斥锁来确保可能涉及状态写入/读取的 Autograd 节点的线程安全。
C++ hooks 不提供线程安全¶
Autograd 依赖用户编写线程安全的 C++ hooks。如果您希望 hook 在多线程环境中正确应用,您需要编写适当的线程锁定代码来确保 hook 的线程安全。
复数的 Autograd¶
简短版本
当您在 PyTorch 中对具有复数定义域和/或值域的函数 求导时,梯度是在该函数是更大的实值损失函数 的一部分的假设下计算的。计算出的梯度为 (注意 z 的共轭形式),其负数正是梯度下降算法中最陡峭下降的方向。因此,为使现有优化器能直接处理复数参数提供了一条可行途径。
这一约定与 TensorFlow 在复数求导上的约定一致,但与 JAX 不同(JAX 计算 )。
如果您有一个内部使用复数运算的实数到实数的函数,这里的约定就不重要了:您将总是得到与只使用实数运算实现时相同的结果。
如果您对数学细节感到好奇,或者想知道如何在 PyTorch 中定义复数导数,请继续阅读。
什么是复数导数?¶
复数可微性的数学定义采用了导数的极限定义,并将其推广到对复数进行运算。考虑一个函数 ,
其中 和 是两个自变量的实值函数, 是虚数单位。
使用导数定义,我们可以写出
为了使此极限存在,不仅 和 必须是实可微的,而且 还必须满足柯西-黎曼方程。换句话说:关于实部和虚部步长 () 的计算极限必须相等。这是一个更严格的条件。
复数可微函数通常被称为全纯函数。它们性质良好,拥有你在实可微函数中看到的所有良好性质,但在优化领域几乎没有用处。对于优化问题,研究界只使用实值目标函数,因为复数不属于任何有序域,因此拥有复数值的损失函数没有太大意义。
事实也表明,没有有趣的实值目标函数满足柯西-黎曼方程。因此,全纯函数的理论不能用于优化,大多数人因此使用 Wirtinger 演算。
Wirtinger 演算登场了……¶
因此,我们拥有关于复数可微性和全纯函数的伟大理论,但我们完全无法使用它,因为许多常用函数不是全纯的。可怜的数学家该怎么办?嗯,Wirtinger 观察到,即使 不是全纯的,也可以将其重写为一个双变量函数 ,它总是全纯的。这是因为 z 的实部和虚部可以使用 z 和 z* 来表达:
Wirtinger 演算建议转而研究 ,如果 是实可微的,则 保证是全纯的(另一种理解方式是将其视为坐标系的改变,从 到 )。这个函数具有偏导数 和 。我们可以使用链式法则来建立这些偏导数与关于 z 的实部和虚部的偏导数之间的关系。
从上面的方程,我们得到
这是 Wirtinger 导数的经典定义,你可以在维基百科上找到。
这个改变有很多美妙的结果。
首先,Cauchy-Riemann 方程可以简单地表述为 (也就是说,函数 可以完全用 来表示,而无需涉及 )。
另一个重要的(并且有点反直觉的)结果,正如我们稍后将看到,是当我们对一个实值的损失函数进行优化时,更新变量时应该采取的步骤由 给出(而不是 )。
更多阅读材料请参考:https://arxiv.org/pdf/0906.4835.pdf
Wirtinger 微分在优化中有何用途?¶
音频及其他领域的研究人员更常见地使用梯度下降法来优化带有复数变量的实值损失函数。通常情况下,他们将实部和虚部值作为独立的、可以被更新的通道来处理。对于步长 和损失 ,我们可以在 中写出以下方程:
这些方程如何转化为复空间 ?
发生了一件非常有趣的事情:Wirtinger 微积分告诉我们,可以将上面的复变量更新公式简化为仅引用共轭 Wirtinger 导数 ,这正是我们在优化中迈出的步骤。
因为共轭 Wirtinger 导数对于实值损失函数给出了恰好正确的步骤,所以在对具有实值损失的函数进行微分时,PyTorch 会提供这个导数。
PyTorch 如何计算共轭 Wirtinger 导数?¶
通常,我们的导数公式将 grad_output 作为输入,表示我们已经计算出的传入向量-雅可比积(aka,),其中 是整个计算的损失(产生实值损失), 是我们函数的输出。这里的目标是计算 ,其中 是函数的输入。结果表明,在实值损失的情况下,即使链式法则意味着我们也需要访问 ,我们也可以通过仅计算 来完成。如果您想跳过这个推导,请查看本节的最后一个公式,然后跳到下一节。
让我们继续使用定义为 的函数 。如上所述,autograd 的梯度约定是围绕实值损失函数的优化展开的,因此我们假设 是一个更大的实值损失函数 的一部分。使用链式法则,我们可以写出
(1)¶
现在,使用 Wirtinger 导数的定义,我们可以写出
这里需要注意的是,由于 和 是实函数,并且根据我们的假设 是实值函数的一部分, 是实数,因此我们有
(2)¶
即, 等于 。
求解上述方程得到 和 :
(3)¶
利用 (2),我们得到
(4)¶
最后一个方程对于编写自己的梯度非常重要,因为它将我们的导数公式分解为一个更容易手动计算的简单形式。
如何为一个复函数编写我自己的导数公式?¶
上面方框中的方程给出了所有复函数导数的一般公式。然而,我们仍然需要计算 和 。有两种方法可以做到这一点
第一种方法是直接使用 Wirtinger 导数的定义,并通过 和 (你可以用常规方法计算它们)来计算 和 。
第二种方法是使用换元法技巧,将 重写为一个双变量函数 ,并通过将 和 视为自变量来计算共轭 Wirtinger 导数。这通常更容易;例如,如果所讨论的函数是全纯函数,则只会使用 (且 将为零)。
让我们以函数 为例,其中 。
使用第一种方法计算 Wirtinger 导数,我们得到:
使用 (4),以及 grad_output = 1.0(这是在 PyTorch 中对标量输出调用 backward() 时使用的默认梯度输出值),我们得到
使用第二种计算 Wirtinger 导数的方法,我们直接得到
再次使用 (4),我们得到 。正如你所见,第二种方法计算量更小,对于更快的计算非常方便。
保存张量的钩子¶
你可以通过定义一对 pack_hook / unpack_hook 钩子来控制 保存的张量如何打包/解包。pack_hook 函数应接受一个张量作为其唯一参数,但可以返回任何 Python 对象(例如另一个张量、一个元组,甚至包含文件名的字符串)。unpack_hook 函数接受 pack_hook 的输出作为其唯一参数,并且应该返回一个在反向传播中使用的张量。由 unpack_hook 返回的张量只需与传递给 pack_hook 的输入张量具有相同的内容即可。特别是,任何与自动微分相关的元数据都可以被忽略,因为它们将在解包时被覆盖。
一个这样的对例如下
class SelfDeletingTempFile():
def __init__(self):
self.name = os.path.join(tmp_dir, str(uuid.uuid4()))
def __del__(self):
os.remove(self.name)
def pack_hook(tensor):
temp_file = SelfDeletingTempFile()
torch.save(tensor, temp_file.name)
return temp_file
def unpack_hook(temp_file):
return torch.load(temp_file.name)
请注意,unpack_hook 不应删除临时文件,因为它可能会被多次调用:临时文件应该在返回的 SelfDeletingTempFile 对象存活期间一直存在。在上面的例子中,我们通过在不再需要时(即 SelfDeletingTempFile 对象被删除时)关闭临时文件来防止其泄露。
注意
我们保证 pack_hook 只会被调用一次,但 unpack_hook 可以根据反向传播的需要被调用多次,并且我们期望它每次都返回相同的数据。
警告
禁止对任何函数的输入执行原地操作,因为它们可能导致意外的副作用。如果对 pack 钩子的输入进行原地修改,PyTorch 将抛出错误,但不会捕获对 unpack 钩子的输入进行原地修改的情况。
注册保存张量的钩子¶
你可以通过调用 register_hooks() 方法在一个 SavedTensor 对象上注册一对钩子。这些对象作为 grad_fn 的属性暴露出来,并以 _raw_saved_ 前缀开头。
x = torch.randn(5, requires_grad=True)
y = x.pow(2)
y.grad_fn._raw_saved_self.register_hooks(pack_hook, unpack_hook)
pack_hook 方法在注册这对钩子后立即被调用。unpack_hook 方法在每次需要访问保存的张量时被调用,无论是通过 y.grad_fn._saved_self 还是在反向传播过程中。
警告
如果在保存的张量被释放后(即在调用 backward 后)你仍然持有对 SavedTensor 的引用,则禁止调用其 register_hooks()。PyTorch 在大多数情况下会抛出错误,但在某些情况下可能无法做到,并可能导致未定义的行为。
注册保存张量的默认钩子¶
或者,你可以使用上下文管理器 saved_tensors_hooks 来注册一对钩子,这对钩子将应用于在该上下文中创建的 所有 保存的张量。
示例
# Only save on disk tensors that have size >= 1000
SAVE_ON_DISK_THRESHOLD = 1000
def pack_hook(x):
if x.numel() < SAVE_ON_DISK_THRESHOLD:
return x
temp_file = SelfDeletingTempFile()
torch.save(tensor, temp_file.name)
return temp_file
def unpack_hook(tensor_or_sctf):
if isinstance(tensor_or_sctf, torch.Tensor):
return tensor_or_sctf
return torch.load(tensor_or_sctf.name)
class Model(nn.Module):
def forward(self, x):
with torch.autograd.graph.saved_tensors_hooks(pack_hook, unpack_hook):
# ... compute output
output = x
return output
model = Model()
net = nn.DataParallel(model)
使用此上下文管理器定义的钩子是线程本地的。因此,以下代码不会产生预期效果,因为钩子不经过 DataParallel。
# Example what NOT to do
net = nn.DataParallel(model)
with torch.autograd.graph.saved_tensors_hooks(pack_hook, unpack_hook):
output = net(input)
请注意,使用这些钩子会禁用所有用于减少 Tensor 对象创建的现有优化。例如
with torch.autograd.graph.saved_tensors_hooks(lambda x: x, lambda x: x):
x = torch.randn(5, requires_grad=True)
y = x * x
如果没有钩子,x, y.grad_fn._saved_self 和 y.grad_fn._saved_other 都指向同一个张量对象。有了钩子,PyTorch 会将 x 打包并解包成两个新的张量对象,它们与原始的 x 共享相同的存储空间(没有进行复制)。
反向传播钩子执行¶
本节将讨论不同钩子何时触发以及何时不触发,然后讨论它们被触发的顺序。将涵盖的钩子包括:通过 torch.Tensor.register_hook() 注册到 Tensor 的反向传播钩子,通过 torch.Tensor.register_post_accumulate_grad_hook() 注册到 Tensor 的梯度累积后钩子,通过 torch.autograd.graph.Node.register_hook() 注册到 Node 的后钩子,以及通过 torch.autograd.graph.Node.register_prehook() 注册到 Node 的前钩子。
特定钩子是否会被触发¶
通过 torch.Tensor.register_hook() 注册到 Tensor 的钩子在该 Tensor 计算梯度时执行。(请注意,这不需要执行该 Tensor 的 grad_fn。例如,如果该 Tensor 作为 torch.autograd.grad() 的 inputs 参数的一部分传递,则该 Tensor 的 grad_fn 可能不会执行,但注册到该 Tensor 的钩子将始终执行。)
通过 torch.Tensor.register_post_accumulate_grad_hook() 注册到 Tensor 的钩子在该 Tensor 的梯度累积完成后执行,这意味着该 Tensor 的 grad 字段已被设置。通过 torch.Tensor.register_hook() 注册的钩子在计算梯度时运行,而通过 torch.Tensor.register_post_accumulate_grad_hook() 注册的钩子仅在反向传播结束时自动微分更新该 Tensor 的 grad 字段后触发。因此,梯度累积后钩子只能注册到叶张量上。即使你调用 backward(retain_graph=True),在非叶张量上通过 torch.Tensor.register_post_accumulate_grad_hook() 注册钩子也会出错。
使用 torch.autograd.graph.Node.register_hook() 或 torch.autograd.graph.Node.register_prehook() 注册到 torch.autograd.graph.Node 的钩子只有在该 Node 被执行时才会被触发。
特定的 Node 是否会被执行可能取决于反向传播是使用 torch.autograd.grad() 还是 torch.autograd.backward() 调用的。具体来说,当你在与作为 inputs 参数的一部分传递给 torch.autograd.grad() 或 torch.autograd.backward() 的 Tensor 相对应的 Node 上注册钩子时,应该注意这些区别。
如果你使用 torch.autograd.backward(),所有上述提到的钩子都将执行,无论你是否指定了 inputs 参数。这是因为 .backward() 会执行所有 Node,即使它们对应于指定为输入的 Tensor。(请注意,执行与作为 inputs 传递的 Tensor 相对应的这个额外 Node 通常是不必要的,但无论如何都会执行。此行为可能会更改;你不应依赖它。)
另一方面,如果你使用 torch.autograd.grad(),注册到与传递给 input 的 Tensors 相对应的 Node 的反向传播钩子可能不会被执行,因为除非有另一个输入依赖于此 Node 的梯度结果,否则这些 Node 将不会执行。
不同钩子被触发的顺序¶
事件发生的顺序是
注册到 Tensor 的钩子执行
注册到 Node 的前钩子执行(如果 Node 被执行)。
对于保留梯度的 Tensors,
.grad字段被更新Node 被执行(取决于上面的规则)
对于
.grad已累积的叶张量,梯度累积后钩子执行注册到 Node 的后钩子执行(如果 Node 被执行)
如果在同一个 Tensor 或 Node 上注册了多个相同类型的钩子,它们将按照注册的顺序执行。后执行的钩子可以观察到先执行的钩子对梯度所做的修改。
特殊钩子¶
torch.autograd.graph.register_multi_grad_hook() 是使用注册到 Tensors 的钩子实现的。每个单独的 Tensor 钩子按照上面定义的 Tensor 钩子顺序触发,并且注册的多梯度钩子在计算完最后一个 Tensor 梯度时被调用。
torch.nn.modules.module.register_module_full_backward_hook() 是使用注册到 Node 的钩子实现的。在计算前向传播时,钩子被注册到与模块的输入和输出相对应的 grad_fn。因为一个模块可能接受多个输入并返回多个输出,所以在前向传播之前先对模块的输入应用一个虚拟的自定义 autograd Function,并在返回前向传播的输出之前对模块的输出应用该函数,以确保这些 Tensor 共享一个单独的 grad_fn,然后我们可以将我们的钩子附加到该 grad_fn 上。
张量在原地修改时其钩子的行为¶
通常,注册到 Tensor 的钩子接收输出相对于该 Tensor 的梯度,其中 Tensor 的值被认为是计算反向传播时的值。
但是,如果你注册钩子到 Tensor,然后对该 Tensor 进行原地修改,则在原地修改之前注册的钩子同样接收输出相对于该 Tensor 的梯度,但 Tensor 的值被认为是其在原地修改之前的值。
如果你偏好前一种情况的行为,你应该在该 Tensor 的所有原地修改完成后再注册钩子。例如
t = torch.tensor(1., requires_grad=True).sin()
t.cos_()
t.register_hook(fn)
t.backward()
此外,了解幕后情况也有帮助,即当钩子注册到 Tensor 时,它们实际上永久绑定到该 Tensor 的 grad_fn,因此如果该 Tensor 随后被原地修改,即使 Tensor 现在有了新的 grad_fn,在原地修改之前注册的钩子仍将与旧的 grad_fn 相关联,例如,当自动微分引擎在图中到达该 Tensor 的旧 grad_fn 时,这些钩子就会触发。