神经切线核¶
神经切线核 (NTK) 是一个描述神经网络在训练过程中如何演变的核函数。近年来,围绕它进行了大量的研究。本教程受JAX 中 NTK 的实现的启发(详情请参阅Fast Finite Width Neural Tangent Kernel),演示了如何使用 functorch轻松计算此量。
设置¶
首先,进行一些设置。让我们定义一个简单的 CNN,我们希望计算它的 NTK。
import torch
import torch.nn as nn
from functorch import make_functional, vmap, vjp, jvp, jacrev
device = 'cuda'
class CNN(nn.Module):
def __init__(self):
super().__init__()
self.conv1 = nn.Conv2d(3, 32, (3, 3))
self.conv2 = nn.Conv2d(32, 32, (3, 3))
self.conv3 = nn.Conv2d(32, 32, (3, 3))
self.fc = nn.Linear(21632, 10)
def forward(self, x):
x = self.conv1(x)
x = x.relu()
x = self.conv2(x)
x = x.relu()
x = self.conv3(x)
x = x.flatten(1)
x = self.fc(x)
return x
让我们生成一些随机数据
x_train = torch.randn(20, 3, 32, 32, device=device)
x_test = torch.randn(5, 3, 32, 32, device=device)
创建模型的函数版本¶
functorch 变换作用于函数。特别是,为了计算 NTK,我们需要一个接受模型参数和单个输入(而不是一批输入!)并返回单个输出的函数。
我们将使用 functorch 的make_functional
来完成第一步。如果您的模块有缓冲区,您可能需要使用make_functional_with_buffers
代替。
net = CNN().to(device)
fnet, params = make_functional(net)
请记住,模型最初是编写为接受一批输入数据点的。在我们的 CNN 示例中,没有批间操作。也就是说,批次中的每个数据点都独立于其他数据点。考虑到这个假设,我们可以轻松地生成一个函数来评估模型在一个数据点上的值
def fnet_single(params, x):
return fnet(params, x.unsqueeze(0)).squeeze(0)
计算 NTK:方法 1(雅可比矩阵收缩)¶
我们准备计算经验 NTK。两个数据点\(x_1\) 和\(x_2\) 的经验 NTK 定义为在\(x_1\) 处评估的模型的雅可比矩阵与在\(x_2\) 处评估的模型的雅可比矩阵之间的矩阵乘积
在\(x_1\) 是一批数据点且\(x_2\) 是一批数据点的批处理情况下,我们希望\(x_1\) 和\(x_2\) 中所有数据点组合的雅可比矩阵之间的矩阵乘积。
第一种方法就是这样做——计算两个雅可比矩阵,并将其收缩。以下是计算批处理情况下 NTK 的方法
def empirical_ntk_jacobian_contraction(fnet_single, params, x1, x2):
# Compute J(x1)
jac1 = vmap(jacrev(fnet_single), (None, 0))(params, x1)
jac1 = [j.flatten(2) for j in jac1]
# Compute J(x2)
jac2 = vmap(jacrev(fnet_single), (None, 0))(params, x2)
jac2 = [j.flatten(2) for j in jac2]
# Compute J(x1) @ J(x2).T
result = torch.stack([torch.einsum('Naf,Mbf->NMab', j1, j2) for j1, j2 in zip(jac1, jac2)])
result = result.sum(0)
return result
result = empirical_ntk_jacobian_contraction(fnet_single, params, x_train, x_test)
print(result.shape)
torch.Size([20, 5, 10, 10])
在某些情况下,您可能只需要此量的对角线或迹,尤其是在您事先知道网络架构会导致 NTK 中非对角线元素可以近似为零的情况下。调整上述函数以执行此操作很容易
def empirical_ntk_jacobian_contraction(fnet_single, params, x1, x2, compute='full'):
# Compute J(x1)
jac1 = vmap(jacrev(fnet_single), (None, 0))(params, x1)
jac1 = [j.flatten(2) for j in jac1]
# Compute J(x2)
jac2 = vmap(jacrev(fnet_single), (None, 0))(params, x2)
jac2 = [j.flatten(2) for j in jac2]
# Compute J(x1) @ J(x2).T
einsum_expr = None
if compute == 'full':
einsum_expr = 'Naf,Mbf->NMab'
elif compute == 'trace':
einsum_expr = 'Naf,Maf->NM'
elif compute == 'diagonal':
einsum_expr = 'Naf,Maf->NMa'
else:
assert False
result = torch.stack([torch.einsum(einsum_expr, j1, j2) for j1, j2 in zip(jac1, jac2)])
result = result.sum(0)
return result
result = empirical_ntk_jacobian_contraction(fnet_single, params, x_train, x_test, 'trace')
print(result.shape)
torch.Size([20, 5])
此方法的渐近时间复杂度为\(N O [FP]\)(计算雅可比矩阵的时间)\( + N^2 O^2 P\)(收缩雅可比矩阵的时间),其中\(N\) 是\(x_1\) 和\(x_2\) 的批大小,\(O\) 是模型的输出大小,\(P\) 是参数的总数,而\([FP]\) 是通过模型进行一次前向传递的成本。有关详细信息,请参阅Fast Finite Width Neural Tangent Kernel 中的第 3.2 节。
计算 NTK:方法 2(NTK-向量积)¶
我们将讨论的下一种方法是使用 NTK-向量积来计算 NTK 的方法。
此方法将 NTK 重构为应用于大小为\(O\times O\)(其中\(O\) 是模型的输出大小)的单位矩阵\(I_O\) 的列的一系列 NTK-向量积
其中\(e_o\in \mathbb{R}^O\) 是单位矩阵\(I_O\) 的列向量。
令\(\textrm{vjp}_o = J_{net}^T(x_2) e_o\)。我们可以使用向量-雅可比矩阵积来计算它。
现在,考虑\(J_{net}(x_1) \textrm{vjp}_o\)。这是一个雅可比矩阵-向量积!
最后,我们可以使用
vmap
对\(I_O\) 的所有列\(e_o\) 并行运行上述计算。
这表明我们可以结合反向模式自动微分(计算向量-雅可比矩阵积)和前向模式自动微分(计算雅可比矩阵-向量积)来计算 NTK。
让我们编写代码
def empirical_ntk_ntk_vps(func, params, x1, x2, compute='full'):
def get_ntk(x1, x2):
def func_x1(params):
return func(params, x1)
def func_x2(params):
return func(params, x2)
output, vjp_fn = vjp(func_x1, params)
def get_ntk_slice(vec):
# This computes vec @ J(x2).T
# `vec` is some unit vector (a single slice of the Identity matrix)
vjps = vjp_fn(vec)
# This computes J(X1) @ vjps
_, jvps = jvp(func_x2, (params,), vjps)
return jvps
# Here's our identity matrix
basis = torch.eye(output.numel(), dtype=output.dtype, device=output.device).view(output.numel(), -1)
return vmap(get_ntk_slice)(basis)
# get_ntk(x1, x2) computes the NTK for a single data point x1, x2
# Since the x1, x2 inputs to empirical_ntk_ntk_vps are batched,
# we actually wish to compute the NTK between every pair of data points
# between {x1} and {x2}. That's what the vmaps here do.
result = vmap(vmap(get_ntk, (None, 0)), (0, None))(x1, x2)
if compute == 'full':
return result
if compute == 'trace':
return torch.einsum('NMKK->NM', result)
if compute == 'diagonal':
return torch.einsum('NMKK->NMK', result)
result_from_jacobian_contraction = empirical_ntk_jacobian_contraction(fnet_single, params, x_test, x_train)
result_from_ntk_vps = empirical_ntk_ntk_vps(fnet_single, params, x_test, x_train)
assert torch.allclose(result_from_jacobian_contraction, result_from_ntk_vps, atol=1e-5)
我们为empirical_ntk_ntk_vps
编写的代码看起来像是上面数学公式的直接翻译!这展示了函数变换的强大功能:祝您好运尝试使用 PyTorch 编写上述内容的高效版本。
此方法的渐近时间复杂度为\(N^2 O [FP]\),其中\(N\) 是\(x_1\) 和\(x_2\) 的批大小,\(O\) 是模型的输出大小,而\([FP]\) 是通过模型进行一次前向传递的成本。因此,此方法执行比方法 1(雅可比矩阵收缩)更多的网络前向传递(\(N^2 O\) 而不是\(N O\)),但完全避免了收缩成本(没有\(N^2 O^2 P\) 项,其中\(P\) 是模型参数的总数)。因此,当\(O P\) 相对于\([FP]\) 较大时,此方法更可取,例如具有许多输出\(O\) 的全连接(而非卷积)模型。在内存方面,两种方法都应该具有可比性。有关详细信息,请参阅Fast Finite Width Neural Tangent Kernel 中的第 3.3 节。