自动微分机制¶
本笔记将概述自动微分的工作原理以及如何记录操作。严格来说,了解所有这些内容并非必要,但我们建议您熟悉它,因为它将帮助您编写更高效、更简洁的程序,并有助于您进行调试。
自动微分如何编码历史记录¶
自动微分是一种反向自动微分系统。从概念上讲,自动微分记录一个图,该图记录了您执行操作时创建所有数据的操作,从而为您提供一个有向无环图,其叶子是输入张量,根是输出张量。通过从根到叶跟踪此图,您可以使用链式法则自动计算梯度。
在内部,自动微分将此图表示为 Function
对象(实际上是表达式)的图,这些对象可以 apply()
ed 来计算评估图的结果。在计算前向传递时,自动微分同时执行请求的计算并构建一个图,该图表示计算梯度的函数(每个 torch.Tensor
的 .grad_fn
属性是此图的入口点)。当完成前向传递时,我们在反向传递中评估此图以计算梯度。
需要注意的是,图在每次迭代时都会从头开始重新创建,这正是允许使用任意 Python 控制流语句的原因,这些语句可以在每次迭代时更改图的整体形状和大小。您不必在启动训练之前对所有可能的路径进行编码 - 您运行什么就微分什么。
保存的张量¶
某些操作需要在正向传播过程中保存中间结果,以便执行反向传播。例如,函数 保存输入 以便计算梯度。
在定义自定义 Python Function
时,可以使用 save_for_backward()
在正向传播过程中保存张量,并使用 saved_tensors
在反向传播过程中检索它们。有关更多信息,请参阅 扩展 PyTorch。
对于 PyTorch 定义的操作(例如 torch.pow()
),张量将根据需要自动保存。您可以探索(出于教育或调试目的)哪些张量被某个 grad_fn
保存,方法是查找其以 _saved
为前缀的属性。
x = torch.randn(5, requires_grad=True)
y = x.pow(2)
print(x.equal(y.grad_fn._saved_self)) # True
print(x is y.grad_fn._saved_self) # True
在前面的代码中,y.grad_fn._saved_self
指的是与 x 相同的张量对象。但这并不总是这样。例如
x = torch.randn(5, requires_grad=True)
y = x.exp()
print(y.equal(y.grad_fn._saved_result)) # True
print(y is y.grad_fn._saved_result) # False
为了防止引用循环,PyTorch 在保存时将张量打包,并在读取时将其解包到另一个张量中。这里,从访问 y.grad_fn._saved_result
获取的张量与 y
是不同的张量对象(但它们仍然共享相同的存储)。
张量是否会被打包到另一个张量对象中取决于它是否是其自身 grad_fn 的输出,这是一个实现细节,可能会发生变化,用户不应该依赖它。
您可以使用 保存张量的钩子 控制 PyTorch 的打包/解包方式。
不可微函数的梯度¶
使用自动微分计算梯度仅在使用每个基本函数都是可微的情况下才有效。不幸的是,我们在实践中使用的许多函数不具有此属性(例如,relu
或 sqrt
在 0
处)。为了尽量减少不可微函数的影响,我们通过按以下顺序应用规则来定义基本运算的梯度
如果函数是可微的,因此在当前点存在梯度,则使用它。
如果函数是凸的(至少在局部),则使用最小范数的次梯度(它是最陡下降方向)。
如果函数是凹的(至少在局部),则使用最小范数的超梯度(考虑 -f(x) 并应用前一点)。
如果函数已定义,则通过连续性定义当前点的梯度(注意,这里可能出现
inf
,例如对于sqrt(0)
)。如果多个值是可能的,则任意选择一个。如果函数未定义(例如,
sqrt(-1)
、log(-1)
或大多数函数在输入为NaN
时),则用作梯度的值是任意的(我们也可能引发错误,但不能保证)。大多数函数将使用NaN
作为梯度,但出于性能原因,某些函数将使用其他值(例如,log(-1)
)。如果函数不是确定性映射(即它不是一个数学函数),它将被标记为不可微分。这将导致它在反向传播中出错,如果它被用于需要梯度的张量,而不在
no_grad
环境中。
局部禁用梯度计算¶
Python 提供了几种机制来局部禁用梯度计算。
要禁用整个代码块的梯度,可以使用上下文管理器,例如无梯度模式和推理模式。对于更细粒度的排除子图进行梯度计算,可以使用设置张量的requires_grad
字段。
下面,除了讨论上述机制之外,我们还将描述评估模式(nn.Module.eval()
),这是一种不用于禁用梯度计算的方法,但由于其名称,经常与这三种方法混淆。
设置requires_grad
¶
requires_grad
是一个标志,默认值为 false,*除非它被包装在* nn.Parameter
中,它允许对子图进行细粒度的排除以进行梯度计算。它在正向和反向传播中都有效。
在正向传播期间,只有当至少一个输入张量需要梯度时,才会在反向图中记录操作。在反向传播期间(.backward()
),只有具有requires_grad=True
的叶张量才会将其梯度累积到它们的.grad
字段中。
需要注意的是,即使每个张量都有这个标志,*设置*它只有对叶张量(没有grad_fn
的张量,例如,nn.Module
的参数)才有意义。非叶张量(具有grad_fn
的张量)是与它们关联的反向图的张量。因此,它们的梯度将作为中间结果需要来计算需要梯度的叶张量的梯度。从这个定义可以清楚地看出,所有非叶张量将自动具有require_grad=True
。
设置 requires_grad
应该是控制模型哪些部分参与梯度计算的主要方式,例如,如果您需要在模型微调期间冻结预训练模型的某些部分。
要冻结模型的某些部分,只需对不想更新的参数应用 .requires_grad_(False)
。如上所述,由于使用这些参数作为输入的计算不会在正向传递中记录,因此它们的 .grad
字段不会在反向传递中更新,因为它们根本不会成为反向图的一部分,如预期的那样。
由于这是一种非常常见的模式,因此 requires_grad
也可以在模块级别使用 nn.Module.requires_grad_()
设置。当应用于模块时,.requires_grad_()
会对模块的所有参数(默认情况下 requires_grad=True
)生效。
梯度模式¶
除了设置 requires_grad
之外,还可以从 Python 中选择三种梯度模式,这些模式会影响 PyTorch 中的计算在内部如何由 autograd 处理:默认模式(梯度模式)、无梯度模式和推理模式,所有这些模式都可以通过上下文管理器和装饰器进行切换。
模式 |
将操作排除在反向图的记录之外 |
跳过额外的 autograd 跟踪开销 |
在启用模式期间创建的张量可以在以后的梯度模式中使用 |
示例 |
---|---|---|---|---|
默认 |
✓ |
正向传递 |
||
无梯度 |
✓ |
✓ |
优化器更新 |
|
推理 |
✓ |
✓ |
数据处理、模型评估 |
默认模式(梯度模式)¶
“默认模式”是我们隐式处于的模式,当没有启用其他模式(如无梯度模式和推理模式)时。与“无梯度模式”形成对比,默认模式有时也称为“梯度模式”。
关于默认模式,最重要的是,它是唯一一个 requires_grad
生效的模式。在另外两种模式中,requires_grad
始终被覆盖为 False
。
无梯度模式¶
在无梯度模式下进行的计算,其行为就好像没有输入需要梯度一样。换句话说,即使存在具有 require_grad=True
的输入,在无梯度模式下的计算也不会被记录在反向图中。
当您需要执行不应该被自动微分记录的操作,但您仍然希望在以后的梯度模式下使用这些计算的结果时,启用无梯度模式。此上下文管理器可以方便地为代码块或函数禁用梯度,而无需临时将张量设置为 requires_grad=False
,然后再恢复为 True
。
例如,在编写优化器时,无梯度模式可能很有用:在执行训练更新时,您希望就地更新参数,而无需将更新记录到自动微分中。您还打算在下一个前向传递中使用更新后的参数进行梯度模式下的计算。
中的实现 torch.nn.init 也依赖于无梯度模式,在初始化参数时,为了避免在就地更新初始化参数时进行自动微分跟踪。
推理模式¶
推理模式是无梯度模式的极端版本。与无梯度模式一样,推理模式下的计算不会被记录在反向图中,但启用推理模式将允许 PyTorch 进一步加速您的模型。这种更好的运行时性能有一个缺点:在推理模式下创建的张量将无法在退出推理模式后用于自动微分记录的计算中。
当您执行不需要在反向图中记录的计算,并且您不打算在以后的任何由自动微分记录的计算中使用在推理模式下创建的张量时,启用推理模式。
建议您在代码中不需要自动微分跟踪的部分(例如,数据处理和模型评估)尝试使用推理模式。如果它适用于您的用例,那么它就是一个免费的性能提升。如果您在启用推理模式后遇到错误,请检查您是否在退出推理模式后,在由自动微分记录的计算中使用了在推理模式下创建的张量。如果您无法避免在您的用例中使用此类张量,您始终可以切换回无梯度模式。
有关推理模式的详细信息,请参阅 推理模式。
有关推理模式的实现细节,请参阅 RFC-0011-InferenceMode。
评估模式 (nn.Module.eval()
)¶
评估模式不是本地禁用梯度计算的机制。它在这里被包含进来,因为有时会被误认为是这种机制。
从功能上讲,module.eval()
(或等效地 module.train(False)
)与无梯度模式和推理模式完全正交。model.eval()
如何影响您的模型完全取决于模型中使用的特定模块以及它们是否定义了任何训练模式特定的行为。
如果您模型依赖于诸如 torch.nn.Dropout
和 torch.nn.BatchNorm2d
这样的模块,这些模块可能在训练模式下表现不同,例如,为了避免在验证数据上更新您的 BatchNorm 运行统计信息,您有责任调用 model.eval()
和 model.train()
。
建议您在训练时始终使用 model.train()
,在评估模型(验证/测试)时始终使用 model.eval()
,即使您不确定您的模型是否具有训练模式特定的行为,因为您正在使用的模块可能会更新为在训练和评估模式下表现不同。
带自动梯度的就地操作¶
在自动梯度中支持就地操作是一件很困难的事情,我们不鼓励在大多数情况下使用它们。自动梯度的积极缓冲区释放和重用使其非常高效,只有在极少数情况下,就地操作才能显着降低内存使用量。除非您在承受巨大的内存压力,否则您可能永远不需要使用它们。
有两个主要原因限制了就地操作的适用性
就地操作可能会覆盖计算梯度所需的数值。
每个就地操作都需要实现重写计算图。非就地版本只需分配新的对象并保留对旧图的引用,而就地操作则需要更改所有输入的创建者,以指向表示此操作的
Function
。这可能很棘手,尤其是在存在许多引用相同存储的张量(例如,通过索引或转置创建)的情况下,如果修改后的输入的存储被任何其他Tensor
引用,则就地函数将引发错误。
就地正确性检查¶
每个张量都保留一个版本计数器,每次在任何操作中标记为脏时都会递增。当一个 Function 保存任何张量以进行反向传播时,也会保存其包含的张量的版本计数器。一旦你访问 self.saved_tensors
,就会进行检查,如果它大于保存的值,则会引发错误。这确保了如果你使用就地函数并且没有看到任何错误,你可以确信计算的梯度是正确的。
多线程自动微分¶
自动微分引擎负责运行计算反向传播所需的所有反向操作。本节将描述所有可以帮助你在多线程环境中充分利用它的细节。(这仅与 PyTorch 1.6+ 相关,因为之前版本的行为有所不同。)
用户可以使用多线程代码(例如 Hogwild 训练)训练他们的模型,并且不会阻塞并发反向计算,示例代码可能是
# Define a train function to be used in different threads
def train_fn():
x = torch.ones(5, 5, requires_grad=True)
# forward
y = (x + 3) * (x + 4) * 0.5
# backward
y.sum().backward()
# potential optimizer update
# User write their own threading code to drive the train_fn
threads = []
for _ in range(10):
p = threading.Thread(target=train_fn, args=())
p.start()
threads.append(p)
for p in threads:
p.join()
请注意,用户应该注意一些行为
CPU 上的并发¶
当您在 CPU 上的多个线程中通过 Python 或 C++ API 运行 backward()
或 grad()
时,您期望看到额外的并发性,而不是在执行期间按特定顺序序列化所有反向调用(PyTorch 1.6 之前的行为)。
非确定性¶
如果您从多个线程并发调用 backward()
并且共享输入(例如 Hogwild CPU 训练),那么应该预期非确定性。这可能发生是因为参数在线程之间自动共享,因此,多个线程可能会访问并尝试在梯度累积期间累积相同的 .grad
属性。从技术上讲,这并不安全,它可能会导致竞争条件,结果可能无效。
开发具有共享参数的多线程模型的用户应该牢记线程模型,并了解上述问题。
可以使用函数式 API torch.autograd.grad()
来计算梯度,而不是 backward()
,以避免非确定性。
图保留¶
如果自动梯度图的一部分在线程之间共享,例如,运行前向的第一部分单线程,然后在多个线程中运行第二部分,那么图的第一部分将被共享。在这种情况下,不同的线程在同一图上执行 grad()
或 backward()
可能会出现一个线程动态销毁图,而另一个线程在这种情况下会崩溃的问题。自动梯度会向用户发出错误,类似于调用 backward()
两次而没有 retain_graph=True
,并让用户知道他们应该使用 retain_graph=True
。
自动梯度节点上的线程安全性¶
由于自动梯度允许调用线程驱动其反向执行以实现潜在的并行性,因此在 CPU 上使用并行 backward()
调用共享部分/全部 GraphTask 时,确保线程安全性非常重要。
自定义 Python autograd.Function
由于 GIL 的存在,自动线程安全。对于内置的 C++ Autograd 节点(例如 AccumulateGrad、CopySlices)和自定义的 autograd::Function
,Autograd 引擎使用线程互斥锁来确保对可能存在状态写入/读取的 Autograd 节点的线程安全。
C++ 钩子没有线程安全¶
Autograd 依赖用户编写线程安全的 C++ 钩子。如果您希望钩子在多线程环境中正确应用,您需要编写适当的线程锁代码以确保钩子是线程安全的。
复数的 Autograd¶
简短版本
When you use PyTorch to differentiate any function with complex domain and/or codomain, the gradients are computed under the assumption that the function is a part of a larger real-valued loss function . The gradient computed is (note the conjugation of z), the negative of which is precisely the direction of steepest descent used in Gradient Descent algorithm. Thus, all the existing optimizers work out of the box with complex parameters.
此约定与 TensorFlow 对复数微分的约定一致,但与 JAX 不同(JAX 计算 ).
如果您有一个内部使用复数运算的实数到实数函数,这里的约定并不重要:您将始终获得与仅使用实数运算实现时相同的結果。
如果您对数学细节感到好奇,或者想知道如何在 PyTorch 中定义复数导数,请继续阅读。
什么是复数导数?¶
复数可微性的数学定义采用导数的极限定义,并将其推广到复数运算。考虑一个函数 ,
其中 和 是两个变量的实值函数,而 是虚数单位。
利用导数定义,我们可以写成
为了使这个极限存在,不仅 和 必须是实可微的,而且 也必须满足柯西-黎曼 方程。换句话说:对实部和虚部步长 () 计算的极限必须相等。这是一个更严格的条件。
复可微函数通常被称为全纯函数。它们表现良好,具有您从实可微函数中看到的良好性质,但在优化领域几乎没有用处。对于优化问题,研究界只使用实值目标函数,因为复数不属于任何有序域,因此使用复值损失没有多大意义。
事实证明,没有有趣的实值目标函数满足柯西-黎曼方程。因此,同态函数理论不能用于优化,大多数人因此使用维尔廷格微积分。
维尔廷格微积分出现了……¶
So, we have this great theory of complex differentiability and holomorphic functions, and we can’t use any of it at all, because many of the commonly used functions are not holomorphic. What’s a poor mathematician to do? Well, Wirtinger observed that even if isn’t holomorphic, one could rewrite it as a two variable function which is always holomorphic. This is because real and imaginary of the components of can be expressed in terms of and as:
Wirtinger calculus suggests to study instead, which is guaranteed to be holomorphic if was real differentiable (another way to think of it is as a change of coordinate system, from to .) This function has partial derivatives and . We can use the chain rule to establish a relationship between these partial derivatives and the partial derivatives w.r.t., the real and imaginary components of .
从上面的等式中,我们得到
这是你在维基百科上看到的Wirtinger微积分的经典定义。
这种改变带来了很多美丽的结论。
For one, the Cauchy-Riemann equations translate into simply saying that (that is to say, the function can be written entirely in terms of , without making reference to ).
Another important (and somewhat counterintuitive) result, as we’ll see later, is that when we do optimization on a real-valued loss, the step we should take while making variable update is given by (not ).
更多阅读,请查看:https://arxiv.org/pdf/0906.4835.pdf
Wirtinger微积分在优化中如何有用?¶
音频和其他领域的学者更常使用梯度下降来优化具有复数变量的实值损失函数。通常,这些人将实部和虚部视为可以更新的独立通道。对于步长和损失,我们可以在中写出以下等式
这些等式如何转化到复数空间?
发生了一些非常有趣的事情:Wirtinger 微积分告诉我们,我们可以简化上面的复变量更新公式,使其只引用共轭 Wirtinger 导数 , 这正是我们在优化中采取的步骤。
由于共轭 Wirtinger 导数为我们提供了实值损失函数的正确步骤,因此 PyTorch 在你对具有实值损失的函数进行微分时会提供此导数。
PyTorch 如何计算共轭 Wirtinger 导数?¶
Typically, our derivative formulas take in grad_output as an input, representing the incoming Vector-Jacobian product that we’ve already computed, aka, , where is the loss of the entire computation (producing a real loss) and is the output of our function. The goal here is to compute , where is the input of the function. It turns out that in the case of real loss, we can get away with only calculating , even though the chain rule implies that we also need to have access to . If you want to skip this derivation, look at the last equation in this section and then skip to the next section.
Let’s continue working with defined as . As discussed above, autograd’s gradient convention is centered around optimization for real valued loss functions, so let’s assume is a part of larger real valued loss function . Using chain rule, we can write:
(1)¶
现在使用 Wirtinger 导数定义,我们可以写成
这里需要注意的是,由于 和 是实函数,并且根据我们假设 是一个实值函数的一部分,所以 是实数。
(2)¶
即, 等于 .
对上述方程求解 和 , 我们得到
(3)¶
利用 (2), 我们得到
(4)¶
最后一个方程对于编写自己的梯度公式非常重要,因为它将我们的导数公式分解为一个更简单的公式,该公式易于手工计算。
如何为复杂函数编写自己的导数公式?¶
The above boxed equation gives us the general formula for all derivatives on complex functions. However, we still need to compute and . There are two ways you could do this:
The first way is to just use the definition of Wirtinger derivatives directly and calculate and by using and (which you can compute in the normal way).
The second way is to use the change of variables trick and rewrite as a two variable function , and compute the conjugate Wirtinger derivatives by treating and as independent variables. This is often easier; for example, if the function in question is holomorphic, only will be used (and will be zero).
Let’s consider the function as an example, where .
使用第一种方法计算 Wirtinger 导数,我们得到。
使用 (4) 和 grad_output = 1.0(这是在 PyTorch 中对标量输出调用 backward()
时使用的默认梯度输出值),我们得到
使用第二种计算 Wirtinger 导数的方法,我们直接得到
再次使用 (4),我们得到 . 如你所见,第二种方法涉及更少的计算,并且在快速计算中更加方便。
保存张量的钩子¶
您可以通过定义一对 pack_hook
/ unpack_hook
钩子来控制 保存的张量如何打包/解包。 pack_hook
函数应该将张量作为其唯一的参数,但可以返回任何 Python 对象(例如另一个张量、元组,甚至包含文件名字符串)。 unpack_hook
函数以 pack_hook
的输出作为其唯一参数,并且应该返回一个将在反向传播中使用的张量。 unpack_hook
返回的张量只需要与作为输入传递给 pack_hook
的张量具有相同的内容。特别是,任何与 autograd 相关的元数据都可以忽略,因为它们将在解包期间被覆盖。
这样一对的示例是
class SelfDeletingTempFile():
def __init__(self):
self.name = os.path.join(tmp_dir, str(uuid.uuid4()))
def __del__(self):
os.remove(self.name)
def pack_hook(tensor):
temp_file = SelfDeletingTempFile()
torch.save(tensor, temp_file.name)
return temp_file
def unpack_hook(temp_file):
return torch.load(temp_file.name)
请注意, unpack_hook
不应该删除临时文件,因为它可能被多次调用:临时文件应该在返回的 SelfDeletingTempFile 对象存活的整个时间内保持存活。在上面的示例中,我们通过在不再需要时关闭它(在 SelfDeletingTempFile 对象删除时)来防止临时文件泄漏。
注意
我们保证 pack_hook
只会被调用一次,但 unpack_hook
可以根据反向传播的要求被调用多次,我们期望它每次都返回相同的数据。
警告
禁止对任何函数的输入执行就地操作,因为它们可能会导致意外的副作用。如果对打包钩子的输入进行了就地修改,PyTorch 会抛出错误,但不会捕获对解包钩子的输入进行就地修改的情况。
为保存的张量注册钩子¶
您可以通过调用 register_hooks()
方法在保存的张量上注册一对钩子,该方法位于 SavedTensor
对象上。这些对象作为 grad_fn
的属性公开,并以 _raw_saved_
前缀开头。
x = torch.randn(5, requires_grad=True)
y = x.pow(2)
y.grad_fn._raw_saved_self.register_hooks(pack_hook, unpack_hook)
当配对注册后,pack_hook
方法会被立即调用。每次需要访问保存的张量时,unpack_hook
方法会被调用,无论是通过 y.grad_fn._saved_self
还是在反向传播过程中。
警告
如果你在保存的张量被释放后(即反向传播调用之后)仍然保留对 SavedTensor
的引用,那么调用它的 register_hooks()
方法是被禁止的。PyTorch 大部分情况下会抛出错误,但在某些情况下可能会失败,并可能导致未定义的行为。
为保存的张量注册默认钩子¶
或者,你可以使用上下文管理器 saved_tensors_hooks
来注册一对钩子,这些钩子将应用于该上下文中创建的所有保存的张量。
示例
# Only save on disk tensors that have size >= 1000
SAVE_ON_DISK_THRESHOLD = 1000
def pack_hook(x):
if x.numel() < SAVE_ON_DISK_THRESHOLD:
return x
temp_file = SelfDeletingTempFile()
torch.save(tensor, temp_file.name)
return temp_file
def unpack_hook(tensor_or_sctf):
if isinstance(tensor_or_sctf, torch.Tensor):
return tensor_or_sctf
return torch.load(tensor_or_sctf.name)
class Model(nn.Module):
def forward(self, x):
with torch.autograd.graph.saved_tensors_hooks(pack_hook, unpack_hook):
# ... compute output
output = x
return output
model = Model()
net = nn.DataParallel(model)
使用此上下文管理器定义的钩子是线程局部的。因此,以下代码不会产生预期效果,因为钩子不会通过 DataParallel。
# Example what NOT to do
net = nn.DataParallel(model)
with torch.autograd.graph.saved_tensors_hooks(pack_hook, unpack_hook):
output = net(input)
请注意,使用这些钩子会禁用所有用于减少张量对象创建的优化。例如
with torch.autograd.graph.saved_tensors_hooks(lambda x: x, lambda x: x):
x = torch.randn(5, requires_grad=True)
y = x * x
在没有钩子的情况下,x
、y.grad_fn._saved_self
和 y.grad_fn._saved_other
都引用同一个张量对象。使用钩子后,PyTorch 会将 x 打包和解包到两个新的张量对象中,这两个对象与原始 x 共享相同的存储空间(没有执行复制)。
反向钩子执行¶
本节将讨论不同钩子何时触发或不触发,然后讨论它们的触发顺序。将涵盖的钩子包括:通过 torch.Tensor.register_hook()
注册到 Tensor 的反向钩子,通过 torch.Tensor.register_post_accumulate_grad_hook()
注册到 Tensor 的后累积梯度钩子,通过 torch.autograd.graph.Node.register_hook()
注册到 Node 的后钩子,以及通过 torch.autograd.graph.Node.register_prehook()
注册到 Node 的前钩子。
特定钩子是否会被触发¶
通过 torch.Tensor.register_hook()
注册到 Tensor 的钩子在为该 Tensor 计算梯度时执行。(请注意,这并不需要执行 Tensor 的 grad_fn。例如,如果 Tensor 作为 inputs
参数的一部分传递给 torch.autograd.grad()
,Tensor 的 grad_fn 可能不会执行,但注册到该 Tensor 的钩子将始终执行。)
通过 torch.Tensor.register_post_accumulate_grad_hook()
注册到 Tensor 的钩子在为该 Tensor 累积梯度后执行,这意味着 Tensor 的 grad 字段已设置。而通过 torch.Tensor.register_hook()
注册的钩子在计算梯度时运行,通过 torch.Tensor.register_post_accumulate_grad_hook()
注册的钩子仅在反向传播结束时 Tensor 的 grad 字段由 autograd 更新后触发。因此,后累积梯度钩子只能注册到叶子 Tensor。在非叶子 Tensor 上通过 torch.Tensor.register_post_accumulate_grad_hook()
注册钩子将导致错误,即使您调用了 backward(retain_graph=True)。
使用 torch.autograd.graph.Node.register_hook()
或 torch.autograd.graph.Node.register_prehook()
注册到 torch.autograd.graph.Node
的钩子,只有在注册到的节点被执行时才会触发。
特定节点是否被执行可能取决于反向传播是否使用 torch.autograd.grad()
或 torch.autograd.backward()
调用。具体来说,当您在对应于传递给 torch.autograd.grad()
或 torch.autograd.backward()
的 inputs
参数的张量上的节点注册钩子时,您应该注意这些差异。
如果您使用的是 torch.autograd.backward()
,则上述所有钩子都将被执行,无论您是否指定了 inputs
参数。这是因为 .backward() 会执行所有节点,即使它们对应于作为输入指定的张量。(请注意,执行对应于作为 inputs
传递的张量的此额外节点通常是不必要的,但无论如何都会执行。此行为可能会发生变化;您不应该依赖它。)
另一方面,如果您使用的是 torch.autograd.grad()
,则注册到对应于传递给 input
的张量的节点的反向钩子可能不会被执行,因为这些节点将不会被执行,除非有另一个依赖于此节点的梯度结果的输入。
不同钩子触发顺序¶
事件发生的顺序如下:
注册到 Tensor 的钩子被执行
注册到 Node 的预钩子被执行(如果 Node 被执行)。
对于保留梯度的 Tensor,其
.grad
字段被更新Node 被执行(受上述规则约束)
对于具有累积
.grad
的叶 Tensor,执行后累积梯度钩子注册到 Node 的后钩子被执行(如果 Node 被执行)
如果在同一个 Tensor 或 Node 上注册了多个相同类型的钩子,则它们按照注册顺序执行。较晚执行的钩子可以观察到较早钩子对梯度的修改。
特殊钩子¶
torch.autograd.graph.register_multi_grad_hook()
是使用注册到 Tensor 的钩子实现的。每个单独的 Tensor 钩子都按照上面定义的 Tensor 钩子顺序触发,并且当最后一个 Tensor 梯度被计算时,注册的多梯度钩子被调用。
torch.nn.modules.module.register_module_full_backward_hook()
是使用注册到 Node 的钩子实现的。在计算前向传播时,会将钩子注册到与模块的输入和输出相对应的 grad_fn。由于一个模块可能接受多个输入并返回多个输出,因此在向前传播之前首先将一个虚拟的自定义自动微分函数应用于模块的输入,并在返回向前传播的输出之前将模块的输出应用于模块的输出,以确保这些 Tensor 共享一个 grad_fn,然后我们可以将我们的钩子附加到该 grad_fn 上。
Tensor 被就地修改时的 Tensor 钩子行为¶
通常,注册到 Tensor 的钩子接收输出相对于该 Tensor 的梯度,其中 Tensor 的值被认为是计算反向传播时 Tensor 的值。
但是,如果您向张量注册钩子,然后就地修改该张量,那么在就地修改之前注册的钩子也会收到输出相对于张量的梯度,但张量的值将被视为就地修改之前的其值。
如果您更喜欢前一种情况的行为,您应该在对张量进行所有就地修改后向其注册钩子。例如
t = torch.tensor(1., requires_grad=True).sin()
t.cos_()
t.register_hook(fn)
t.backward()
此外,了解以下内容可能会有所帮助:在幕后,当向张量注册钩子时,它们实际上会永久绑定到该张量的 grad_fn,因此如果该张量随后被就地修改,即使该张量现在具有新的 grad_fn,在就地修改之前注册的钩子将继续与旧的 grad_fn 相关联,例如,当自动梯度引擎在图中到达该张量的旧 grad_fn 时,它们将触发。