注意
转到结尾 下载完整的示例代码。
TensorDictModule¶
作者: Nicolas Dufour, Vincent Moens
在本教程中,您将学习如何使用 TensorDictModule
和 TensorDictSequential
创建通用的、可重用的模块,这些模块可以接受 TensorDict
作为输入。
为了方便 TensorDict
类与 Module
一起使用,tensordict
提供了两者之间的接口,名为 TensorDictModule
。
TensorDictModule
类是一个 Module
,它在被调用时将 TensorDict
作为输入。它将读取一系列输入键,将它们作为输入传递给包装的模块或函数,并在完成执行后将输出写入同一个 tensordict 中。
由用户来定义要读取作为输入和输出的键。
import torch
import torch.nn as nn
from tensordict import TensorDict
from tensordict.nn import TensorDictModule, TensorDictSequential
简单示例:编码循环层¶
TensorDictModule
最简单的用法示例如下。如果乍一看,使用此类似乎引入了不必要的复杂性,我们稍后会看到,此 API 使用户可以编程方式将模块连接在一起,在模块之间缓存值或以编程方式构建模块。其中一个最简单的例子是 ResNet 等架构中的循环模块,其中模块的输入被缓存并添加到小型多层感知器 (MLP) 的输出中。
首先,让我们考虑一下您将如何分块 MLP,并使用 tensordict.nn
对其进行编码。堆栈的第一层可能是 Linear
层,它接受一个条目作为输入(我们将其命名为 x)并输出另一个条目(我们将其命名为 y)。
为了馈送到我们的模块,我们有一个 TensorDict
实例,其中包含一个条目,"x"
tensordict = TensorDict(
x=torch.randn(5, 3),
batch_size=[5],
)
现在,我们使用 tensordict.nn.TensorDictModule
构建我们的简单模块。默认情况下,此类在输入 tensordict 中就地写入(意味着条目写入与输入相同的 tensordict 中,而不是条目被就地覆盖!),因此我们无需显式指示输出是什么
linear0 = TensorDictModule(nn.Linear(3, 128), in_keys=["x"], out_keys=["linear0"])
linear0(tensordict)
assert "linear0" in tensordict
如果模块输出多个张量(或 tensordict!),则必须以正确的顺序将它们的条目传递给 TensorDictModule
。
对可调用对象的支持¶
在设计模型时,经常会发生您希望将任意非参数函数合并到网络中的情况。例如,您可能希望在图像传递到卷积网络或视觉 Transformer 时置换图像的维度,或者将值除以 255。有几种方法可以做到这一点:例如,您可以使用 forward_hook,或者设计一个新的 Module
来执行此操作。
TensorDictModule
可以与任何可调用对象一起使用,而不仅仅是模块,这使得将任意函数合并到模块中变得容易。例如,让我们看看如何在不使用 ReLU
模块的情况下集成 relu
激活函数
relu0 = TensorDictModule(torch.relu, in_keys=["linear0"], out_keys=["relu0"])
堆叠模块¶
我们的 MLP 不是由单层组成的,因此我们现在需要向其添加另一层。这一层将是一个激活函数,例如 ReLU
。我们可以使用 TensorDictSequential
堆叠此模块和上一个模块。
注意
这是 tensordict.nn
的真正力量所在:与 Sequential
不同,TensorDictSequential
将在内存中保留所有先前的输入和输出(并有可能在之后过滤掉它们),从而可以轻松地动态且以编程方式构建复杂的网络结构。
block0 = TensorDictSequential(linear0, relu0)
block0(tensordict)
assert "linear0" in tensordict
assert "relu0" in tensordict
我们可以重复此逻辑以获得完整的 MLP
linear1 = TensorDictModule(nn.Linear(128, 128), in_keys=["relu0"], out_keys=["linear1"])
relu1 = TensorDictModule(nn.ReLU(), in_keys=["linear1"], out_keys=["relu1"])
linear2 = TensorDictModule(nn.Linear(128, 3), in_keys=["relu1"], out_keys=["linear2"])
block1 = TensorDictSequential(linear1, relu1, linear2)
多个输入键¶
残差网络的最后一步是将输入添加到最后一个线性层的输出。无需为此编写特殊的 Module
子类!TensorDictModule
也可以用于包装简单的函数
residual = TensorDictModule(
lambda x, y: x + y, in_keys=["x", "linear2"], out_keys=["y"]
)
现在我们可以将 block0
、block1
和 residual
放在一起,构成一个完整的残差块
block = TensorDictSequential(block0, block1, residual)
block(tensordict)
assert "y" in tensordict
一个真正的担忧可能是用作输入的 tensordict 中条目的累积:在某些情况下(例如,当需要梯度时),无论如何都可能缓存中间值,但这并非总是如此,并且让垃圾收集器知道可以丢弃某些条目可能很有用。tensordict.nn.TensorDictModuleBase
及其子类(包括 tensordict.nn.TensorDictModule
和 tensordict.nn.TensorDictSequential
)可以选择在执行后过滤其输出键。为此,只需调用 tensordict.nn.TensorDictModuleBase.select_out_keys
方法。这将就地更新模块,并且所有不需要的条目都将被丢弃
block.select_out_keys("y")
tensordict = TensorDict(x=torch.randn(1, 3), batch_size=[1])
block(tensordict)
assert "y" in tensordict
assert "linear1" not in tensordict
但是,输入键会被保留
assert "x" in tensordict
作为旁注,selected_out_keys
也可以传递给 TensorDictSequential
,以避免单独调用此方法。
在没有 tensordict 的情况下使用 TensorDictModule¶
TensorDictSequential
提供的随时构建复杂架构的机会并不意味着必须切换到 tensordict 来表示数据。感谢 dispatch
,来自 tensordict.nn 的模块也支持与条目名称匹配的参数和关键字参数
x = torch.randn(1, 3)
y = block(x=x)
assert isinstance(y, torch.Tensor)
在底层,dispatch
会重建 tensordict,运行模块,然后解构它。这可能会导致一些开销,但正如我们稍后将看到的,有一种解决方案可以消除这种开销。
运行时¶
tensordict.nn.TensorDictModule
和 tensordict.nn.TensorDictSequential
在执行时确实会产生一些开销,因为它们需要从 tensordict 中读取和写入。但是,我们可以通过使用 compile()
大大减少这种开销。为此,让我们比较一下使用和不使用 compile 的这三种代码版本
class ResidualBlock(nn.Module):
def __init__(self):
super().__init__()
self.linear0 = nn.Linear(3, 128)
self.relu0 = nn.ReLU()
self.linear1 = nn.Linear(128, 128)
self.relu1 = nn.ReLU()
self.linear2 = nn.Linear(128, 3)
def forward(self, x):
y = self.linear0(x)
y = self.relu0(y)
y = self.linear1(y)
y = self.relu1(y)
return self.linear2(y) + x
print("Without compile")
x = torch.randn(256, 3)
block_notd = ResidualBlock()
block_tdm = TensorDictModule(block_notd, in_keys=["x"], out_keys=["y"])
block_tds = block
from torch.utils.benchmark import Timer
print(
f"Regular: {Timer('block_notd(x=x)', globals=globals()).adaptive_autorange().median*1_000_000: 4.4f} us"
)
print(
f"TDM: {Timer('block_tdm(x=x)', globals=globals()).adaptive_autorange().median*1_000_000: 4.4f} us"
)
print(
f"Sequential: {Timer('block_tds(x=x)', globals=globals()).adaptive_autorange().median*1_000_000: 4.4f} us"
)
print("Compiled versions")
block_notd_c = torch.compile(block_notd, mode="reduce-overhead")
for _ in range(5): # warmup
block_notd_c(x)
print(
f"Compiled regular: {Timer('block_notd_c(x=x)', globals=globals()).adaptive_autorange().median*1_000_000: 4.4f} us"
)
block_tdm_c = torch.compile(block_tdm, mode="reduce-overhead")
for _ in range(5): # warmup
block_tdm_c(x=x)
print(
f"Compiled TDM: {Timer('block_tdm_c(x=x)', globals=globals()).adaptive_autorange().median*1_000_000: 4.4f} us"
)
block_tds_c = torch.compile(block_tds, mode="reduce-overhead")
for _ in range(5): # warmup
block_tds_c(x=x)
print(
f"Compiled sequential: {Timer('block_tds_c(x=x)', globals=globals()).adaptive_autorange().median*1_000_000: 4.4f} us"
)
Without compile
Regular: 219.5165 us
TDM: 260.3091 us
Sequential: 375.0590 us
Compiled versions
Compiled regular: 326.0555 us
Compiled TDM: 333.1850 us
Compiled sequential: 342.4750 us
可以看出,TensorDictSequential
引入的开销已完全解决。
TensorDictModule 的注意事项¶
不要在来自
tensordict.nn
的模块周围使用Sequence
。它会破坏输入/输出键结构。始终尝试依赖nn:TensorDictSequential
代替。不要将输出 tensordict 分配给新变量,因为输出 tensordict 只是就地修改的输入。分配新变量名并非严格禁止,但这表示您可能希望在删除其中一个变量时两者都消失,但实际上垃圾收集器仍会在工作区中看到张量,并且不会释放任何内存
>>> tensordict = module(tensordict) # ok! >>> tensordict_out = module(tensordict) # don't!
使用分布:ProbabilisticTensorDictModule
¶
ProbabilisticTensorDictModule
是一个非参数模块,表示概率分布。分布参数从 tensordict 输入中读取,输出写入输出 tensordict。根据某些规则(由输入 default_interaction_type
参数和 interaction_type()
全局函数指定)对输出进行采样。如果它们发生冲突,则上下文管理器优先。
它可以与 TensorDictModule
结合使用,后者返回使用分布参数更新的 tensordict,使用 ProbabilisticTensorDictSequential
。这是 TensorDictSequential
的一个特例,其最后一层是 ProbabilisticTensorDictModule
实例。
ProbabilisticTensorDictModule
负责构建分布(通过 get_dist()
方法)和/或从此分布中采样(通过对模块进行常规 forward 调用)。相同的 get_dist()
方法在 ProbabilisticTensorDictSequential
中公开。
可以在输出 tensordict 中找到参数,以及在需要时找到对数概率。
from tensordict.nn import (
ProbabilisticTensorDictModule,
ProbabilisticTensorDictSequential,
)
from tensordict.nn.distributions import NormalParamExtractor
from torch import distributions as dist
td = TensorDict({"input": torch.randn(3, 4), "hidden": torch.randn(3, 8)}, [3])
net = torch.nn.GRUCell(4, 8)
net = TensorDictModule(net, in_keys=["input", "hidden"], out_keys=["hidden"])
extractor = NormalParamExtractor()
extractor = TensorDictModule(extractor, in_keys=["hidden"], out_keys=["loc", "scale"])
td_module = ProbabilisticTensorDictSequential(
net,
extractor,
ProbabilisticTensorDictModule(
in_keys=["loc", "scale"],
out_keys=["action"],
distribution_class=dist.Normal,
return_log_prob=True,
),
)
print(f"TensorDict before going through module: {td}")
td_module(td)
print(f"TensorDict after going through module now as keys action, loc and scale: {td}")
TensorDict before going through module: TensorDict(
fields={
hidden: Tensor(shape=torch.Size([3, 8]), device=cpu, dtype=torch.float32, is_shared=False),
input: Tensor(shape=torch.Size([3, 4]), device=cpu, dtype=torch.float32, is_shared=False)},
batch_size=torch.Size([3]),
device=None,
is_shared=False)
TensorDict after going through module now as keys action, loc and scale: TensorDict(
fields={
action: Tensor(shape=torch.Size([3, 4]), device=cpu, dtype=torch.float32, is_shared=False),
hidden: Tensor(shape=torch.Size([3, 8]), device=cpu, dtype=torch.float32, is_shared=False),
input: Tensor(shape=torch.Size([3, 4]), device=cpu, dtype=torch.float32, is_shared=False),
loc: Tensor(shape=torch.Size([3, 4]), device=cpu, dtype=torch.float32, is_shared=False),
sample_log_prob: Tensor(shape=torch.Size([3, 4]), device=cpu, dtype=torch.float32, is_shared=False),
scale: Tensor(shape=torch.Size([3, 4]), device=cpu, dtype=torch.float32, is_shared=False)},
batch_size=torch.Size([3]),
device=None,
is_shared=False)
结论¶
我们已经了解了如何使用 tensordict.nn 动态地随时构建复杂的神经架构。这为构建与模型签名无关的管道开辟了可能性,即以灵活的方式编写通用代码,这些代码使用具有任意数量输入或输出的网络。
我们还了解了 dispatch
如何使用 tensordict.nn 构建此类网络,并在不直接使用 TensorDict
的情况下使用它们。感谢 compile()
,tensordict.nn.TensorDictSequential
引入的开销可以完全消除,为用户留下一个简洁的、无 tensordict 版本的模块。
在下一个教程中,我们将看到如何使用 torch.export
来隔离模块并导出它。
脚本的总运行时间: (0 分 18.375 秒)