• 教程 >
  • 自动微分保存张量钩子
快捷方式

自动微分保存张量钩子

PyTorch 通常使用反向传播来计算梯度。但是,某些操作需要保存中间结果才能执行反向传播。本教程将介绍这些张量是如何保存和检索的,以及如何定义钩子来控制打包和解包过程。

本教程假设你熟悉反向传播在理论上的工作原理。如果不太了解,请先阅读 此处 的内容。

保存的张量

训练模型通常比运行推理消耗更多的内存。一般来说,可以这样说,“PyTorch 需要保存计算图,该图用于调用 backward”,因此会额外使用内存。本教程的一个目标是调整对这一理解的认识。

实际上,图本身有时不会消耗更多内存,因为它不会复制任何张量。但是,图可以保留对原本会超出范围的张量的引用:这些被称为**保存的张量**。

为什么训练模型(通常)比评估模型消耗更多的内存?

我们从一个简单的例子开始:\(y = a \cdot b\),我们知道 \(y\) 相对于 \(a\)\(b\) 的梯度

\[\frac{\partial y}{\partial a} = b \]
\[\frac{\partial y}{\partial b} = a \]
import torch

a = torch.randn(5, requires_grad=True)
b = torch.ones(5, requires_grad=True)
y = a * b

使用 torchviz,我们可以可视化计算图

https://user-images.githubusercontent.com/8019486/130124513-72e016a3-c36f-42b9-88e2-53baf3e016c5.png

在这个例子中,PyTorch 保存了中间值 \(a\)\(b\),以便在反向传播期间计算梯度。

https://user-images.githubusercontent.com/8019486/130124538-3da50977-6f0b-46d0-8909-5456ade9b598.png

这些中间值(在上面的橙色部分)可以通过查找 ygrad_fn 的以 _saved 为前缀的属性来访问(用于调试目的)

print(y.grad_fn._saved_self)
print(y.grad_fn._saved_other)
tensor([ 0.3367,  0.1288,  0.2345,  0.2303, -1.1229], requires_grad=True)
tensor([1., 1., 1., 1., 1.], requires_grad=True)

随着计算图的深度增加,它将存储更多的保存的张量。同时,如果没有图,这些张量将超出范围。

def f(x):
    return x * x

x = torch.randn(5, requires_grad=True)
y = f(f(f(x)))
https://user-images.githubusercontent.com/8019486/130124570-f1074098-1bb3-459e-bf5a-03bf6f65b403.png

在上面的例子中,执行没有梯度只会保留 xy 在范围内,但图还会额外存储 f(x)f(f(x))。因此,在训练期间运行前向传递在内存使用方面比在评估期间(更准确地说,当不需要自动微分时)更昂贵。

打包/解包的概念

回到第一个例子:y.grad_fn._saved_selfy.grad_fn._saved_other 分别指向原始张量对象 ab

a = torch.randn(5, requires_grad=True)
b = torch.ones(5, requires_grad=True)
y = a * b

print(y.grad_fn._saved_self is a)   # True
print(y.grad_fn._saved_other is b)  # True
True
True

但是,情况并非总是如此。

a = torch.randn(5, requires_grad=True)
y = torch.exp(a)
print(y.grad_fn._saved_result.equal(y))  # True
print(y.grad_fn._saved_result is y)      # False
True
False

在幕后,PyTorch 对张量 y 进行了打包解包以防止引用循环。

根据经验,你不应该依赖于访问为反向保存的张量将产生与原始张量相同的张量对象这一事实。但是它们将共享相同的存储

保存的张量钩子

PyTorch 提供了一个 API 来控制如何打包/解包保存的张量。

def pack_hook(x):
    print("Packing", x)
    return x

def unpack_hook(x):
    print("Unpacking", x)
    return x
a = torch.ones(5, requires_grad=True)
b = torch.ones(5, requires_grad=True) * 2

with torch.autograd.graph.saved_tensors_hooks(pack_hook, unpack_hook):
    y = a * b

y.sum().backward()
Packing tensor([2., 2., 2., 2., 2.], grad_fn=<MulBackward0>)
Packing tensor([1., 1., 1., 1., 1.], requires_grad=True)
Unpacking tensor([2., 2., 2., 2., 2.], grad_fn=<MulBackward0>)
Unpacking tensor([1., 1., 1., 1., 1.], requires_grad=True)

当操作为反向保存张量时,pack_hook 函数将被调用。然后,pack_hook 的输出将存储在计算图中,而不是原始张量。unpack_hook 使用该返回值来计算一个新的张量,该张量是反向传播期间实际使用的张量。一般来说,你希望 unpack_hook(pack_hook(t)) 等于 t

x = torch.randn(5, requires_grad=True)
with torch.autograd.graph.saved_tensors_hooks(lambda x: x * 4, lambda x: x / 4):
    y = torch.pow(x, 2)
y.sum().backward()
assert(x.grad.equal(2 * x))

需要注意的一点是,pack_hook 的输出可以是任何 Python 对象,只要 unpack_hook 可以从中推导出具有正确值的张量即可。

一些非传统的例子

首先,一些愚蠢的例子来说明什么是可能的,但你可能永远不想这样做。

返回一个 int

返回 Python 列表的索引 比较无害,但实用性值得商榷

storage = []

def pack(x):
    storage.append(x)
    return len(storage) - 1

def unpack(x):
    return storage[x]

x = torch.randn(5, requires_grad=True)
with torch.autograd.graph.saved_tensors_hooks(pack, unpack):
    y = x * x
y.sum().backward()

assert(x.grad.equal(2 * x))

返回一个元组

返回一些张量和一个函数,说明如何解包它 当前形式不太可能有用

def pack(x):
    delta = torch.randn(*x.size())
    return x - delta, lambda x: x + delta

def unpack(packed):
    x, f = packed
    return f(x)


x = torch.randn(5, requires_grad=True)
with torch.autograd.graph.saved_tensors_hooks(pack, unpack):
    y = x * x
y.sum().backward()

assert(torch.allclose(x.grad, 2 * x))

返回一个 str

返回张量的 __repr__ of 可能永远不要这样做

x = torch.randn(5, requires_grad=True)
with torch.autograd.graph.saved_tensors_hooks(lambda x: repr(x), lambda x: eval("torch." + x)):
    y = x * x
y.sum().backward()
assert(torch.all(x.grad - 2 * x <= 1e-4))

尽管这些例子在实践中没有用,但它们说明了 pack_hook 的输出实际上可以是任何 Python 对象,只要它包含足够的信息来检索原始张量的内容即可。在接下来的部分中,我们将重点介绍更有用的应用程序。

将张量保存到 CPU

通常,计算图中涉及的张量驻留在 GPU 上。在图中保留对这些张量的引用是导致大多数模型在训练期间耗尽 GPU 内存的原因,而它们在评估期间本可以正常运行。

钩子提供了一个非常简单的方法来实现这一点。

def pack_hook(x):
    return (x.device, x.cpu())

def unpack_hook(packed):
    device, tensor = packed
    return tensor.to(device)

x = torch.randn(5, requires_grad=True)
with torch.autograd.graph.saved_tensors_hooks(pack, unpack):
    y = x * x
y.sum().backward()

torch.allclose(x.grad, (2 * x))
True

事实上,PyTorch 提供了一个 API 来方便地使用这些钩子(以及使用固定内存的能力)。

import torch.nn as nn

class Model(nn.Module):
    def __init__(self):
        super().__init__()
        self.w = nn.Parameter(torch.randn(5))

    def forward(self, x):
        with torch.autograd.graph.save_on_cpu(pin_memory=True):
            # some computation
            return self.w * x

x = torch.randn(5)
model = Model()
loss = model(x).sum()
loss.backward()

在实践中,在一个 A100 GPU 上,对于批大小为 256 的 ResNet-152,这对应于 GPU 内存使用量从 48GB 减少到 5GB,代价是速度降低 6 倍。

当然,你可以通过只将网络的某些部分保存到 CPU 来调节这种权衡。

例如,你可以定义一个特殊的 nn.Module 来包装任何模块并将它的张量保存到 CPU。

class SaveToCpu(nn.Module):
    def __init__(self, module):
        super().__init__()
        self.module = module

    def forward(self, *args, **kwargs):
        with torch.autograd.graph.save_on_cpu(pin_memory=True):
            return self.module(*args, **kwargs)

model = nn.Sequential(
    nn.Linear(10, 100),
    SaveToCpu(nn.Linear(100, 100)),
    nn.Linear(100, 10),
)

x = torch.randn(10)
loss = model(x).sum()
loss.backward()

将张量保存到磁盘

类似地,你可能希望将这些张量保存到磁盘。同样,这可以通过这些钩子实现。

一个简单的版本看起来像这样。

# Naive version - HINT: Don't do this

import uuid
tmp_dir = "temp"

def pack_hook(tensor):
    name = os.path.join(tmp_dir, str(uuid.uuid4()))
    torch.save(tensor, name)
    return name

def unpack_hook(name):
    return torch.load(name, weights_only=True)

上面的代码之所以不好,是因为我们在磁盘上泄漏了文件,而且它们永远不会被清除。修复它并不像看起来那样简单。

# Incorrect version - HINT: Don't do this

import uuid
import os
import tempfile
tmp_dir_obj = tempfile.TemporaryDirectory()
tmp_dir = tmp_dir_obj.name

def pack_hook(tensor):
    name = os.path.join(tmp_dir, str(uuid.uuid4()))
    torch.save(tensor, name)
    return name

def unpack_hook(name):
    tensor = torch.load(name, weights_only=True)
    os.remove(name)
    return tensor

上面的代码之所以不起作用,是因为 unpack_hook 可以被调用多次。如果我们在第一次解包时删除了文件,那么在第二次访问保存的张量时该文件将不可用,这会导致错误。

x = torch.ones(5, requires_grad=True)
with torch.autograd.graph.saved_tensors_hooks(pack_hook, unpack_hook):
    y = x.pow(2)
print(y.grad_fn._saved_self)
try:
    print(y.grad_fn._saved_self)
    print("Double access succeeded!")
except:
    print("Double access failed!")
tensor([1., 1., 1., 1., 1.], requires_grad=True)
Double access failed!

为了解决这个问题,我们可以编写一个版本的这些钩子,利用 PyTorch 自动释放(删除)不再需要保存的数据这一事实。

class SelfDeletingTempFile():
    def __init__(self):
        self.name = os.path.join(tmp_dir, str(uuid.uuid4()))

    def __del__(self):
        os.remove(self.name)

def pack_hook(tensor):
    temp_file = SelfDeletingTempFile()
    torch.save(tensor, temp_file.name)
    return temp_file

def unpack_hook(temp_file):
    return torch.load(temp_file.name, weights_only=True)

当我们调用 backward 时,pack_hook 的输出将被删除,这会导致文件被删除,因此我们不再泄漏文件。

然后,这可以在你的模型中以以下方式使用

# Only save on disk tensors that have size >= 1000
SAVE_ON_DISK_THRESHOLD = 1000

def pack_hook(x):
    if x.numel() < SAVE_ON_DISK_THRESHOLD:
        return x
    temp_file = SelfDeletingTempFile()
    torch.save(tensor, temp_file.name)
    return temp_file

def unpack_hook(tensor_or_sctf):
    if isinstance(tensor_or_sctf, torch.Tensor):
        return tensor_or_sctf
    return torch.load(tensor_or_sctf.name)

class SaveToDisk(nn.Module):
    def __init__(self, module):
        super().__init__()
        self.module = module

    def forward(self, *args, **kwargs):
        with torch.autograd.graph.saved_tensors_hooks(pack_hook, unpack_hook):
            return self.module(*args, **kwargs)

net = nn.DataParallel(SaveToDisk(Model()))

在这个最后一个例子中,我们还演示了如何过滤哪些张量应该被保存(这里,那些元素数量大于 1000 的张量),以及如何将此功能与 nn.DataParallel 相结合。

如果你已经看到了这里,恭喜你!你现在知道如何使用保存的张量钩子,以及它们如何在一些场景中用于权衡内存和计算。

脚本的总运行时间:(0 分钟 0.035 秒)

由 Sphinx-Gallery 生成的图库

文档

访问 PyTorch 的全面开发人员文档

查看文档

教程

获取针对初学者和高级开发人员的深入教程

查看教程

资源

查找开发资源并获得解答

查看资源