雅可比矩阵、Hessian 矩阵、hvp、vhp 等:组合 functorch 变换¶
计算雅可比矩阵或 Hessian 矩阵在许多非传统的深度学习模型中非常有用。使用标准自动微分系统(如 PyTorch Autograd)来高效计算这些量非常困难(或令人烦恼);functorch 提供了有效计算各种高阶自动微分量的方法。
计算雅可比矩阵¶
import torch
import torch.nn as nn
import torch.nn.functional as F
from functools import partial
_ = torch.manual_seed(0)
让我们从一个想要计算其雅可比矩阵的函数开始。这是一个带有非线性激活的简单线性函数。
def predict(weight, bias, x):
return F.linear(x, weight, bias).tanh()
让我们添加一些虚拟数据:权重、偏差和特征向量 x。
D = 16
weight = torch.randn(D, D)
bias = torch.randn(D)
x = torch.randn(D) # feature vector
让我们将 predict
视为一个将输入 x
从 \(R^D -> R^D\) 映射的函数。PyTorch Autograd 计算向量-雅可比矩阵积。为了计算此 \(R^D -> R^D\) 函数的完整雅可比矩阵,我们必须每次使用不同的单位向量逐行计算它。
def compute_jac(xp):
jacobian_rows = [torch.autograd.grad(predict(weight, bias, xp), xp, vec)[0]
for vec in unit_vectors]
return torch.stack(jacobian_rows)
xp = x.clone().requires_grad_()
unit_vectors = torch.eye(D)
jacobian = compute_jac(xp)
print(jacobian.shape)
print(jacobian[0]) # show first row
torch.Size([16, 16])
tensor([-0.5956, -0.6096, -0.1326, -0.2295, 0.4490, 0.3661, -0.1672, -1.1190,
0.1705, -0.6683, 0.1851, 0.1630, 0.0634, 0.6547, 0.5908, -0.1308])
与其逐行计算雅可比矩阵,我们可以使用 vmap 来消除 for 循环并向量化计算。我们不能直接将 vmap 应用于 PyTorch Autograd;相反,functorch 提供了一个 vjp 变换
from functorch import vmap, vjp
_, vjp_fn = vjp(partial(predict, weight, bias), x)
ft_jacobian, = vmap(vjp_fn)(unit_vectors)
# lets confirm both methods compute the same result
assert torch.allclose(ft_jacobian, jacobian)
在未来的教程中,反向模式自动微分和 vmap 的组合将为我们提供每个样本的梯度。在本教程中,反向模式自动微分和 vmap 的组合为我们提供了雅可比矩阵计算!vmap 和自动微分变换的各种组合可以为我们提供不同的有趣量。
functorch 提供了 **jacrev** 作为执行 vmap-vjp 组合以计算雅可比矩阵的便捷函数。**jacrev** 接受一个 argnums 参数,该参数指定我们想要计算相对于哪个参数的雅可比矩阵。
from functorch import jacrev
ft_jacobian = jacrev(predict, argnums=2)(weight, bias, x)
# confirm
assert torch.allclose(ft_jacobian, jacobian)
让我们比较两种计算雅可比矩阵方法的性能。functorch 版本快得多(并且随着输出数量的增加而变得更快)。
通常,我们预计通过 vmap 进行向量化可以帮助消除开销并更好地利用硬件。
Vmap 通过将外循环下推到函数的基本操作中来实现此魔术,从而获得更好的性能。
让我们快速创建一个函数来评估性能并处理微秒和毫秒的测量
def get_perf(first, first_descriptor, second, second_descriptor):
""" takes torch.benchmark objects and compares delta of second vs first. """
faster = second.times[0]
slower = first.times[0]
gain = (slower-faster)/slower
if gain < 0: gain *=-1
final_gain = gain*100
print(f" Performance delta: {final_gain:.4f} percent improvement with {second_descriptor} ")
然后运行性能比较
from torch.utils.benchmark import Timer
without_vmap = Timer(stmt="compute_jac(xp)", globals=globals())
with_vmap = Timer(stmt="jacrev(predict, argnums=2)(weight, bias, x)", globals=globals())
no_vmap_timer = without_vmap.timeit(500)
with_vmap_timer = with_vmap.timeit(500)
print(no_vmap_timer)
print(with_vmap_timer)
<torch.utils.benchmark.utils.common.Measurement object at 0x7fa9a911b350>
compute_jac(xp)
2.25 ms
1 measurement, 500 runs , 1 thread
<torch.utils.benchmark.utils.common.Measurement object at 0x7fa9a6a99d50>
jacrev(predict, argnums=2)(weight, bias, x)
884.34 us
1 measurement, 500 runs , 1 thread
让我们使用 get_perf 函数对上述内容进行相对性能比较
get_perf(no_vmap_timer, "without vmap", with_vmap_timer, "vmap");
Performance delta: 60.7170 percent improvement with vmap
此外,很容易将问题反过来,说我们想要计算模型参数(权重、偏差)而不是输入的雅可比矩阵。
# note the change in input via argnums params of 0,1 to map to weight and bias
ft_jac_weight, ft_jac_bias = jacrev(predict, argnums=(0, 1))(weight, bias, x)
反向模式雅可比矩阵 (jacrev) 与正向模式雅可比矩阵 (jacfwd)¶
我们提供了两个 API 来计算雅可比矩阵:**jacrev** 和 **jacfwd**
jacrev 使用反向模式自动微分。如上所述,它是我们的 vjp 和 vmap 变换的组合。
jacfwd 使用正向模式自动微分。它实现为我们的 jvp 和 vmap 变换的组合。
jacfwd 和 jacrev 可以互相替换,但它们具有不同的性能特征。
一般来说,如果您正在计算一个 \(𝑅^N \to R^M\) 函数的雅可比矩阵,并且输出比输入多得多(即 \(M > N\)),则 jacfwd 是首选,否则使用 jacrev。此规则有一些例外,但以下是一个非严格的论证
在反向模式自动微分中,我们逐行计算雅可比矩阵,而在正向模式自动微分(计算雅可比矩阵-向量积)中,我们逐列计算它。雅可比矩阵有 M 行和 N 列,因此如果它在一方面更高或更宽,我们可能更喜欢处理行或列较少的方法。
from functorch import jacrev, jacfwd
首先,让我们对输入多于输出的情况进行基准测试
Din = 32
Dout = 2048
weight = torch.randn(Dout, Din)
bias = torch.randn(Dout)
x = torch.randn(Din)
# remember the general rule about taller vs wider...here we have a taller matrix:
print(weight.shape)
using_fwd = Timer(stmt="jacfwd(predict, argnums=2)(weight, bias, x)", globals=globals())
using_bwd = Timer(stmt="jacrev(predict, argnums=2)(weight, bias, x)", globals=globals())
jacfwd_timing = using_fwd.timeit(500)
jacrev_timing = using_bwd.timeit(500)
print(f'jacfwd time: {jacfwd_timing}')
print(f'jacrev time: {jacrev_timing}')
torch.Size([2048, 32])
jacfwd time: <torch.utils.benchmark.utils.common.Measurement object at 0x7fa9a5d792d0>
jacfwd(predict, argnums=2)(weight, bias, x)
1.32 ms
1 measurement, 500 runs , 1 thread
jacrev time: <torch.utils.benchmark.utils.common.Measurement object at 0x7fa9a4dee450>
jacrev(predict, argnums=2)(weight, bias, x)
12.46 ms
1 measurement, 500 runs , 1 thread
然后进行相对基准测试
get_perf(jacfwd_timing, "jacfwd", jacrev_timing, "jacrev", );
Performance delta: 842.8274 percent improvement with jacrev
现在反过来 - 输出 (M) 多于输入 (N)
Din = 2048
Dout = 32
weight = torch.randn(Dout, Din)
bias = torch.randn(Dout)
x = torch.randn(Din)
using_fwd = Timer(stmt="jacfwd(predict, argnums=2)(weight, bias, x)", globals=globals())
using_bwd = Timer(stmt="jacrev(predict, argnums=2)(weight, bias, x)", globals=globals())
jacfwd_timing = using_fwd.timeit(500)
jacrev_timing = using_bwd.timeit(500)
print(f'jacfwd time: {jacfwd_timing}')
print(f'jacrev time: {jacrev_timing}')
jacfwd time: <torch.utils.benchmark.utils.common.Measurement object at 0x7fa9a5d64790>
jacfwd(predict, argnums=2)(weight, bias, x)
7.99 ms
1 measurement, 500 runs , 1 thread
jacrev time: <torch.utils.benchmark.utils.common.Measurement object at 0x7fa9a5d67b50>
jacrev(predict, argnums=2)(weight, bias, x)
1.09 ms
1 measurement, 500 runs , 1 thread
以及相对性能比较
get_perf(jacrev_timing, "jacrev", jacfwd_timing, "jacfwd")
Performance delta: 635.2095 percent improvement with jacfwd
使用 functorch.hessian 计算 Hessian 矩阵¶
我们提供了一个方便的 API 来计算 Hessian 矩阵:functorch.hessian
。Hessian 矩阵是雅可比矩阵的雅可比矩阵(或偏导数的偏导数,即二阶导数)。
这表明人们可以简单地组合 functorch 的雅可比矩阵变换来计算 Hessian 矩阵。实际上,在幕后,hessian(f)
只是 jacfwd(jacrev(f))
。
注意:为了提高性能:根据您的模型,您可能还需要使用 jacfwd(jacfwd(f))
或 jacrev(jacrev(f))
来计算 Hessian 矩阵,利用上面关于更宽与更高矩阵的经验法则。
from functorch import hessian
# lets reduce the size in order not to blow out colab. Hessians require significant memory:
Din = 512
Dout = 32
weight = torch.randn(Dout, Din)
bias = torch.randn(Dout)
x = torch.randn(Din)
hess_api = hessian(predict, argnums=2)(weight, bias, x)
hess_fwdfwd = jacfwd(jacfwd(predict, argnums=2), argnums=2)(weight, bias, x)
#hess_revrev = jacrev(jacrev(predict, argnums=2), argnums=2)(weight, bias, x)
让我们验证无论使用 hessian api 还是使用 jacfwd(jacfwd()),我们是否都得到相同的结果
torch.allclose(hess_api, hess_fwdfwd)
True
批处理雅可比矩阵和批处理 Hessian 矩阵¶
在上面的示例中,我们一直在使用单个特征向量。在某些情况下,您可能希望获取一批输出相对于一批输入的雅可比矩阵。也就是说,给定一个形状为 (B, N)
的输入批次和一个从 \(R^N \to R^M\) 的函数,我们想要一个形状为 (B, M, N)
的雅可比矩阵。
最简单的方法是使用 vmap
batch_size = 64
Din = 31
Dout = 33
weight = torch.randn(Dout, Din)
print(f"weight shape = {weight.shape}")
bias = torch.randn(Dout)
x = torch.randn(batch_size, Din)
weight shape = torch.Size([33, 31])
compute_batch_jacobian = vmap(jacrev(predict, argnums=2), in_dims=(None, None, 0))
batch_jacobian0 = compute_batch_jacobian(weight, bias, x)
如果您有一个从 (B, N) -> (B, M) 的函数,并且确定每个输入都会产生一个独立的输出,那么有时也可以在不使用 vmap 的情况下通过对输出求和然后计算该函数的雅可比矩阵来做到这一点
def predict_with_output_summed(weight, bias, x):
return predict(weight, bias, x).sum(0)
batch_jacobian1 = jacrev(predict_with_output_summed, argnums=2)(weight, bias, x).movedim(1, 0)
assert torch.allclose(batch_jacobian0, batch_jacobian1)
如果您有一个从 \(𝑅^𝑁 \to 𝑅^𝑀\) 的函数,但输入是批处理的,则将 vmap 与 jacrev 组合以计算批处理雅可比矩阵
最后,批处理 Hessian 矩阵可以类似地计算。最简单的方法是使用 vmap 对 Hessian 矩阵计算进行批处理,但在某些情况下,求和技巧也有效。
compute_batch_hessian = vmap(hessian(predict, argnums=2), in_dims=(None, None, 0))
batch_hess = compute_batch_hessian(weight, bias, x)
batch_hess.shape
torch.Size([64, 33, 31, 31])
计算 Hessian 矩阵-向量积¶
计算 Hessian 矩阵-向量积 (hvp) 的简单方法是具体化完整的 Hessian 矩阵并与向量执行点积。我们可以做得更好:事实证明,我们不需要具体化完整的 Hessian 矩阵来做到这一点。我们将介绍两种(在许多种中)不同的计算 Hessian 矩阵-向量积的策略
组合反向模式自动微分与反向模式自动微分
组合反向模式自动微分与正向模式自动微分
组合反向模式自动微分与正向模式自动微分(而不是反向模式与反向模式)通常是计算 hvp 更节省内存的方法,因为正向模式自动微分不需要构建 Autograd 图并保存反向的中间值
from functorch import jvp, grad, vjp
def hvp(f, primals, tangents):
return jvp(grad(f), primals, tangents)[1]
以下是一些示例用法。
def f(x):
return x.sin().sum()
x = torch.randn(2048)
tangent = torch.randn(2048)
result = hvp(f, (x,), (tangent,))
如果 PyTorch 正向自动微分不覆盖您的操作,那么我们可以组合反向模式自动微分与反向模式自动微分
def hvp_revrev(f, primals, tangents):
_, vjp_fn = vjp(grad(f), *primals)
return vjp_fn(*tangents)
result_hvp_revrev = hvp_revrev(f, (x,), (tangent,))
assert torch.allclose(result, result_hvp_revrev[0])