• 教程 >
  • autograd 保存张量的钩子
快捷方式

autograd 保存张量的钩子

创建于:2021 年 11 月 03 日 | 最后更新:2024 年 8 月 27 日 | 最后验证:未验证

PyTorch 通常使用反向传播计算梯度。然而,某些操作需要保存中间结果才能执行反向传播。本教程将介绍这些张量是如何保存/检索的,以及如何定义钩子来控制打包/解包过程。

本教程假设您熟悉反向传播的理论工作原理。如果不熟悉,请先阅读这个

保存的张量

训练模型通常比运行模型进行推理消耗更多内存。广义来说,可以认为这是因为“PyTorch 需要保存计算图,以便调用 backward”,因此增加了内存使用量。本教程的一个目标是微调这种理解。

事实上,图本身有时不会消耗更多内存,因为它从不复制任何张量。然而,图可以保留对原本会超出作用域的张量的引用:这些被称为保存的张量

为什么训练模型(通常)比评估模型需要更多内存?

我们从一个简单的例子开始:\(y = a \cdot b\),我们知道 \(y\)\(a\)\(b\) 的梯度

\[\frac{\partial y}{\partial a} = b \]
\[\frac{\partial y}{\partial b} = a \]
import torch

a = torch.randn(5, requires_grad=True)
b = torch.ones(5, requires_grad=True)
y = a * b

使用 torchviz,我们可以可视化计算图

https://user-images.githubusercontent.com/8019486/130124513-72e016a3-c36f-42b9-88e2-53baf3e016c5.png

在这个例子中,PyTorch 保存中间值 \(a\)\(b\),以便在反向传播期间计算梯度。

https://user-images.githubusercontent.com/8019486/130124538-3da50977-6f0b-46d0-8909-5456ade9b598.png

这些中间值(在上面的橙色部分中)可以通过查找以 _saved 为前缀的 ygrad_fn 的属性来访问(用于调试目的)

print(y.grad_fn._saved_self)
print(y.grad_fn._saved_other)
tensor([ 0.3367,  0.1288,  0.2345,  0.2303, -1.1229], requires_grad=True)
tensor([1., 1., 1., 1., 1.], requires_grad=True)

随着计算图深度的增加,它将存储更多保存的张量。与此同时,如果不是因为图,这些张量将超出作用域。

def f(x):
    return x * x

x = torch.randn(5, requires_grad=True)
y = f(f(f(x)))
https://user-images.githubusercontent.com/8019486/130124570-f1074098-1bb3-459e-bf5a-03bf6f65b403.png

在上面的例子中,在没有 grad 的情况下执行只会将 xy 保留在作用域中,但图还会额外存储 f(x)f(f(x))。因此,在训练期间运行前向传播在内存使用方面将比评估期间(更准确地说,当不需要 autograd 时)更昂贵。

打包 / 解包的概念

回到第一个例子:y.grad_fn._saved_selfy.grad_fn._saved_other 分别指向原始张量对象 ab

a = torch.randn(5, requires_grad=True)
b = torch.ones(5, requires_grad=True)
y = a * b

print(y.grad_fn._saved_self is a)   # True
print(y.grad_fn._saved_other is b)  # True
True
True

然而,情况并非总是如此。

a = torch.randn(5, requires_grad=True)
y = torch.exp(a)
print(y.grad_fn._saved_result.equal(y))  # True
print(y.grad_fn._saved_result is y)      # False
True
False

在底层,PyTorch 已经打包解包了张量 y 以防止引用循环。

根据经验,您不应依赖于访问为反向传播保存的张量会产生与原始张量相同的张量对象。但是,它们将共享相同的存储

保存的张量钩子

PyTorch 提供了一个 API 来控制应如何打包/解包保存的张量。

def pack_hook(x):
    print("Packing", x)
    return x

def unpack_hook(x):
    print("Unpacking", x)
    return x
a = torch.ones(5, requires_grad=True)
b = torch.ones(5, requires_grad=True) * 2

with torch.autograd.graph.saved_tensors_hooks(pack_hook, unpack_hook):
    y = a * b

y.sum().backward()
Packing tensor([2., 2., 2., 2., 2.], grad_fn=<MulBackward0>)
Packing tensor([1., 1., 1., 1., 1.], requires_grad=True)
Unpacking tensor([2., 2., 2., 2., 2.], grad_fn=<MulBackward0>)
Unpacking tensor([1., 1., 1., 1., 1.], requires_grad=True)

pack_hook 函数将在每次操作为反向传播保存张量时调用。pack_hook 的输出随后存储在计算图中,而不是原始张量。unpack_hook 使用该返回值来计算新张量,该张量是反向传播期间实际使用的张量。通常,您希望 unpack_hook(pack_hook(t)) 等于 t

x = torch.randn(5, requires_grad=True)
with torch.autograd.graph.saved_tensors_hooks(lambda x: x * 4, lambda x: x / 4):
    y = torch.pow(x, 2)
y.sum().backward()
assert(x.grad.equal(2 * x))

需要注意的一点是,pack_hook 的输出可以是任何 Python 对象,只要 unpack_hook 可以从中派生出具有正确值的张量即可。

一些非常规的例子

首先,一些愚蠢的例子来说明什么是可能的,但您可能永远都不想这样做。

返回一个 int

返回 Python 列表的索引,相对无害,但用处值得商榷

storage = []

def pack(x):
    storage.append(x)
    return len(storage) - 1

def unpack(x):
    return storage[x]

x = torch.randn(5, requires_grad=True)
with torch.autograd.graph.saved_tensors_hooks(pack, unpack):
    y = x * x
y.sum().backward()

assert(x.grad.equal(2 * x))

返回一个元组

返回一些张量和一个解包它的函数,在其当前形式下不太可能有用

def pack(x):
    delta = torch.randn(*x.size())
    return x - delta, lambda x: x + delta

def unpack(packed):
    x, f = packed
    return f(x)


x = torch.randn(5, requires_grad=True)
with torch.autograd.graph.saved_tensors_hooks(pack, unpack):
    y = x * x
y.sum().backward()

assert(torch.allclose(x.grad, 2 * x))

返回一个 str

返回张量的 __repr__ of,可能永远不会这样做

x = torch.randn(5, requires_grad=True)
with torch.autograd.graph.saved_tensors_hooks(lambda x: repr(x), lambda x: eval("torch." + x)):
    y = x * x
y.sum().backward()
assert(torch.all(x.grad - 2 * x <= 1e-4))

虽然这些例子在实践中没有用处,但它们说明 pack_hook 的输出实际上可以是任何 Python 对象,只要它包含足够的信息来检索原始张量的内容。在接下来的章节中,我们将重点介绍更有用的应用。

将张量保存到 CPU

通常,计算图中涉及的张量位于 GPU 上。在图中保留对这些张量的引用是导致大多数模型在训练期间耗尽 GPU 内存的原因,而它们在评估期间本可以正常运行。

Hook 提供了一种非常简单的方法来实现这一点。

def pack_hook(x):
    return (x.device, x.cpu())

def unpack_hook(packed):
    device, tensor = packed
    return tensor.to(device)

x = torch.randn(5, requires_grad=True)
with torch.autograd.graph.saved_tensors_hooks(pack, unpack):
    y = x * x
y.sum().backward()

torch.allclose(x.grad, (2 * x))
True

事实上,PyTorch 提供了一个 API 来方便地使用这些 hook(以及使用 pinned memory 的能力)。

import torch.nn as nn

class Model(nn.Module):
    def __init__(self):
        super().__init__()
        self.w = nn.Parameter(torch.randn(5))

    def forward(self, x):
        with torch.autograd.graph.save_on_cpu(pin_memory=True):
            # some computation
            return self.w * x

x = torch.randn(5)
model = Model()
loss = model(x).sum()
loss.backward()

在实践中,在 A100 GPU 上,对于批量大小为 256 的 ResNet-152,这对应于 GPU 内存使用量从 48GB 减少到 5GB,但代价是速度降低 6 倍。

当然,您可以通过仅将网络Certain 部分保存到 CPU 来调节权衡。

例如,您可以定义一个特殊的 nn.Module,它包装任何模块并将其张量保存到 CPU。

class SaveToCpu(nn.Module):
    def __init__(self, module):
        super().__init__()
        self.module = module

    def forward(self, *args, **kwargs):
        with torch.autograd.graph.save_on_cpu(pin_memory=True):
            return self.module(*args, **kwargs)

model = nn.Sequential(
    nn.Linear(10, 100),
    SaveToCpu(nn.Linear(100, 100)),
    nn.Linear(100, 10),
)

x = torch.randn(10)
loss = model(x).sum()
loss.backward()

将张量保存到磁盘

同样,您可能希望将这些张量保存到磁盘。同样,这可以通过这些 hook 来实现。

一个朴素的版本看起来像这样。

# Naive version - HINT: Don't do this

import uuid
tmp_dir = "temp"

def pack_hook(tensor):
    name = os.path.join(tmp_dir, str(uuid.uuid4()))
    torch.save(tensor, name)
    return name

def unpack_hook(name):
    return torch.load(name, weights_only=True)

上述代码不好的原因是我们在磁盘上泄漏了文件,并且它们永远不会被清除。修复这个问题并不像看起来那么简单。

# Incorrect version - HINT: Don't do this

import uuid
import os
import tempfile
tmp_dir_obj = tempfile.TemporaryDirectory()
tmp_dir = tmp_dir_obj.name

def pack_hook(tensor):
    name = os.path.join(tmp_dir, str(uuid.uuid4()))
    torch.save(tensor, name)
    return name

def unpack_hook(name):
    tensor = torch.load(name, weights_only=True)
    os.remove(name)
    return tensor

上述代码不起作用的原因是 unpack_hook 可以被多次调用。如果我们在第一次解包期间删除文件,则在第二次访问保存的张量时,该文件将不可用,这将引发错误。

x = torch.ones(5, requires_grad=True)
with torch.autograd.graph.saved_tensors_hooks(pack_hook, unpack_hook):
    y = x.pow(2)
print(y.grad_fn._saved_self)
try:
    print(y.grad_fn._saved_self)
    print("Double access succeeded!")
except:
    print("Double access failed!")
tensor([1., 1., 1., 1., 1.], requires_grad=True)
Double access failed!

为了解决这个问题,我们可以编写一个版本的 hook,它利用了 PyTorch 在不再需要保存的数据时自动释放(删除)数据的事实。

class SelfDeletingTempFile():
    def __init__(self):
        self.name = os.path.join(tmp_dir, str(uuid.uuid4()))

    def __del__(self):
        os.remove(self.name)

def pack_hook(tensor):
    temp_file = SelfDeletingTempFile()
    torch.save(tensor, temp_file.name)
    return temp_file

def unpack_hook(temp_file):
    return torch.load(temp_file.name, weights_only=True)

当我们调用 backward 时,pack_hook 的输出将被删除,这将导致文件被删除,因此我们不再泄漏文件。

然后可以在您的模型中以下列方式使用它

# Only save on disk tensors that have size >= 1000
SAVE_ON_DISK_THRESHOLD = 1000

def pack_hook(x):
    if x.numel() < SAVE_ON_DISK_THRESHOLD:
        return x
    temp_file = SelfDeletingTempFile()
    torch.save(tensor, temp_file.name)
    return temp_file

def unpack_hook(tensor_or_sctf):
    if isinstance(tensor_or_sctf, torch.Tensor):
        return tensor_or_sctf
    return torch.load(tensor_or_sctf.name)

class SaveToDisk(nn.Module):
    def __init__(self, module):
        super().__init__()
        self.module = module

    def forward(self, *args, **kwargs):
        with torch.autograd.graph.saved_tensors_hooks(pack_hook, unpack_hook):
            return self.module(*args, **kwargs)

net = nn.DataParallel(SaveToDisk(Model()))

在最后一个示例中,我们还演示了如何过滤应保存哪些张量(此处为元素数量大于 1000 的张量),以及如何将此功能与 nn.DataParallel 结合使用。

如果您已经走到这一步,恭喜您!您现在知道如何使用保存的张量 hook,以及它们如何在一些场景中用于权衡内存以换取计算。

脚本的总运行时间: ( 0 分钟 0.036 秒)

由 Sphinx-Gallery 生成的图库


评价本教程

© Copyright 2024, PyTorch。

使用 Sphinx 构建,主题由 theme 提供,由 Read the Docs 提供。

文档

访问 PyTorch 的全面开发者文档

查看文档

教程

获取面向初学者和高级开发者的深入教程

查看教程

资源

查找开发资源并获得您的问题的解答

查看资源