快捷方式

TensorDictModule

作者: Nicolas Dufour, Vincent Moens

在本教程中,您将学习如何使用 TensorDictModuleTensorDictSequential 创建通用的、可重用的模块,这些模块可以接受 TensorDict 作为输入。

为了方便 TensorDict 类与 Module 一起使用,tensordict 提供了两者之间的接口,名为 TensorDictModule

TensorDictModule 类是一个 Module,它在被调用时将 TensorDict 作为输入。它将读取一系列输入键,将它们作为输入传递给包装的模块或函数,并在完成执行后将输出写入同一个 tensordict 中。

由用户来定义要读取作为输入和输出的键。

import torch
import torch.nn as nn
from tensordict import TensorDict
from tensordict.nn import TensorDictModule, TensorDictSequential

简单示例:编码循环层

TensorDictModule 最简单的用法示例如下。如果乍一看,使用此类似乎引入了不必要的复杂性,我们稍后会看到,此 API 使​​用户可以编程方式将模块连接在一起,在模块之间缓存值或以编程方式构建模块。其中一个最简单的例子是 ResNet 等架构中的循环模块,其中模块的输入被缓存并添加到小型多层感知器 (MLP) 的输出中。

首先,让我们考虑一下您将如何分块 MLP,并使用 tensordict.nn 对其进行编码。堆栈的第一层可能是 Linear 层,它接受一个条目作为输入(我们将其命名为 x)并输出另一个条目(我们将其命名为 y)。

为了馈送到我们的模块,我们有一个 TensorDict 实例,其中包含一个条目,"x"

tensordict = TensorDict(
    x=torch.randn(5, 3),
    batch_size=[5],
)

现在,我们使用 tensordict.nn.TensorDictModule 构建我们的简单模块。默认情况下,此类在输入 tensordict 中就地写入(意味着条目写入与输入相同的 tensordict 中,而不是条目被就地覆盖!),因此我们无需显式指示输出是什么

linear0 = TensorDictModule(nn.Linear(3, 128), in_keys=["x"], out_keys=["linear0"])
linear0(tensordict)

assert "linear0" in tensordict

如果模块输出多个张量(或 tensordict!),则必须以正确的顺序将它们的条目传递给 TensorDictModule

对可调用对象的支持

在设计模型时,经常会发生您希望将任意非参数函数合并到网络中的情况。例如,您可能希望在图像传递到卷积网络或视觉 Transformer 时置换图像的维度,或者将值除以 255。有几种方法可以做到这一点:例如,您可以使用 forward_hook,或者设计一个新的 Module 来执行此操作。

TensorDictModule 可以与任何可调用对象一起使用,而不仅仅是模块,这使得将任意函数合并到模块中变得容易。例如,让我们看看如何在不使用 ReLU 模块的情况下集成 relu 激活函数

relu0 = TensorDictModule(torch.relu, in_keys=["linear0"], out_keys=["relu0"])

堆叠模块

我们的 MLP 不是由单层组成的,因此我们现在需要向其添加另一层。这一层将是一个激活函数,例如 ReLU。我们可以使用 TensorDictSequential 堆叠此模块和上一个模块。

注意

这是 tensordict.nn 的真正力量所在:与 Sequential 不同,TensorDictSequential 将在内存中保留所有先前的输入和输出(并有可能在之后过滤掉它们),从而可以轻松地动态且以编程方式构建复杂的网络结构。

block0 = TensorDictSequential(linear0, relu0)

block0(tensordict)
assert "linear0" in tensordict
assert "relu0" in tensordict

我们可以重复此逻辑以获得完整的 MLP

linear1 = TensorDictModule(nn.Linear(128, 128), in_keys=["relu0"], out_keys=["linear1"])
relu1 = TensorDictModule(nn.ReLU(), in_keys=["linear1"], out_keys=["relu1"])
linear2 = TensorDictModule(nn.Linear(128, 3), in_keys=["relu1"], out_keys=["linear2"])
block1 = TensorDictSequential(linear1, relu1, linear2)

多个输入键

残差网络的最后一步是将输入添加到最后一个线性层的输出。无需为此编写特殊的 Module 子类!TensorDictModule 也可以用于包装简单的函数

residual = TensorDictModule(
    lambda x, y: x + y, in_keys=["x", "linear2"], out_keys=["y"]
)

现在我们可以将 block0block1residual 放在一起,构成一个完整的残差块

block = TensorDictSequential(block0, block1, residual)
block(tensordict)
assert "y" in tensordict

一个真正的担忧可能是用作输入的 tensordict 中条目的累积:在某些情况下(例如,当需要梯度时),无论如何都可能缓存中间值,但这并非总是如此,并且让垃圾收集器知道可以丢弃某些条目可能很有用。tensordict.nn.TensorDictModuleBase 及其子类(包括 tensordict.nn.TensorDictModuletensordict.nn.TensorDictSequential)可以选择在执行后过滤其输出键。为此,只需调用 tensordict.nn.TensorDictModuleBase.select_out_keys 方法。这将就地更新模块,并且所有不需要的条目都将被丢弃

block.select_out_keys("y")

tensordict = TensorDict(x=torch.randn(1, 3), batch_size=[1])
block(tensordict)
assert "y" in tensordict

assert "linear1" not in tensordict

但是,输入键会被保留

assert "x" in tensordict

作为旁注,selected_out_keys 也可以传递给 TensorDictSequential,以避免单独调用此方法。

在没有 tensordict 的情况下使用 TensorDictModule

TensorDictSequential 提供的随时构建复杂架构的机会并不意味着必须切换到 tensordict 来表示数据。感谢 dispatch,来自 tensordict.nn 的模块也支持与条目名称匹配的参数和关键字参数

x = torch.randn(1, 3)
y = block(x=x)
assert isinstance(y, torch.Tensor)

在底层,dispatch 会重建 tensordict,运行模块,然后解构它。这可能会导致一些开销,但正如我们稍后将看到的,有一种解决方案可以消除这种开销。

运行时

tensordict.nn.TensorDictModuletensordict.nn.TensorDictSequential 在执行时确实会产生一些开销,因为它们需要从 tensordict 中读取和写入。但是,我们可以通过使用 compile() 大大减少这种开销。为此,让我们比较一下使用和不使用 compile 的这三种代码版本

class ResidualBlock(nn.Module):
    def __init__(self):
        super().__init__()
        self.linear0 = nn.Linear(3, 128)
        self.relu0 = nn.ReLU()
        self.linear1 = nn.Linear(128, 128)
        self.relu1 = nn.ReLU()
        self.linear2 = nn.Linear(128, 3)

    def forward(self, x):
        y = self.linear0(x)
        y = self.relu0(y)
        y = self.linear1(y)
        y = self.relu1(y)
        return self.linear2(y) + x


print("Without compile")
x = torch.randn(256, 3)
block_notd = ResidualBlock()
block_tdm = TensorDictModule(block_notd, in_keys=["x"], out_keys=["y"])
block_tds = block

from torch.utils.benchmark import Timer

print(
    f"Regular: {Timer('block_notd(x=x)', globals=globals()).adaptive_autorange().median*1_000_000: 4.4f} us"
)
print(
    f"TDM: {Timer('block_tdm(x=x)', globals=globals()).adaptive_autorange().median*1_000_000: 4.4f} us"
)
print(
    f"Sequential: {Timer('block_tds(x=x)', globals=globals()).adaptive_autorange().median*1_000_000: 4.4f} us"
)

print("Compiled versions")
block_notd_c = torch.compile(block_notd, mode="reduce-overhead")
for _ in range(5):  # warmup
    block_notd_c(x)
print(
    f"Compiled regular: {Timer('block_notd_c(x=x)', globals=globals()).adaptive_autorange().median*1_000_000: 4.4f} us"
)
block_tdm_c = torch.compile(block_tdm, mode="reduce-overhead")
for _ in range(5):  # warmup
    block_tdm_c(x=x)
print(
    f"Compiled TDM: {Timer('block_tdm_c(x=x)', globals=globals()).adaptive_autorange().median*1_000_000: 4.4f} us"
)
block_tds_c = torch.compile(block_tds, mode="reduce-overhead")
for _ in range(5):  # warmup
    block_tds_c(x=x)
print(
    f"Compiled sequential: {Timer('block_tds_c(x=x)', globals=globals()).adaptive_autorange().median*1_000_000: 4.4f} us"
)
Without compile
Regular:  219.5165 us
TDM:  260.3091 us
Sequential:  375.0590 us
Compiled versions
Compiled regular:  326.0555 us
Compiled TDM:  333.1850 us
Compiled sequential:  342.4750 us

可以看出,TensorDictSequential 引入的开销已完全解决。

TensorDictModule 的注意事项

  • 不要在来自 tensordict.nn 的模块周围使用 Sequence。它会破坏输入/输出键结构。始终尝试依赖 nn:TensorDictSequential 代替。

  • 不要将输出 tensordict 分配给新变量,因为输出 tensordict 只是就地修改的输入。分配新变量名并非严格禁止,但这表示您可能希望在删除其中一个变量时两者都消失,但实际上垃圾收集器仍会在工作区中看到张量,并且不会释放任何内存

    >>> tensordict = module(tensordict)  # ok!
    >>> tensordict_out = module(tensordict)  # don't!
    

使用分布:ProbabilisticTensorDictModule

ProbabilisticTensorDictModule 是一个非参数模块,表示概率分布。分布参数从 tensordict 输入中读取,输出写入输出 tensordict。根据某些规则(由输入 default_interaction_type 参数和 interaction_type() 全局函数指定)对输出进行采样。如果它们发生冲突,则上下文管理器优先。

它可以与 TensorDictModule 结合使用,后者返回使用分布参数更新的 tensordict,使用 ProbabilisticTensorDictSequential。这是 TensorDictSequential 的一个特例,其最后一层是 ProbabilisticTensorDictModule 实例。

ProbabilisticTensorDictModule 负责构建分布(通过 get_dist() 方法)和/或从此分布中采样(通过对模块进行常规 forward 调用)。相同的 get_dist() 方法在 ProbabilisticTensorDictSequential 中公开。

可以在输出 tensordict 中找到参数,以及在需要时找到对数概率。

from tensordict.nn import (
    ProbabilisticTensorDictModule,
    ProbabilisticTensorDictSequential,
)
from tensordict.nn.distributions import NormalParamExtractor
from torch import distributions as dist

td = TensorDict({"input": torch.randn(3, 4), "hidden": torch.randn(3, 8)}, [3])
net = torch.nn.GRUCell(4, 8)
net = TensorDictModule(net, in_keys=["input", "hidden"], out_keys=["hidden"])
extractor = NormalParamExtractor()
extractor = TensorDictModule(extractor, in_keys=["hidden"], out_keys=["loc", "scale"])
td_module = ProbabilisticTensorDictSequential(
    net,
    extractor,
    ProbabilisticTensorDictModule(
        in_keys=["loc", "scale"],
        out_keys=["action"],
        distribution_class=dist.Normal,
        return_log_prob=True,
    ),
)
print(f"TensorDict before going through module: {td}")
td_module(td)
print(f"TensorDict after going through module now as keys action, loc and scale: {td}")
TensorDict before going through module: TensorDict(
    fields={
        hidden: Tensor(shape=torch.Size([3, 8]), device=cpu, dtype=torch.float32, is_shared=False),
        input: Tensor(shape=torch.Size([3, 4]), device=cpu, dtype=torch.float32, is_shared=False)},
    batch_size=torch.Size([3]),
    device=None,
    is_shared=False)
TensorDict after going through module now as keys action, loc and scale: TensorDict(
    fields={
        action: Tensor(shape=torch.Size([3, 4]), device=cpu, dtype=torch.float32, is_shared=False),
        hidden: Tensor(shape=torch.Size([3, 8]), device=cpu, dtype=torch.float32, is_shared=False),
        input: Tensor(shape=torch.Size([3, 4]), device=cpu, dtype=torch.float32, is_shared=False),
        loc: Tensor(shape=torch.Size([3, 4]), device=cpu, dtype=torch.float32, is_shared=False),
        sample_log_prob: Tensor(shape=torch.Size([3, 4]), device=cpu, dtype=torch.float32, is_shared=False),
        scale: Tensor(shape=torch.Size([3, 4]), device=cpu, dtype=torch.float32, is_shared=False)},
    batch_size=torch.Size([3]),
    device=None,
    is_shared=False)

结论

我们已经了解了如何使用 tensordict.nn 动态地随时构建复杂的神经架构。这为构建与模型签名无关的管道开辟了可能性,即以灵活的方式编写通用代码,这些代码使用具有任意数量输入或输出的网络。

我们还了解了 dispatch 如何使​​用 tensordict.nn 构建此类网络,并在不直接使用 TensorDict 的情况下使用它们。感谢 compile()tensordict.nn.TensorDictSequential 引入的开销可以完全消除,为用户留下一个简洁的、无 tensordict 版本的模块。

在下一个教程中,我们将看到如何使用 torch.export 来隔离模块并导出它。

脚本的总运行时间: (0 分 18.375 秒)

由 Sphinx-Gallery 生成的图库

文档

访问 PyTorch 的全面开发者文档

查看文档

教程

获取面向初学者和高级开发者的深入教程

查看教程

资源

查找开发资源并获得解答

查看资源