快捷方式

TensorDictModule

作者: Nicolas Dufour, Vincent Moens

在本教程中,您将学习如何使用 TensorDictModuleTensorDictSequential 来创建可接受 TensorDict 作为输入的通用且可重用的模块。

为了方便地将 TensorDict 类与 Module 一起使用,tensordict 提供了一个名为 TensorDictModule 的接口来连接它们。

TensorDictModule 类是一个 Module,它在调用时接受 TensorDict 作为输入。它将读取一系列输入键,将它们作为输入传递给包装的模块或函数,并在执行完成后将输出写入同一个 tensordict 中。

输入和输出的键由用户定义。

import torch
import torch.nn as nn
from tensordict import TensorDict
from tensordict.nn import TensorDictModule, TensorDictSequential

简单示例:编写一个循环层

下面示例了 TensorDictModule 的最简单用法。虽然乍一看使用这个类似乎会引入不必要的复杂性,但我们稍后会看到,这个 API 允许用户以编程方式将模块连接在一起,在模块之间缓存值,或者以编程方式构建模块。其中一个最简单的例子是 ResNet 等架构中的循环模块,其中模块的输入被缓存并添加到微型多层感知器 (MLP) 的输出中。

首先,让我们考虑如何将一个 MLP 分块,并使用 tensordict.nn 对其进行编码。堆栈中的第一层可能是一个 Linear 层,它接受一个输入项(我们将其命名为 x),并输出另一个项(我们将将其命名为 y)。

为了馈送给我们的模块,我们有一个包含单个项 "x"TensorDict 实例

tensordict = TensorDict(
    x=torch.randn(5, 3),
    batch_size=[5],
)

现在,我们使用 tensordict.nn.TensorDictModule 构建我们的简单模块。默认情况下,此类会就地写入输入 tensordict(这意味着条目被写入与输入相同的 tensordict 中,而不是条目被就地覆盖!),这样我们就不需要显式地指示输出是什么

linear0 = TensorDictModule(nn.Linear(3, 128), in_keys=["x"], out_keys=["linear0"])
linear0(tensordict)

assert "linear0" in tensordict

如果模块输出多个张量(或 tensordicts!),它们的条目必须按照正确的顺序传递给 TensorDictModule

支持可调用对象

在设计模型时,您经常会希望在网络中包含任意的非参数函数。例如,您可能希望在图像传递给卷积网络或视觉 transformer 时对其维度进行置换,或者将值除以 255。有几种方法可以实现这一点:例如,您可以使用 forward_hook,或者设计一个新的 Module 来执行此操作。

TensorDictModule 可以与任何可调用对象一起使用,而不仅仅是模块,这使得将任意函数集成到模块中变得容易。例如,让我们看看如何在不使用 ReLU 模块的情况下集成 relu 激活函数

relu0 = TensorDictModule(torch.relu, in_keys=["linear0"], out_keys=["relu0"])

堆叠模块

我们的 MLP 不是由单层组成的,所以现在我们需要为其添加另一层。这一层将是一个激活函数,例如 ReLU。我们可以使用 TensorDictSequential 将此模块与前一个模块堆叠在一起。

注意

tensordict.nn 的真正强大之处在于:与 Sequential 不同,TensorDictSequential 会在内存中保留所有先前的输入和输出(之后可以选择过滤掉它们),这使得能够轻松地动态且以编程方式构建复杂的网络结构。

block0 = TensorDictSequential(linear0, relu0)

block0(tensordict)
assert "linear0" in tensordict
assert "relu0" in tensordict

我们可以重复这个逻辑来得到一个完整的 MLP

linear1 = TensorDictModule(nn.Linear(128, 128), in_keys=["relu0"], out_keys=["linear1"])
relu1 = TensorDictModule(nn.ReLU(), in_keys=["linear1"], out_keys=["relu1"])
linear2 = TensorDictModule(nn.Linear(128, 3), in_keys=["relu1"], out_keys=["linear2"])
block1 = TensorDictSequential(linear1, relu1, linear2)

多个输入键

残差网络的最后一步是将输入添加到最后一个线性层的输出中。无需为此编写特殊的 Module 子类!TensorDictModule 也可以用于包装简单的函数

residual = TensorDictModule(
    lambda x, y: x + y, in_keys=["x", "linear2"], out_keys=["y"]
)

现在我们可以将 block0block1residual 组合起来,形成一个完整的残差块

block = TensorDictSequential(block0, block1, residual)
block(tensordict)
assert "y" in tensordict

一个真正令人担忧的问题可能是作为输入的 tensordict 中条目的累积:在某些情况下(例如,需要梯度时),中间值无论如何都可能被缓存,但这并非总是如此,通知垃圾回收器某些条目可以被丢弃可能会很有用。tensordict.nn.TensorDictModuleBase 及其子类(包括 tensordict.nn.TensorDictModuletensordict.nn.TensorDictSequential)可以选择在执行后过滤其输出键。为此,只需调用 tensordict.nn.TensorDictModuleBase.select_out_keys 方法。这将就地更新模块,并且所有不需要的条目将被丢弃

block.select_out_keys("y")

tensordict = TensorDict(x=torch.randn(1, 3), batch_size=[1])
block(tensordict)
assert "y" in tensordict

assert "linear1" not in tensordict

但是,输入键会被保留

assert "x" in tensordict

附带一提,selected_out_keys 也可以传递给 tensordict.nn.TensorDictSequential,以避免单独调用此方法。

不使用 tensordict 来使用 TensorDictModule

tensordict.nn.TensorDictSequential 提供的动态构建复杂架构的能力并不意味着必须切换到 tensordict 来表示数据。借助 dispatch,tensordict.nn 中的模块也支持与条目名称匹配的参数和关键字参数

x = torch.randn(1, 3)
y = block(x=x)
assert isinstance(y, torch.Tensor)

在底层,dispatch 重建一个 tensordict,运行模块,然后将其分解。这可能会带来一些开销,但是,正如我们稍后将看到的,有办法解决这个问题。

运行时性能

tensordict.nn.TensorDictModuletensordict.nn.TensorDictSequential 在执行时确实会产生一些开销,因为它们需要从 tensordict 读取和写入。但是,我们可以通过使用 compile() 来大大减少这种开销。为此,让我们比较一下此代码在使用 compile 和不使用 compile 时的三种版本

class ResidualBlock(nn.Module):
    def __init__(self):
        super().__init__()
        self.linear0 = nn.Linear(3, 128)
        self.relu0 = nn.ReLU()
        self.linear1 = nn.Linear(128, 128)
        self.relu1 = nn.ReLU()
        self.linear2 = nn.Linear(128, 3)

    def forward(self, x):
        y = self.linear0(x)
        y = self.relu0(y)
        y = self.linear1(y)
        y = self.relu1(y)
        return self.linear2(y) + x


print("Without compile")
x = torch.randn(256, 3)
block_notd = ResidualBlock()
block_tdm = TensorDictModule(block_notd, in_keys=["x"], out_keys=["y"])
block_tds = block

from torch.utils.benchmark import Timer

print(
    f"Regular: {Timer('block_notd(x=x)', globals=globals()).adaptive_autorange().median * 1_000_000: 4.4f} us"
)
print(
    f"TDM: {Timer('block_tdm(x=x)', globals=globals()).adaptive_autorange().median * 1_000_000: 4.4f} us"
)
print(
    f"Sequential: {Timer('block_tds(x=x)', globals=globals()).adaptive_autorange().median * 1_000_000: 4.4f} us"
)

print("Compiled versions")
block_notd_c = torch.compile(block_notd, mode="reduce-overhead")
for _ in range(5):  # warmup
    block_notd_c(x)
print(
    f"Compiled regular: {Timer('block_notd_c(x=x)', globals=globals()).adaptive_autorange().median * 1_000_000: 4.4f} us"
)
block_tdm_c = torch.compile(block_tdm, mode="reduce-overhead")
for _ in range(5):  # warmup
    block_tdm_c(x=x)
print(
    f"Compiled TDM: {Timer('block_tdm_c(x=x)', globals=globals()).adaptive_autorange().median * 1_000_000: 4.4f} us"
)
block_tds_c = torch.compile(block_tds, mode="reduce-overhead")
for _ in range(5):  # warmup
    block_tds_c(x=x)
print(
    f"Compiled sequential: {Timer('block_tds_c(x=x)', globals=globals()).adaptive_autorange().median * 1_000_000: 4.4f} us"
)
Without compile
Regular:  215.9519 us
TDM:  276.4528 us
Sequential:  486.3226 us
Compiled versions
Compiled regular:  327.3375 us
Compiled TDM:  370.3780 us
Compiled sequential:  382.0440 us

正如大家所见,TensorDictSequential 引入的开销已得到完全解决。

使用 TensorDictModule 的注意事项

  • 不要在来自 tensordict.nn 的模块周围使用 Sequence。这会破坏输入/输出键结构。始终尝试依赖 nn:TensorDictSequential

  • 不要将输出 tensordict 赋值给一个新变量,因为输出 tensordict 只是就地修改的输入。赋值一个新变量名并非严格禁止,但这意味着当一个被删除时,您可能希望两者都消失,而实际上垃圾回收器仍然会看到工作区中的张量,并且不会释放内存

    >>> tensordict = module(tensordict)  # ok!
    >>> tensordict_out = module(tensordict)  # don't!
    

处理分布:ProbabilisticTensorDictModule

ProbabilisticTensorDictModule 是一个表示概率分布的非参数模块。分布参数从 tensordict 输入中读取,输出写入输出 tensordict。根据输入 default_interaction_type 参数和 interaction_type() 全局函数指定的规则对输出进行采样。如果它们冲突,上下文管理器优先。

它可以与使用 ProbabilisticTensorDictSequential 更新了分布参数的 TensorDictModule 结合使用。这是 TensorDictSequential 的一个特例,其最后一层是一个 ProbabilisticTensorDictModule 实例。

ProbabilisticTensorDictModule 负责构建分布(通过 get_dist() 方法)和/或从该分布中进行采样(通过对模块进行常规的 forward 调用)。相同的 get_dist() 方法也在 ProbabilisticTensorDictSequential 中公开。

如果需要,可以在输出 tensordict 中找到参数以及对数概率。

from tensordict.nn import (
    ProbabilisticTensorDictModule,
    ProbabilisticTensorDictSequential,
)
from tensordict.nn.distributions import NormalParamExtractor
from torch import distributions as dist

td = TensorDict({"input": torch.randn(3, 4), "hidden": torch.randn(3, 8)}, [3])
net = torch.nn.GRUCell(4, 8)
net = TensorDictModule(net, in_keys=["input", "hidden"], out_keys=["hidden"])
extractor = NormalParamExtractor()
extractor = TensorDictModule(extractor, in_keys=["hidden"], out_keys=["loc", "scale"])
td_module = ProbabilisticTensorDictSequential(
    net,
    extractor,
    ProbabilisticTensorDictModule(
        in_keys=["loc", "scale"],
        out_keys=["action"],
        distribution_class=dist.Normal,
        return_log_prob=True,
    ),
)
print(f"TensorDict before going through module: {td}")
td_module(td)
print(f"TensorDict after going through module now as keys action, loc and scale: {td}")
TensorDict before going through module: TensorDict(
    fields={
        hidden: Tensor(shape=torch.Size([3, 8]), device=cpu, dtype=torch.float32, is_shared=False),
        input: Tensor(shape=torch.Size([3, 4]), device=cpu, dtype=torch.float32, is_shared=False)},
    batch_size=torch.Size([3]),
    device=None,
    is_shared=False)
TensorDict after going through module now as keys action, loc and scale: TensorDict(
    fields={
        action: Tensor(shape=torch.Size([3, 4]), device=cpu, dtype=torch.float32, is_shared=False),
        hidden: Tensor(shape=torch.Size([3, 8]), device=cpu, dtype=torch.float32, is_shared=False),
        input: Tensor(shape=torch.Size([3, 4]), device=cpu, dtype=torch.float32, is_shared=False),
        loc: Tensor(shape=torch.Size([3, 4]), device=cpu, dtype=torch.float32, is_shared=False),
        sample_log_prob: Tensor(shape=torch.Size([3, 4]), device=cpu, dtype=torch.float32, is_shared=False),
        scale: Tensor(shape=torch.Size([3, 4]), device=cpu, dtype=torch.float32, is_shared=False)},
    batch_size=torch.Size([3]),
    device=None,
    is_shared=False)

结论

我们已经看到 tensordict.nn 如何用于动态地即时构建复杂的神经网络架构。这开启了构建对模型签名无感知的管道的可能性,也就是说,可以编写通用的代码,以灵活的方式使用具有任意数量输入或输出的网络。

我们还看到 dispatch 如何使得能够使用 tensordict.nn 构建此类网络并使用它们,而无需直接使用 TensorDict。得益于 compile()tensordict.nn.TensorDictSequential 引入的开销可以被完全消除,从而为用户提供了一个整洁的、无需 tensordict 的模块版本。

在下一个教程中,我们将看到如何使用 torch.export 来隔离模块并将其导出。

脚本总运行时间: (0 分钟 16.867 秒)

图库由 Sphinx-Gallery 生成

文档

访问全面的 PyTorch 开发者文档

查看文档

教程

获取针对初学者和高级开发者的深度教程

查看教程

资源

查找开发资源并获得疑问解答

查看资源