每样本梯度¶
它是什么?¶
每样本梯度计算是对一批数据中的每个样本计算梯度。它是差分隐私、元学习和优化研究中的一个有用量。
import torch
import torch.nn as nn
import torch.nn.functional as F
from functools import partial
torch.manual_seed(0);
# Here's a simple CNN and loss function:
class SimpleCNN(nn.Module):
def __init__(self):
super().__init__()
self.conv1 = nn.Conv2d(1, 32, 3, 1)
self.conv2 = nn.Conv2d(32, 64, 3, 1)
self.fc1 = nn.Linear(9216, 128)
self.fc2 = nn.Linear(128, 10)
def forward(self, x):
x = self.conv1(x)
x = F.relu(x)
x = self.conv2(x)
x = F.relu(x)
x = F.max_pool2d(x, 2)
x = torch.flatten(x, 1)
x = self.fc1(x)
x = F.relu(x)
x = self.fc2(x)
output = F.log_softmax(x, dim=1)
output = x
return output
def loss_fn(predictions, targets):
return F.nll_loss(predictions, targets)
让我们生成一批虚拟数据,并假装我们正在使用 MNIST 数据集。
虚拟图像大小为 28x28,我们使用大小为 64 的小批量。
device = 'cuda'
num_models = 10
batch_size = 64
data = torch.randn(batch_size, 1, 28, 28, device=device)
targets = torch.randint(10, (64,), device=device)
在常规模型训练中,人们会将小批量通过模型进行前向传播,然后调用 .backward() 计算梯度。这将生成整个小批量的“平均”梯度。
model = SimpleCNN().to(device=device)
predictions = model(data) # move the entire mini-batch through the model
loss = loss_fn(predictions, targets)
loss.backward() # back propogate the 'average' gradient of this mini-batch
与上述方法相反,每样本梯度计算等同于
对于数据的每个单个样本,执行前向和后向传播以获得单个(每样本)梯度。
def compute_grad(sample, target):
sample = sample.unsqueeze(0) # prepend batch dimension for processing
target = target.unsqueeze(0)
prediction = model(sample)
loss = loss_fn(prediction, target)
return torch.autograd.grad(loss, list(model.parameters()))
def compute_sample_grads(data, targets):
""" manually process each sample with per sample gradient """
sample_grads = [compute_grad(data[i], targets[i]) for i in range(batch_size)]
sample_grads = zip(*sample_grads)
sample_grads = [torch.stack(shards) for shards in sample_grads]
return sample_grads
per_sample_grads = compute_sample_grads(data, targets)
sample_grads[0]
是模型 .conv1.weight 的每样本梯度。 model.conv1.weight.shape
为 [32, 1, 3, 3]
;注意,批次中每个样本都有一个梯度,总共 64 个。
print(per_sample_grads[0].shape)
torch.Size([64, 32, 1, 3, 3])
每样本梯度,高效的方式,使用 functorch¶
我们可以使用函数变换来高效地计算每样本梯度。
首先,让我们使用 functorch.make_functional_with_buffers
创建 model
的无状态函数版本。
这将分离状态(参数)和模型,并将模型变成一个纯函数
from functorch import make_functional_with_buffers, vmap, grad
fmodel, params, buffers = make_functional_with_buffers(model)
让我们回顾一下更改——首先,模型变成了无状态的 FunctionalModuleWithBuffers
fmodel
FunctionalModuleWithBuffers(
(stateless_model): SimpleCNN(
(conv1): Conv2d(1, 32, kernel_size=(3, 3), stride=(1, 1))
(conv2): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1))
(fc1): Linear(in_features=9216, out_features=128, bias=True)
(fc2): Linear(in_features=128, out_features=10, bias=True)
)
)
并且模型参数现在独立于模型存在,存储为元组
for x in params:
print(f"{x.shape}")
print(f"\n{type(params)}")
torch.Size([32, 1, 3, 3])
torch.Size([32])
torch.Size([64, 32, 3, 3])
torch.Size([64])
torch.Size([128, 9216])
torch.Size([128])
torch.Size([10, 128])
torch.Size([10])
<class 'tuple'>
接下来,让我们定义一个函数,用于计算模型在给定单个输入而不是一批输入时的损失。重要的是,此函数接受参数、输入和目标,因为我们将对它们进行变换。
注意 - 因为模型最初是为处理批次而编写的,所以我们将使用 torch.unsqueeze
添加一个批次维度。
def compute_loss_stateless_model (params, buffers, sample, target):
batch = sample.unsqueeze(0)
targets = target.unsqueeze(0)
predictions = fmodel(params, buffers, batch)
loss = loss_fn(predictions, targets)
return loss
现在,让我们使用 functorch 的 grad
创建一个新函数,用于计算 compute_loss
的第一个参数(即 params)的梯度。
ft_compute_grad = grad(compute_loss_stateless_model)
ft_compute_grad
函数计算单个 (sample, target) 对的梯度。我们可以使用 vmap 使其能够计算一整批样本和目标的梯度。请注意 in_dims=(None, None, 0, 0)
,因为我们希望对数据和目标的第 0 维映射 ft_compute_grad
,并对每个使用相同的 params 和 buffers。
ft_compute_sample_grad = vmap(ft_compute_grad, in_dims=(None, None, 0, 0))
最后,让我们使用我们转换后的函数计算每样本梯度
ft_per_sample_grads = ft_compute_sample_grad(params, buffers, data, targets)
# we can double check that the results using functorch grad and vmap match the results of hand processing each one individually:
for per_sample_grad, ft_per_sample_grad in zip(per_sample_grads, ft_per_sample_grads):
assert torch.allclose(per_sample_grad, ft_per_sample_grad, atol=3e-3, rtol=1e-5)
快速说明:关于 vmap 可以变换的函数类型存在一些限制。最适合变换的函数是纯函数:输出仅由输入决定的函数,并且没有副作用(例如,变异)。vmap 无法处理任意 Python 数据结构的变异,但它能够处理许多就地 PyTorch 操作。
性能比较¶
您是否想知道 vmap 的性能如何?
目前,在 A100(Ampere)等较新的 GPU 上获得了最佳结果,在这个示例中,我们已经看到了高达 25 倍的加速,但这里有一些在 Colab 中完成的结果
def get_perf(first, first_descriptor, second, second_descriptor):
""" takes torch.benchmark objects and compares delta of second vs first. """
second_res = second.times[0]
first_res = first.times[0]
gain = (first_res-second_res)/first_res
if gain < 0: gain *=-1
final_gain = gain*100
print(f" Performance delta: {final_gain:.4f} percent improvement with {first_descriptor} ")
from torch.utils.benchmark import Timer
without_vmap = Timer( stmt="compute_sample_grads(data, targets)", globals=globals())
with_vmap = Timer(stmt="ft_compute_sample_grad(params, buffers, data, targets)",globals=globals())
no_vmap_timing = without_vmap.timeit(100)
with_vmap_timing = with_vmap.timeit(100)
print(f'Per-sample-grads without vmap {no_vmap_timing}')
print(f'Per-sample-grads with vmap {with_vmap_timing}')
Per-sample-grads without vmap <torch.utils.benchmark.utils.common.Measurement object at 0x7f71ac3f1850>
compute_sample_grads(data, targets)
79.86 ms
1 measurement, 100 runs , 1 thread
Per-sample-grads with vmap <torch.utils.benchmark.utils.common.Measurement object at 0x7f7143e26f10>
ft_compute_sample_grad(params, buffers, data, targets)
12.93 ms
1 measurement, 100 runs , 1 thread
get_perf(with_vmap_timing, "vmap", no_vmap_timing,"no vmap" )
Performance delta: 517.5791 percent improvement with vmap
在 PyTorch 中计算每样本梯度还有其他优化解决方案(如 https://github.com/pytorch/opacus),这些解决方案的性能也优于朴素方法。但很酷的是,组合 vmap
和 grad
为我们提供了不错的加速。
一般来说,使用 vmap 进行向量化应该比在 for 循环中运行函数更快,并且与手动批处理具有竞争力。但是,也有一些例外,例如,如果我们还没有为特定操作实现 vmap 规则,或者如果底层内核没有针对旧硬件(GPU)进行优化。如果您遇到这些情况,请通过在我们 GitHub 上创建问题来告知我们!