自动微分机制¶
本笔记将概述自动微分的工作原理以及如何记录运算。了解所有这些内容并非严格必要,但我们建议您熟悉它们,因为它们将帮助您编写更高效、更清晰的程序,并有助于您进行调试。
自动微分如何编码历史记录¶
自动微分是一种反向自动微分系统。从概念上讲,自动微分会在您执行运算时记录一个记录所有创建数据的运算的图,从而为您提供一个有向无环图,其叶子是输入张量,根是输出张量。通过从根到叶跟踪此图,您可以使用链式法则自动计算梯度。
在内部,自动微分将此图表示为 Function
对象(实际上是表达式)的图,可以对这些对象进行 apply()
以计算评估图的结果。在计算正向传播时,自动微分会同时执行请求的计算,并构建一个代表计算梯度的函数的图(每个 torch.Tensor
的 .grad_fn
属性是此图的入口点)。在正向传播完成后,我们在反向传播中评估此图以计算梯度。
需要注意的是,图在每次迭代时都是从头开始重新创建的,这正是允许使用任意 Python 控制流语句的原因,这些语句可以在每次迭代时更改图的整体形状和大小。您不必在启动训练之前对所有可能的路径进行编码 - 您运行的内容就是您要区分的内容。
保存的张量¶
某些运算需要在正向传播期间保存中间结果才能执行反向传播。例如,函数 会保存输入 以计算梯度。
在定义自定义 Python Function
时,您可以使用 save_for_backward()
在正向传播期间保存张量,并使用 saved_tensors
在反向传播期间检索它们。有关更多信息,请参见 扩展 PyTorch。
对于 PyTorch 定义的运算(例如,torch.pow()
),张量将根据需要自动保存。您可以探索(出于教育或调试目的)某些 grad_fn
保存了哪些张量,方法是查看其以 _saved
为前缀的属性。
x = torch.randn(5, requires_grad=True)
y = x.pow(2)
print(x.equal(y.grad_fn._saved_self)) # True
print(x is y.grad_fn._saved_self) # True
在前面的代码中,y.grad_fn._saved_self
指的是与 x 相同的张量对象。但情况并非总是如此。例如
x = torch.randn(5, requires_grad=True)
y = x.exp()
print(y.equal(y.grad_fn._saved_result)) # True
print(y is y.grad_fn._saved_result) # False
在幕后,为了防止出现引用循环,PyTorch 在保存时对张量进行了打包,并在读取时将其解包到不同的张量中。在这里,您从访问 y.grad_fn._saved_result
获得的张量与 y
是不同的张量对象(但它们仍然共享相同的存储)。
是否将张量打包到不同的张量对象中取决于它是否是其自身 grad_fn 的输出,这是一种实现细节,可能会发生变化,并且用户不应该依赖它。
您可以使用 保存的张量挂钩 控制 PyTorch 如何进行打包/解包。
不可微函数的梯度¶
使用自动微分计算梯度仅在使用每个基本函数可微分时有效。不幸的是,我们在实践中使用的许多函数不具备此属性(例如,relu
或 sqrt
在 0
处)。为了尽量减少不可微分函数的影响,我们通过按以下顺序应用规则来定义基本运算的梯度。
如果函数可微分,并且在当前点存在梯度,则使用它。
如果函数是凸的(至少在局部),则使用最小范数的次梯度(它是最陡下降方向)。
如果函数是凹的(至少在局部),则使用最小范数的超梯度(考虑 -f(x) 并应用前一点)。
如果函数已定义,则通过连续性定义当前点的梯度(请注意,这里可能存在
inf
,例如对于sqrt(0)
)。如果有多个值可能,则任意选择一个。如果函数未定义(例如
sqrt(-1)
,log(-1)
或当输入为NaN
时的大多数函数),则用作梯度的值是任意的(我们也可能引发错误,但不能保证)。大多数函数将使用NaN
作为梯度,但出于性能原因,某些函数将使用其他值(例如log(-1)
)。如果函数不是确定性映射(即它不是 数学函数),它将被标记为不可微分。这将在反向传播过程中导致错误,如果它在需要梯度的张量上使用,而不是在
no_grad
环境中使用。
本地禁用梯度计算¶
Python 提供了多种机制来本地禁用梯度计算。
要禁用整个代码块的梯度,可以使用上下文管理器,例如 no-grad 模式和推理模式。对于更细粒度的从梯度计算中排除子图,可以使用设置张量的 requires_grad
字段。
下面,除了讨论上述机制之外,我们还将介绍评估模式(nn.Module.eval()
),这是一种不用于禁用梯度计算的方法,但由于其名称,经常与这三种方法混淆。
设置 requires_grad
¶
requires_grad
是一个标志,默认为 false,除非它被包装在 nn.Parameter
中,它允许从梯度计算中细粒度地排除子图。它在正向和反向传递中生效。
在正向传递期间,只有当至少一个输入张量需要梯度时,才会在反向图中记录运算。在反向传递(.backward()
)期间,只有具有 requires_grad=True
的叶张量才会将其梯度累积到其 .grad
字段中。
请务必注意,即使每个张量都有此标志,但设置它仅对叶张量有意义(没有 grad_fn
的张量,例如 nn.Module
的参数)。非叶张量(具有 grad_fn
的张量)是与它们关联的反向图的张量。因此,它们的梯度将作为中间结果,用于计算需要梯度的叶张量的梯度。从这个定义可以清楚地看出,所有非叶张量都将自动具有 require_grad=True
。
设置 requires_grad
应该是控制模型哪些部分是梯度计算的一部分的主要方式,例如,如果在模型微调过程中需要冻结预训练模型的某些部分。
要冻结模型的某些部分,只需对不想更新的参数应用 .requires_grad_(False)
。如上所述,由于使用这些参数作为输入的计算不会在正向传递中记录,因此它们的 .grad
字段在反向传递中不会更新,因为它们根本不会在反向图中,如预期的那样。
由于这是一种常见的模式,因此 requires_grad
也可以在模块级别使用 nn.Module.requires_grad_()
进行设置。当应用于模块时,.requires_grad_()
会对模块的所有参数生效(默认情况下,这些参数的 requires_grad=True
)。
梯度模式¶
除了设置 requires_grad
之外,还可以从 Python 中选择三种梯度模式,这些模式会影响 PyTorch 如何在内部处理自动微分计算:默认模式(梯度模式)、no-grad 模式和推理模式,所有这些模式都可以通过上下文管理器和装饰器进行切换。
模式 |
排除从反向图中记录运算 |
跳过额外的自动微分跟踪开销 |
在启用模式期间创建的张量可以在以后的梯度模式中使用 |
示例 |
---|---|---|---|---|
默认 |
✓ |
正向传递 |
||
no-grad |
✓ |
✓ |
优化器更新 |
|
推理 |
✓ |
✓ |
数据处理、模型评估 |
默认模式(梯度模式)¶
“默认模式”是我们隐式处于的模式,此时没有启用 no-grad 模式和推理模式等其他模式。与“no-grad 模式”形成对比,默认模式有时也称为“梯度模式”。
关于默认模式,最重要的是它是在 requires_grad
生效的唯一模式。在另外两种模式中,requires_grad
始终被覆盖为 False
。
No-grad 模式¶
no-grad 模式下的计算表现得好像没有输入需要梯度。换句话说,即使存在具有 require_grad=True
的输入,no-grad 模式下的计算也不会记录在反向图中。
当需要执行不应由自动微分记录的运算,但仍然希望在以后的梯度模式中使用这些计算的结果时,启用 no-grad 模式。此上下文管理器使禁用代码块或函数的梯度变得很方便,而无需临时将张量设置为 requires_grad=False
,然后恢复为 True
。
例如,在编写优化器时,no-grad 模式可能很有用:在执行训练更新时,希望就地更新参数,而无需自动微分记录更新。还打算在下一个正向传递中使用更新后的参数进行梯度模式计算。
在 torch.nn.init 中的实现也依赖于 no-grad 模式,在初始化参数时,为了避免在就地更新初始化参数时自动微分跟踪。
推理模式¶
推理模式是 no-grad 模式的极端版本。与 no-grad 模式一样,推理模式下的计算不会记录在反向图中,但启用推理模式将允许 PyTorch 进一步加速模型。这种更好的运行时性能有一个缺点:在推理模式下创建的张量将无法在退出推理模式后用于由自动微分记录的计算中。
当执行与自动微分没有交互的计算,并且不打算在以后的任何由自动微分记录的计算中使用在推理模式下创建的张量时,启用推理模式。
建议在不需要自动微分跟踪的代码部分(例如数据处理和模型评估)中尝试使用推理模式。如果它适用于您的用例,那么这是一次免费的性能提升。如果在启用推理模式后遇到错误,请检查您是否没有在退出推理模式后将推理模式下创建的张量用于由自动微分记录的计算中。如果您无法避免在您的用例中使用此类内容,则可以始终切换回 no-grad 模式。
有关推理模式的详细信息,请参阅 推理模式。
有关推理模式的实现详细信息,请参阅 RFC-0011-InferenceMode。
评估模式(nn.Module.eval()
)¶
评估模式不是本地禁用梯度计算的机制。它仍然包含在这里,因为它有时会被误认为是这种机制。
从功能上讲,module.eval()
(或等效地 module.train(False)
)完全独立于 no-grad 模式和推理模式。model.eval()
如何影响您的模型完全取决于模型中使用的特定模块,以及它们是否定义了任何特定于训练模式的行为。
如果您的模型依赖于诸如 torch.nn.Dropout
和 torch.nn.BatchNorm2d
等模块,这些模块可能会根据训练模式表现出不同的行为,例如为了避免在验证数据上更新 BatchNorm 运行统计信息,您需要调用 model.eval()
和 model.train()
。
建议在训练时始终使用 model.train()
,并在评估模型(验证/测试)时使用 model.eval()
,即使您不确定模型是否具有特定于训练模式的行为,因为您正在使用的模块可能会更新,以便在训练模式和评估模式中表现出不同的行为。
自动微分的就地运算¶
在 autograd 中支持就地操作是一件很困难的事情,我们不建议在大多数情况下使用它们。Autograd 的积极的缓冲区释放和重用机制使其非常高效,只有很少情况下就地操作会显著降低内存使用量。除非您在承受着巨大的内存压力,否则您可能永远不需要使用它们。
有两个主要原因限制了就地操作的适用性
就地操作可能会覆盖计算梯度所需的数值。
每个就地操作都需要实现重写计算图。非就地版本只是分配新的对象并保留对旧图的引用,而就地操作则需要更改所有输入的创建者,以指向表示此操作的
Function
。这可能很棘手,尤其是在有许多引用相同存储空间的张量(例如,通过索引或转置创建)的情况下,而且如果修改后的输入的存储空间被其他任何Tensor
引用,就地函数将引发错误。
就地正确性检查¶
每个张量都保留一个版本计数器,每次在任何操作中标记为脏数据时都会递增。当一个 Function 保存任何张量用于反向传播时,也会保存它们所包含的张量的版本计数器。一旦您访问 self.saved_tensors
,就会进行检查,如果它大于保存的值,就会引发错误。这确保了如果您使用就地函数并且没有看到任何错误,那么您可以确定计算出的梯度是正确的。
多线程 Autograd¶
autograd 引擎负责运行计算反向传播所需的所有反向操作。本节将描述所有可以帮助您在多线程环境中充分利用它的细节。(这仅与 PyTorch 1.6+ 相关,因为之前版本的行为有所不同。)
用户可以使用多线程代码训练他们的模型(例如 Hogwild 训练),并且不会阻塞并发反向传播计算,示例代码可能是
# Define a train function to be used in different threads
def train_fn():
x = torch.ones(5, 5, requires_grad=True)
# forward
y = (x + 3) * (x + 4) * 0.5
# backward
y.sum().backward()
# potential optimizer update
# User write their own threading code to drive the train_fn
threads = []
for _ in range(10):
p = threading.Thread(target=train_fn, args=())
p.start()
threads.append(p)
for p in threads:
p.join()
请注意,用户应该了解一些行为
CPU 上的并发性¶
当您通过 python 或 C++ API 在 CPU 上的多个线程中运行 backward()
或 grad()
时,您期望看到额外的并发性,而不是在执行期间按特定顺序序列化所有反向调用(PyTorch 1.6 之前的行为)。
非确定性¶
如果您从多个线程并发调用 backward()
并且共享了输入(即 Hogwild CPU 训练),那么就应该预期会出现非确定性。这可能是因为参数在各线程之间自动共享,因此,多个线程可能会在梯度累积期间访问并尝试累积相同的 .grad
属性。从技术上讲,这是不安全的,可能会导致竞争条件,结果可能无法使用。
开发具有共享参数的多线程模型的用户应该牢记线程模型,并理解上述问题。
可以使用函数式 API torch.autograd.grad()
计算梯度,而不是 backward()
,以避免非确定性。
图保留¶
如果 autograd 图的一部分在各线程之间共享,即在单个线程中运行正向传播的第一部分,然后在多个线程中运行第二部分,那么图的第一部分将被共享。在这种情况下,不同的线程在同一个图上执行 grad()
或 backward()
可能会出现问题,即一个线程会动态地销毁图,而另一个线程会在这种情况下崩溃。Autograd 会向用户报错,类似于两次调用 backward()
而不使用 retain_graph=True
,并让用户知道他们应该使用 retain_graph=True
。
Autograd 节点上的线程安全性¶
由于 Autograd 允许调用线程驱动其反向传播执行以实现潜在的并行性,因此在 CPU 上使用并行 backward()
调用(共享 GraphTask 的部分/全部内容)时,确保线程安全性非常重要。
自定义 Python autograd.Function
由于 GIL 的原因,自动线程安全。对于内置 C++ Autograd 节点(例如 AccumulateGrad、CopySlices)和自定义 autograd::Function
,Autograd 引擎使用线程互斥锁来确保对可能具有状态写入/读取的 Autograd 节点的线程安全性。
C++ 钩子上的线程安全性¶
Autograd 依赖于用户编写线程安全的 C++ 钩子。如果您希望在多线程环境中正确应用钩子,则需要编写适当的线程锁定代码来确保钩子的线程安全性。
用于复数的 Autograd¶
简而言之
当您使用 PyTorch 对任何具有复数域和/或陪域的函数 进行微分时,梯度是在假设该函数是更大的实值损失函数 的一部分的情况下计算的。计算出的梯度是 (请注意 z 的共轭),它的负值正是梯度下降算法中使用的最速下降方向。因此,在使现有优化器开箱即用地使用复数参数方面,存在一条可行的途径。
此约定与 TensorFlow 中的复数微分约定相匹配,但与 JAX 不同(JAX 计算 )。
如果您有一个实数到实数的函数,它在内部使用复数运算,那么这里的约定无关紧要:您将始终获得与使用仅实数运算实现时相同的结果。
如果您对数学细节感到好奇,或者想知道如何在 PyTorch 中定义复数导数,请继续阅读。
什么是复数导数?¶
复数可微性的数学定义采用导数的极限定义,并将其推广到复数运算。考虑函数 ,
其中 和 是两个变量的实值函数,而 是虚数单位。
利用导数定义,我们可以写成
为了使这个极限存在,不仅 和 必须是实可微的,而且 还必须满足柯西-黎曼方程。换句话说:对于实数和虚数步长计算的极限()必须相等。这是一个更严格的条件。
复可微函数通常被称为全纯函数。它们表现良好,具有从实可微函数中看到的所有良好性质,但在优化领域实际上没有用。对于优化问题,研究界只使用实值目标函数,因为复数不是任何有序域的一部分,因此使用复值损失没有多大意义。
事实证明,没有有趣的实值目标函数满足柯西-黎曼方程。因此,全纯函数理论不能用于优化,大多数人因此使用 Wirtinger 微积分。
Wirtinger 微积分登场 ...¶
所以,我们拥有了复可微性和全纯函数的伟大理论,我们却无法使用任何理论,因为许多常用的函数都不是全纯函数。一个可怜的数学家该怎么办?好吧,Wirtinger 观察到,即使 不是全纯函数,也可以将其改写为一个二元函数 ,该函数始终是全纯函数。这是因为 分量的实部和虚部可以用 和 表示为
Wirtinger 微积分建议研究 ,如果 是实可微的,那么它一定是一个全纯函数(换句话说,它是一个坐标系的改变,从 到 )。该函数有偏导数 和 . 我们可以使用链式法则建立这些偏导数与 的实部和虚部的偏导数之间的关系。
从上面的等式中,我们得到
这正是您在 维基百科 上看到的 Wirtinger 微积分的经典定义。
这种变化带来了很多美妙的结果。
其中一个结果是,柯西-黎曼方程简化为 (也就是说,函数 可以完全用 表示,而无需引用 )。
另一个重要的(也是有点违反直觉的)结果是,正如我们稍后将看到的,当我们在实值损失上进行优化时,我们应该在进行变量更新时采取的步骤由 (而不是 )。
要了解更多信息,请查看:https://arxiv.org/pdf/0906.4835.pdf
Wirtinger 微积分在优化中如何有用?¶
音频和其他领域的学者更常见地使用梯度下降来优化具有复数变量的实值损失函数。通常,这些人将实部和虚部视为可以更新的独立通道。对于步长 和损失 ,我们可以在 中写下以下方程式。
这些方程如何转换为复数空间 ?
有趣的是,Wirtinger 微积分告诉我们,我们可以将上述复变量更新公式简化为仅引用共轭 Wirtinger 导数 , 这恰好是我们在优化中采取的步骤。
由于共轭 Wirtinger 导数为我们提供了实值损失函数的正确优化方向,因此当您对具有实值损失的函数求导时,PyTorch 会为您提供此导数。
PyTorch 如何计算共轭 Wirtinger 导数?¶
通常,我们的导数公式以 grad_output 作为输入,表示我们已经计算出的传入向量-雅可比积,也称为 ,其中 是整个计算的损失(产生真实的损失),而 是我们函数的输出。这里的目标是计算 ,其中 是函数的输入。事实证明,在真实损失的情况下,我们只需要计算 ,即使链式法则意味着我们还需要访问 。如果您想跳过此推导,请查看本节中的最后一个等式,然后跳到下一节。
让我们继续使用定义为 的 。如上所述,autograd 的梯度约定以针对实值损失函数的优化为中心,因此让我们假设 是更大的实值损失函数 的一部分。使用链式法则,我们可以写成
(1)¶
这里需要注意的是,由于 和 是实函数,并且 根据我们假设 是一个实值函数的一部分,所以是实数,我们有
(2)¶
也就是说, 等于 .
解以上关于 和 的方程,我们得到
(3)¶
使用 (2),得到
(4)¶
这个最后一个等式对于编写你自己的梯度非常重要,因为它将我们的导数公式分解成一个更容易手算的简单公式。
我如何为一个复杂函数编写我自己的导数公式?¶
上面方框中的等式給出了所有复函数导数的一般公式。但是,我们仍然需要计算 和 。有两种方法可以做到这一点。
第一种方法是直接使用 Wirtinger 导数的定义,并计算 和 ,使用 和 (你可以用正常的方式计算)。
第二种方法是使用变量替换技巧,将 重写为一个二元函数 , 并通过将 和 视为独立变量来计算共轭 Wirtinger 导数。这通常更容易;例如,如果所讨论的函数是全纯的,那么只会使用 (并且 将为零)。
让我们考虑函数 作为一个例子,其中 .
使用第一种方法计算 Wirtinger 导数,我们有。
使用 (4) 和 grad_output = 1.0(当在 PyTorch 中对标量输出调用 backward()
时使用的默认梯度输出值),我们得到
使用第二种计算 Wirtinger 导数的方法,我们可以直接得到
再次使用 (4),我们得到 . 正如你所看到的,第二种方法涉及更少的计算,并且对于更快的计算更加方便。
保存张量的钩子¶
您可以通过定义一对 pack_hook
/ unpack_hook
钩子来控制 保存张量的打包/解包方式。 pack_hook
函数应将张量作为其唯一参数,但可以返回任何 Python 对象(例如另一个张量、元组,甚至包含文件名)的字符串)。 unpack_hook
函数将 pack_hook
的输出作为其唯一参数,并且应返回一个将在反向传递中使用的张量。 unpack_hook
返回的张量只需要与传递给 pack_hook
的张量具有相同的内容。特别是,任何与 autograd 相关的元数据都可以忽略,因为它们将在解包期间被覆盖。
这样一对的示例是
class SelfDeletingTempFile():
def __init__(self):
self.name = os.path.join(tmp_dir, str(uuid.uuid4()))
def __del__(self):
os.remove(self.name)
def pack_hook(tensor):
temp_file = SelfDeletingTempFile()
torch.save(tensor, temp_file.name)
return temp_file
def unpack_hook(temp_file):
return torch.load(temp_file.name)
请注意, unpack_hook
不应删除临时文件,因为它可能会被多次调用:临时文件应该在返回的 SelfDeletingTempFile 对象存活的时间内一直存活。在上面的例子中,我们通过在不再需要它时关闭它(在 SelfDeletingTempFile 对象删除时)来防止泄漏临时文件。
注意
我们保证 pack_hook
只会被调用一次,但 unpack_hook
可以根据反向传递的要求被调用多次,并且我们希望它每次返回相同的数据。
警告
对任何函数的输入执行就地操作是被禁止的,因为它们可能会导致意想不到的副作用。如果对打包钩子的输入进行就地修改,PyTorch 将抛出一个错误,但它不会捕获对解包钩子的输入进行就地修改的情况。
为保存的张量注册钩子¶
您可以通过调用 register_hooks()
方法在 SavedTensor
对象上为保存的张量注册一对钩子。这些对象作为 grad_fn
的属性公开,并以 _raw_saved_
前缀开头。
x = torch.randn(5, requires_grad=True)
y = x.pow(2)
y.grad_fn._raw_saved_self.register_hooks(pack_hook, unpack_hook)
pack_hook
方法在注册对时立即被调用。 unpack_hook
方法在每次需要访问保存的张量时被调用,无论是通过 y.grad_fn._saved_self
还是在反向传递期间。
警告
如果您在保存的张量被释放后(即在调用反向传递后)仍然保留对 SavedTensor
的引用,则禁止调用其 register_hooks()
。 PyTorch 通常会抛出错误,但在某些情况下它可能无法这样做,并且可能会出现未定义的行为。
为保存的张量注册默认钩子¶
或者,您可以使用上下文管理器 saved_tensors_hooks
注册一对钩子,它们将应用于在此上下文中创建的所有保存的张量。
示例
# Only save on disk tensors that have size >= 1000
SAVE_ON_DISK_THRESHOLD = 1000
def pack_hook(x):
if x.numel() < SAVE_ON_DISK_THRESHOLD:
return x
temp_file = SelfDeletingTempFile()
torch.save(tensor, temp_file.name)
return temp_file
def unpack_hook(tensor_or_sctf):
if isinstance(tensor_or_sctf, torch.Tensor):
return tensor_or_sctf
return torch.load(tensor_or_sctf.name)
class Model(nn.Module):
def forward(self, x):
with torch.autograd.graph.saved_tensors_hooks(pack_hook, unpack_hook):
# ... compute output
output = x
return output
model = Model()
net = nn.DataParallel(model)
使用此上下文管理器定义的钩子是线程本地的。因此,以下代码不会产生预期的效果,因为钩子不会通过 DataParallel 传递。
# Example what NOT to do
net = nn.DataParallel(model)
with torch.autograd.graph.saved_tensors_hooks(pack_hook, unpack_hook):
output = net(input)
请注意,使用这些钩子会禁用所有优化到位以减少 Tensor 对象的创建。例如
with torch.autograd.graph.saved_tensors_hooks(lambda x: x, lambda x: x):
x = torch.randn(5, requires_grad=True)
y = x * x
如果没有钩子, x
、 y.grad_fn._saved_self
和 y.grad_fn._saved_other
都引用了相同的张量对象。使用钩子,PyTorch 将把 x 打包和解包到两个新的张量对象中,这两个对象与原始的 x 共享相同的存储(没有执行复制)。
反向钩子执行¶
本节将讨论不同钩子何时触发或不触发。然后将讨论它们的触发顺序。将涵盖的钩子包括:通过 torch.Tensor.register_hook()
注册到 Tensor 的反向钩子,通过 torch.Tensor.register_post_accumulate_grad_hook()
注册到 Tensor 的后累积梯度钩子,通过 torch.autograd.graph.Node.register_hook()
注册到 Node 的后钩子,以及通过 torch.autograd.graph.Node.register_prehook()
注册到 Node 的前钩子。
特定钩子是否会触发¶
通过 torch.Tensor.register_hook()
注册到 Tensor 的钩子在为该 Tensor 计算梯度时执行。(请注意,这并不需要执行 Tensor 的 grad_fn。例如,如果 Tensor 作为 inputs
参数传递给 torch.autograd.grad()
,则 Tensor 的 grad_fn 可能不会执行,但注册到该 Tensor 的钩子将始终执行。)
通过 torch.Tensor.register_post_accumulate_grad_hook()
注册到 Tensor 的钩子在为该 Tensor 累积梯度后执行,这意味着 Tensor 的 grad 字段已设置。而通过 torch.Tensor.register_hook()
注册的钩子在计算梯度时运行,通过 torch.Tensor.register_post_accumulate_grad_hook()
注册的钩子只有在反向传播结束时 Tensor 的 grad 字段被 autograd 更新后才会触发。因此,后累积梯度钩子只能为叶 Tensor 注册。即使调用 backward(retain_graph=True),在非叶 Tensor 上通过 torch.Tensor.register_post_accumulate_grad_hook()
注册钩子也会出错。
使用 torch.autograd.graph.Node.register_hook()
或 torch.autograd.graph.Node.register_prehook()
注册到 torch.autograd.graph.Node
的钩子仅在注册到的 Node 被执行时才会触发。
特定 Node 是否执行可能取决于反向传播是否使用 torch.autograd.grad()
或 torch.autograd.backward()
调用。具体来说,当您在与您传递给 torch.autograd.grad()
或 torch.autograd.backward()
作为 inputs
参数一部分的 Tensor 相对应的 Node 上注册钩子时,您应该注意这些区别。
如果您使用的是 torch.autograd.backward()
,则上面提到的所有钩子都会执行,无论您是否指定了 inputs
参数。这是因为 .backward() 执行所有 Node,即使它们对应于指定为输入的 Tensor。(请注意,与作为 inputs
传递的 Tensor 相对应的此附加 Node 的执行通常是不必要的,但无论如何都会执行。此行为可能会发生变化;您不应该依赖它。)
另一方面,如果您使用的是 torch.autograd.grad()
,则注册到与传递给 input
的 Tensor 相对应的 Node 的反向钩子可能不会执行,因为这些 Node 不会执行,除非存在另一个依赖于此 Node 梯度结果的输入。
不同钩子的触发顺序¶
事情发生的顺序是
注册到 Tensor 的钩子执行
注册到 Node 的前钩子执行(如果 Node 被执行)。
为保留梯度的 Tensor 更新
.grad
字段Node 执行(根据上述规则)
对于已累积
.grad
的叶 Tensor,执行后累积梯度钩子注册到 Node 的后钩子执行(如果 Node 被执行)
如果在同一个 Tensor 或 Node 上注册了多个相同类型的钩子,它们将按照注册顺序执行。较晚执行的钩子可以观察到较早钩子对梯度所做的修改。
特殊钩子¶
torch.autograd.graph.register_multi_grad_hook()
是使用注册到 Tensor 的钩子实现的。每个单独的 Tensor 钩子在上面定义的 Tensor 钩子排序后触发,并且在最后一个 Tensor 梯度计算完成后调用注册的多梯度钩子。
torch.nn.modules.module.register_module_full_backward_hook()
是使用注册到 Node 的钩子实现的。在计算前向传播时,会向与模块的输入和输出相对应的 grad_fn 注册钩子。因为一个模块可能接受多个输入并返回多个输出,所以在前向传播之前首先将一个虚拟自定义 autograd 函数应用于模块的输入,并在前向传播的输出返回之前将模块的输出应用于模块,以确保这些 Tensor 共享一个 grad_fn,然后我们可以将我们的钩子附加到它。
Tensor 被就地修改时 Tensor 钩子的行为¶
通常,注册到 Tensor 的钩子会接收输出相对于该 Tensor 的梯度,其中 Tensor 的值被认为是它在计算反向传播时的值。
但是,如果您向 Tensor 注册钩子,然后就地修改该 Tensor,则在就地修改之前注册的钩子也会接收输出相对于该 Tensor 的梯度,但 Tensor 的值被认为是在就地修改之前的值。
如果您更喜欢前一种情况下的行为,您应该在对其进行所有就地修改后向 Tensor 注册它们。例如
t = torch.tensor(1., requires_grad=True).sin()
t.cos_()
t.register_hook(fn)
t.backward()
此外,了解在幕后,当向 Tensor 注册钩子时,它们实际上会永久绑定到该 Tensor 的 grad_fn,因此如果该 Tensor 随后被就地修改,即使该 Tensor 现在具有一个新的 grad_fn,在就地修改之前注册的钩子将继续与旧的 grad_fn 相关联,例如,当 autograd 引擎在图中到达该 Tensor 的旧 grad_fn 时,它们会触发。