注意
点击此处下载完整示例代码
神经正切核¶
创建于:2023 年 3 月 15 日 | 最后更新:2023 年 6 月 16 日 | 最后验证:未验证
神经正切核 (NTK) 是一个内核,描述了神经网络在训练期间如何演变。近年来,围绕它进行了大量研究。本教程受 JAX 中的 NTK 实现的启发(详见 快速有限宽度神经正切核),演示了如何使用 torch.func
(PyTorch 的可组合函数变换)轻松计算此量。
注意
本教程需要 PyTorch 2.0.0 或更高版本。
设置¶
首先,进行一些设置。让我们定义一个简单的 CNN,我们希望计算其 NTK。
import torch
import torch.nn as nn
from torch.func import functional_call, vmap, vjp, jvp, jacrev
device = 'cuda' if torch.cuda.device_count() > 0 else 'cpu'
class CNN(nn.Module):
def __init__(self):
super(CNN, self).__init__()
self.conv1 = nn.Conv2d(3, 32, (3, 3))
self.conv2 = nn.Conv2d(32, 32, (3, 3))
self.conv3 = nn.Conv2d(32, 32, (3, 3))
self.fc = nn.Linear(21632, 10)
def forward(self, x):
x = self.conv1(x)
x = x.relu()
x = self.conv2(x)
x = x.relu()
x = self.conv3(x)
x = x.flatten(1)
x = self.fc(x)
return x
让我们生成一些随机数据
x_train = torch.randn(20, 3, 32, 32, device=device)
x_test = torch.randn(5, 3, 32, 32, device=device)
创建模型的函数版本¶
torch.func
变换作用于函数。特别是,为了计算 NTK,我们将需要一个函数,该函数接受模型的参数和单个输入(而不是一批输入!),并返回单个输出。
我们将使用 torch.func.functional_call
,它允许我们使用不同的参数/缓冲区调用 nn.Module
,以帮助完成第一步。
请记住,该模型最初编写为接受一批输入数据点。在我们的 CNN 示例中,没有批间操作。也就是说,批次中的每个数据点都独立于其他数据点。考虑到这一假设,我们可以轻松生成一个函数,该函数评估单个数据点上的模型
net = CNN().to(device)
# Detaching the parameters because we won't be calling Tensor.backward().
params = {k: v.detach() for k, v in net.named_parameters()}
def fnet_single(params, x):
return functional_call(net, params, (x.unsqueeze(0),)).squeeze(0)
计算 NTK:方法 1(雅可比矩阵收缩)¶
我们已准备好计算经验 NTK。两个数据点 \(x_1\) 和 \(x_2\) 的经验 NTK 定义为在 \(x_1\) 处评估的模型雅可比矩阵与在 \(x_2\) 处评估的模型雅可比矩阵之间的矩阵乘积
在批量情况下,其中 \(x_1\) 是一批数据点,\(x_2\) 是一批数据点,那么我们想要来自 \(x_1\) 和 \(x_2\) 的所有数据点组合的雅可比矩阵之间的矩阵乘积。
第一种方法包括执行此操作 - 计算两个雅可比矩阵,并收缩它们。以下是如何在批量情况下计算 NTK
def empirical_ntk_jacobian_contraction(fnet_single, params, x1, x2):
# Compute J(x1)
jac1 = vmap(jacrev(fnet_single), (None, 0))(params, x1)
jac1 = jac1.values()
jac1 = [j.flatten(2) for j in jac1]
# Compute J(x2)
jac2 = vmap(jacrev(fnet_single), (None, 0))(params, x2)
jac2 = jac2.values()
jac2 = [j.flatten(2) for j in jac2]
# Compute J(x1) @ J(x2).T
result = torch.stack([torch.einsum('Naf,Mbf->NMab', j1, j2) for j1, j2 in zip(jac1, jac2)])
result = result.sum(0)
return result
result = empirical_ntk_jacobian_contraction(fnet_single, params, x_train, x_test)
print(result.shape)
torch.Size([20, 5, 10, 10])
在某些情况下,您可能只想获得此量的对角线或迹,特别是如果您事先知道网络架构会导致 NTK,其中非对角线元素可以近似为零。很容易调整上述函数来做到这一点
def empirical_ntk_jacobian_contraction(fnet_single, params, x1, x2, compute='full'):
# Compute J(x1)
jac1 = vmap(jacrev(fnet_single), (None, 0))(params, x1)
jac1 = jac1.values()
jac1 = [j.flatten(2) for j in jac1]
# Compute J(x2)
jac2 = vmap(jacrev(fnet_single), (None, 0))(params, x2)
jac2 = jac2.values()
jac2 = [j.flatten(2) for j in jac2]
# Compute J(x1) @ J(x2).T
einsum_expr = None
if compute == 'full':
einsum_expr = 'Naf,Mbf->NMab'
elif compute == 'trace':
einsum_expr = 'Naf,Maf->NM'
elif compute == 'diagonal':
einsum_expr = 'Naf,Maf->NMa'
else:
assert False
result = torch.stack([torch.einsum(einsum_expr, j1, j2) for j1, j2 in zip(jac1, jac2)])
result = result.sum(0)
return result
result = empirical_ntk_jacobian_contraction(fnet_single, params, x_train, x_test, 'trace')
print(result.shape)
torch.Size([20, 5])
此方法的渐近时间复杂度为 \(N O [FP]\) (计算雅可比矩阵的时间)+ \(N^2 O^2 P\) (收缩雅可比矩阵的时间),其中 \(N\) 是 \(x_1\) 和 \(x_2\) 的批大小,\(O\) 是模型的输出大小,\(P\) 是参数总数,\([FP]\) 是通过模型的单次前向传递的成本。有关详细信息,请参见 快速有限宽度神经正切核 中的第 3.2 节。
计算 NTK:方法 2(NTK 向量积)¶
我们将讨论的下一种方法是使用 NTK 向量积计算 NTK 的方法。
此方法将 NTK 重新表述为应用于大小为 \(O\times O\) 的单位矩阵 \(I_O\) 的列的 NTK 向量积堆栈(其中 \(O\) 是模型的输出大小)
其中 \(e_o\in \mathbb{R}^O\) 是单位矩阵 \(I_O\) 的列向量。
令 \(\textrm{vjp}_o = J_{net}^T(x_2) e_o\)。我们可以使用向量-雅可比矩阵积来计算它。
现在,考虑 \(J_{net}(x_1) \textrm{vjp}_o\)。这是一个雅可比矩阵-向量积!
最后,我们可以使用
vmap
并行运行上述计算,对 \(I_O\) 的所有列 \(e_o\) 进行运算。
这表明我们可以使用反向模式 AD(计算向量-雅可比矩阵积)和前向模式 AD(计算雅可比矩阵-向量积)的组合来计算 NTK。
让我们编写代码
def empirical_ntk_ntk_vps(func, params, x1, x2, compute='full'):
def get_ntk(x1, x2):
def func_x1(params):
return func(params, x1)
def func_x2(params):
return func(params, x2)
output, vjp_fn = vjp(func_x1, params)
def get_ntk_slice(vec):
# This computes ``vec @ J(x2).T``
# `vec` is some unit vector (a single slice of the Identity matrix)
vjps = vjp_fn(vec)
# This computes ``J(X1) @ vjps``
_, jvps = jvp(func_x2, (params,), vjps)
return jvps
# Here's our identity matrix
basis = torch.eye(output.numel(), dtype=output.dtype, device=output.device).view(output.numel(), -1)
return vmap(get_ntk_slice)(basis)
# ``get_ntk(x1, x2)`` computes the NTK for a single data point x1, x2
# Since the x1, x2 inputs to ``empirical_ntk_ntk_vps`` are batched,
# we actually wish to compute the NTK between every pair of data points
# between {x1} and {x2}. That's what the ``vmaps`` here do.
result = vmap(vmap(get_ntk, (None, 0)), (0, None))(x1, x2)
if compute == 'full':
return result
if compute == 'trace':
return torch.einsum('NMKK->NM', result)
if compute == 'diagonal':
return torch.einsum('NMKK->NMK', result)
# Disable TensorFloat-32 for convolutions on Ampere+ GPUs to sacrifice performance in favor of accuracy
with torch.backends.cudnn.flags(allow_tf32=False):
result_from_jacobian_contraction = empirical_ntk_jacobian_contraction(fnet_single, params, x_test, x_train)
result_from_ntk_vps = empirical_ntk_ntk_vps(fnet_single, params, x_test, x_train)
assert torch.allclose(result_from_jacobian_contraction, result_from_ntk_vps, atol=1e-5)
我们用于 empirical_ntk_ntk_vps
的代码看起来像是从上面的数学公式直接翻译过来的!这展示了函数变换的强大功能:祝您好运,尝试仅使用 torch.autograd.grad
编写上述公式的有效版本。
此方法的渐近时间复杂度为 \(N^2 O [FP]\),其中 \(N\) 是 \(x_1\) 和 \(x_2\) 的批大小,\(O\) 是模型的输出大小,\([FP]\) 是通过模型的单次前向传递的成本。因此,此方法执行的前向传递次数比方法 1 雅可比矩阵收缩更多(\(N^2 O\) 而不是 \(N O\)),但完全避免了收缩成本(没有 \(N^2 O^2 P\) 项,其中 \(P\) 是模型的参数总数)。因此,当 \(O P\) 相对于 \([FP]\) 较大时,此方法更可取,例如具有许多输出 \(O\) 的全连接模型(而非卷积模型)。在内存方面,这两种方法应该是可比的。有关详细信息,请参见 快速有限宽度神经正切核 中的第 3.3 节。
脚本的总运行时间: ( 0 分钟 0.541 秒)