注意
点击 这里 下载完整示例代码
雅可比矩阵、Hessian 矩阵、hvp、vhp 等:组合函数变换¶
创建于: 2023 年 3 月 15 日 | 最后更新: 2023 年 4 月 18 日 | 最后验证: 2024 年 11 月 05 日
计算雅可比矩阵或 Hessian 矩阵在许多非传统的深度学习模型中非常有用。使用 PyTorch 常规的自动微分 API (Tensor.backward()
, torch.autograd.grad
) 高效计算这些量是困难的(或令人头疼的)。PyTorch 受 JAX 启发的 函数变换 API 提供了高效计算各种高阶自动微分量的方法。
注意
本教程需要 PyTorch 2.0.0 或更高版本。
计算雅可比矩阵¶
import torch
import torch.nn.functional as F
from functools import partial
_ = torch.manual_seed(0)
让我们从一个我们想计算其雅可比矩阵的函数开始。这是一个带有非线性激活的简单线性函数。
让我们添加一些 dummy 数据:一个权重、一个偏差和一个特征向量 x。
D = 16
weight = torch.randn(D, D)
bias = torch.randn(D)
x = torch.randn(D) # feature vector
让我们将 predict
视为一个将输入 x
从 \(R^D \to R^D\) 映射的函数。PyTorch Autograd 计算向量-雅可比积。为了计算这个 \(R^D \to R^D\) 函数的完整雅可比矩阵,我们必须每次使用不同的单位向量,逐行计算它。
def compute_jac(xp):
jacobian_rows = [torch.autograd.grad(predict(weight, bias, xp), xp, vec)[0]
for vec in unit_vectors]
return torch.stack(jacobian_rows)
xp = x.clone().requires_grad_()
unit_vectors = torch.eye(D)
jacobian = compute_jac(xp)
print(jacobian.shape)
print(jacobian[0]) # show first row
torch.Size([16, 16])
tensor([-0.5956, -0.6096, -0.1326, -0.2295, 0.4490, 0.3661, -0.1672, -1.1190,
0.1705, -0.6683, 0.1851, 0.1630, 0.0634, 0.6547, 0.5908, -0.1308])
与其逐行计算雅可比矩阵,我们可以使用 PyTorch 的 torch.vmap
函数变换来消除 for 循环并将计算向量化。我们不能直接将 vmap
应用于 torch.autograd.grad
;相反,PyTorch 提供了一个 torch.func.vjp
变换,它可以与 torch.vmap
组合使用。
from torch.func import vmap, vjp
_, vjp_fn = vjp(partial(predict, weight, bias), x)
ft_jacobian, = vmap(vjp_fn)(unit_vectors)
# let's confirm both methods compute the same result
assert torch.allclose(ft_jacobian, jacobian)
在后续教程中,反向模式 AD 和 vmap
的组合将为我们提供单样本梯度。在本教程中,组合反向模式 AD 和 vmap
为我们提供了雅可比矩阵的计算!vmap
和自动微分变换的各种组合可以给我们带来不同的有趣量。
PyTorch 提供了 torch.func.jacrev
作为便利函数,它执行 vmap-vjp
组合来计算雅可比矩阵。jacrev
接受一个 argnums
参数,该参数指定我们希望计算相对于哪个参数的雅可比矩阵。
from torch.func import jacrev
ft_jacobian = jacrev(predict, argnums=2)(weight, bias, x)
# Confirm by running the following:
assert torch.allclose(ft_jacobian, jacobian)
让我们比较这两种计算雅可比矩阵的方法的性能。函数变换版本要快得多(输出越多,速度越快)。
通常,我们期望通过 vmap
进行向量化可以帮助消除开销并更好地利用你的硬件。
vmap
通过将外部循环推到函数的基本操作中来实现这种神奇的效果,从而获得更好的性能。
让我们创建一个快速函数来评估性能并处理微秒和毫秒的测量值。
def get_perf(first, first_descriptor, second, second_descriptor):
"""takes torch.benchmark objects and compares delta of second vs first."""
faster = second.times[0]
slower = first.times[0]
gain = (slower-faster)/slower
if gain < 0: gain *=-1
final_gain = gain*100
print(f" Performance delta: {final_gain:.4f} percent improvement with {second_descriptor} ")
然后运行性能比较。
from torch.utils.benchmark import Timer
without_vmap = Timer(stmt="compute_jac(xp)", globals=globals())
with_vmap = Timer(stmt="jacrev(predict, argnums=2)(weight, bias, x)", globals=globals())
no_vmap_timer = without_vmap.timeit(500)
with_vmap_timer = with_vmap.timeit(500)
print(no_vmap_timer)
print(with_vmap_timer)
<torch.utils.benchmark.utils.common.Measurement object at 0x7f238479e0b0>
compute_jac(xp)
1.49 ms
1 measurement, 500 runs , 1 thread
<torch.utils.benchmark.utils.common.Measurement object at 0x7f237dddef50>
jacrev(predict, argnums=2)(weight, bias, x)
403.88 us
1 measurement, 500 runs , 1 thread
让我们使用我们的 get_perf
函数对上述方法进行相对性能比较。
get_perf(no_vmap_timer, "without vmap", with_vmap_timer, "vmap")
Performance delta: 72.8187 percent improvement with vmap
此外,很容易将问题反过来,说我们想计算模型参数(权重、偏差)的雅可比矩阵,而不是输入的雅可比矩阵。
# note the change in input via ``argnums`` parameters of 0,1 to map to weight and bias
ft_jac_weight, ft_jac_bias = jacrev(predict, argnums=(0, 1))(weight, bias, x)
反向模式雅可比矩阵 (jacrev
) 与前向模式雅可比矩阵 (jacfwd
) 对比¶
我们提供两个 API 来计算雅可比矩阵:jacrev
和 jacfwd
jacrev
使用反向模式 AD。如您所见,它是我们的vjp
和vmap
变换的组合。jacfwd
使用前向模式 AD。它被实现为我们的jvp
和vmap
变换的组合。
jacfwd
和 jacrev
可以相互替换,但它们具有不同的性能特征。
一般来说,如果你正在计算一个 \(R^N \to R^M\) 函数的雅可比矩阵,并且输出数量远多于输入数量(例如,\(M > N\)),那么首选 jacfwd
,否则使用 jacrev
。这条规则也有例外,下面是一个非严格的论证:
在反向模式 AD 中,我们逐行计算雅可比矩阵,而在前向模式 AD(计算雅可比向量积)中,我们逐列计算它。雅可比矩阵有 M 行 N 列,因此如果它在某个方向上更高或更宽,我们可能更喜欢处理较少行或列的方法。
首先,让我们在输入多于输出的情况下进行基准测试。
Din = 32
Dout = 2048
weight = torch.randn(Dout, Din)
bias = torch.randn(Dout)
x = torch.randn(Din)
# remember the general rule about taller vs wider... here we have a taller matrix:
print(weight.shape)
using_fwd = Timer(stmt="jacfwd(predict, argnums=2)(weight, bias, x)", globals=globals())
using_bwd = Timer(stmt="jacrev(predict, argnums=2)(weight, bias, x)", globals=globals())
jacfwd_timing = using_fwd.timeit(500)
jacrev_timing = using_bwd.timeit(500)
print(f'jacfwd time: {jacfwd_timing}')
print(f'jacrev time: {jacrev_timing}')
torch.Size([2048, 32])
jacfwd time: <torch.utils.benchmark.utils.common.Measurement object at 0x7f237dc7bbb0>
jacfwd(predict, argnums=2)(weight, bias, x)
743.46 us
1 measurement, 500 runs , 1 thread
jacrev time: <torch.utils.benchmark.utils.common.Measurement object at 0x7f237d613d90>
jacrev(predict, argnums=2)(weight, bias, x)
8.31 ms
1 measurement, 500 runs , 1 thread
然后进行相对基准测试。
get_perf(jacfwd_timing, "jacfwd", jacrev_timing, "jacrev", );
Performance delta: 1017.9775 percent improvement with jacrev
现在反过来——输出 (M) 多于输入 (N)
Din = 2048
Dout = 32
weight = torch.randn(Dout, Din)
bias = torch.randn(Dout)
x = torch.randn(Din)
using_fwd = Timer(stmt="jacfwd(predict, argnums=2)(weight, bias, x)", globals=globals())
using_bwd = Timer(stmt="jacrev(predict, argnums=2)(weight, bias, x)", globals=globals())
jacfwd_timing = using_fwd.timeit(500)
jacrev_timing = using_bwd.timeit(500)
print(f'jacfwd time: {jacfwd_timing}')
print(f'jacrev time: {jacrev_timing}')
jacfwd time: <torch.utils.benchmark.utils.common.Measurement object at 0x7f238491ae90>
jacfwd(predict, argnums=2)(weight, bias, x)
6.94 ms
1 measurement, 500 runs , 1 thread
jacrev time: <torch.utils.benchmark.utils.common.Measurement object at 0x7f237d5ab130>
jacrev(predict, argnums=2)(weight, bias, x)
476.96 us
1 measurement, 500 runs , 1 thread
并进行相对性能比较。
get_perf(jacrev_timing, "jacrev", jacfwd_timing, "jacfwd")
Performance delta: 1354.6663 percent improvement with jacfwd
使用 functorch.hessian 计算 Hessian 矩阵¶
我们提供一个便利的 API 来计算 Hessian 矩阵:torch.func.hessiani
。Hessian 矩阵是雅可比矩阵的雅可比矩阵(或偏导数的偏导数,即二阶导数)。
这表明我们可以简单地组合 functorch 雅可比变换来计算 Hessian 矩阵。实际上,在底层,hessian(f)
只是 jacfwd(jacrev(f))
的简写。
注意:为了提升性能:根据你的模型,你可能也想使用 jacfwd(jacfwd(f))
或 jacrev(jacrev(f))
来计算 Hessian 矩阵,这利用了上面关于更宽或更高矩阵的经验法则。
from torch.func import hessian
# lets reduce the size in order not to overwhelm Colab. Hessians require
# significant memory:
Din = 512
Dout = 32
weight = torch.randn(Dout, Din)
bias = torch.randn(Dout)
x = torch.randn(Din)
hess_api = hessian(predict, argnums=2)(weight, bias, x)
hess_fwdfwd = jacfwd(jacfwd(predict, argnums=2), argnums=2)(weight, bias, x)
hess_revrev = jacrev(jacrev(predict, argnums=2), argnums=2)(weight, bias, x)
让我们验证一下,无论使用 hessian API 还是使用 jacfwd(jacfwd())
,我们都能得到相同的结果。
True
批量雅可比矩阵和批量 Hessian 矩阵¶
在上面的示例中,我们一直在处理单个特征向量。在某些情况下,你可能想计算一批输出相对于一批输入的雅可比矩阵。也就是说,给定形状为 (B, N)
的一批输入和一个从 \(R^N \to R^M\) 的函数,我们希望得到形状为 (B, M, N)
的雅可比矩阵。
最简单的方法是使用 vmap
batch_size = 64
Din = 31
Dout = 33
weight = torch.randn(Dout, Din)
print(f"weight shape = {weight.shape}")
bias = torch.randn(Dout)
x = torch.randn(batch_size, Din)
compute_batch_jacobian = vmap(jacrev(predict, argnums=2), in_dims=(None, None, 0))
batch_jacobian0 = compute_batch_jacobian(weight, bias, x)
weight shape = torch.Size([33, 31])
如果你有一个函数是从 (B, N) -> (B, M),并且确定每个输入都产生独立的输出,那么有时也可以通过对输出求和,然后计算该函数的雅可比矩阵来实现这一点,而无需使用 vmap
。
def predict_with_output_summed(weight, bias, x):
return predict(weight, bias, x).sum(0)
batch_jacobian1 = jacrev(predict_with_output_summed, argnums=2)(weight, bias, x).movedim(1, 0)
assert torch.allclose(batch_jacobian0, batch_jacobian1)
如果你有一个从 \(R^N \to R^M\) 的函数,但输入是批量处理的,则可以将 vmap
与 jacrev
组合来计算批量雅可比矩阵。
最后,批量 Hessian 矩阵也可以类似计算。最容易理解的方法是使用 vmap
对 Hessian 计算进行批量处理,但在某些情况下,求和技巧也适用。
compute_batch_hessian = vmap(hessian(predict, argnums=2), in_dims=(None, None, 0))
batch_hess = compute_batch_hessian(weight, bias, x)
batch_hess.shape
torch.Size([64, 33, 31, 31])
计算 Hessian-向量积¶
计算 Hessian-向量积 (hvp) 的朴素方法是具体化完整的 Hessian 矩阵并与向量进行点积。我们可以做得更好:事实证明,我们不需要具体化完整的 Hessian 矩阵来做到这一点。我们将介绍计算 Hessian-向量积的两种(多种)不同策略: - 组合反向模式 AD 与反向模式 AD - 组合反向模式 AD 与前向模式 AD
将反向模式 AD 与前向模式 AD 组合(而不是反向模式与反向模式组合)通常是计算 hvp 更节省内存的方式,因为前向模式 AD 不需要构建 Autograd 图并保存中间结果用于反向传播。
以下是一些示例用法。
def f(x):
return x.sin().sum()
x = torch.randn(2048)
tangent = torch.randn(2048)
result = hvp(f, (x,), (tangent,))
如果 PyTorch 前向 AD 不支持你的操作,那么我们可以改为组合反向模式 AD 与反向模式 AD。
脚本总运行时间:( 0 分 10.473 秒)