tensordict.nn 包¶
tensordict.nn 包使得在 ML 管道中灵活使用 TensorDict 成为可能。
由于 TensorDict 将代码的部分转换为基于键的结构,现在可以使用这些键作为钩子来构建复杂的图结构。基本构建块是 TensorDictModule
,它使用输入和输出键的列表包装 torch.nn.Module
实例
>>> from torch.nn import Transformer
>>> from tensordict import TensorDict
>>> from tensordict.nn import TensorDictModule
>>> import torch
>>> module = TensorDictModule(Transformer(), in_keys=["feature", "target"], out_keys=["prediction"])
>>> data = TensorDict({"feature": torch.randn(10, 11, 512), "target": torch.randn(10, 11, 512)}, [10, 11])
>>> data = module(data)
>>> print(data)
TensorDict(
fields={
feature: Tensor(torch.Size([10, 11, 512]), dtype=torch.float32),
prediction: Tensor(torch.Size([10, 11, 512]), dtype=torch.float32),
target: Tensor(torch.Size([10, 11, 512]), dtype=torch.float32)},
batch_size=torch.Size([10, 11]),
device=None,
is_shared=False)
不一定需要使用 TensorDictModule
,自定义的 torch.nn.Module
以及有序的输入和输出键列表(命名为 module.in_keys
和 module.out_keys
)就足够了。
多个 PyTorch 用户的痛点是 nn.Sequential 无法处理具有多个输入的模块。使用基于键的图可以轻松解决该问题,因为序列中的每个节点都知道需要读取哪些数据以及将数据写入何处。
为此,我们提供了 TensorDictSequential 类,它通过 TensorDictModules 序列传递数据。序列中的每个模块都从原始 TensorDict 中获取输入,并将输出写入其中,这意味着序列中的模块可以忽略其前任的输出,或者根据需要从 tensordict 中获取额外的输入。这是一个例子
>>> from tensordict.nn import TensorDictSequential
>>> class Net(nn.Module):
... def __init__(self, input_size=100, hidden_size=50, output_size=10):
... super().__init__()
... self.fc1 = nn.Linear(input_size, hidden_size)
... self.fc2 = nn.Linear(hidden_size, output_size)
...
... def forward(self, x):
... x = torch.relu(self.fc1(x))
... return self.fc2(x)
...
>>> class Masker(nn.Module):
... def forward(self, x, mask):
... return torch.softmax(x * mask, dim=1)
...
>>> net = TensorDictModule(
... Net(), in_keys=[("input", "x")], out_keys=[("intermediate", "x")]
... )
>>> masker = TensorDictModule(
... Masker(),
... in_keys=[("intermediate", "x"), ("input", "mask")],
... out_keys=[("output", "probabilities")],
... )
>>> module = TensorDictSequential(net, masker)
>>>
>>> td = TensorDict(
... {
... "input": TensorDict(
... {"x": torch.rand(32, 100), "mask": torch.randint(2, size=(32, 10))},
... batch_size=[32],
... )
... },
... batch_size=[32],
... )
>>> td = module(td)
>>> print(td)
TensorDict(
fields={
input: TensorDict(
fields={
mask: Tensor(torch.Size([32, 10]), dtype=torch.int64),
x: Tensor(torch.Size([32, 100]), dtype=torch.float32)},
batch_size=torch.Size([32]),
device=None,
is_shared=False),
intermediate: TensorDict(
fields={
x: Tensor(torch.Size([32, 10]), dtype=torch.float32)},
batch_size=torch.Size([32]),
device=None,
is_shared=False),
output: TensorDict(
fields={
probabilities: Tensor(torch.Size([32, 10]), dtype=torch.float32)},
batch_size=torch.Size([32]),
device=None,
is_shared=False)},
batch_size=torch.Size([32]),
device=None,
is_shared=False)
我们还可以通过 select_subsequence()
方法轻松选择子图
>>> sub_module = module.select_subsequence(out_keys=[("intermediate", "x")])
>>> td = TensorDict(
... {
... "input": TensorDict(
... {"x": torch.rand(32, 100), "mask": torch.randint(2, size=(32, 10))},
... batch_size=[32],
... )
... },
... batch_size=[32],
... )
>>> sub_module(td)
>>> print(td) # the "output" has not been computed
TensorDict(
fields={
input: TensorDict(
fields={
mask: Tensor(torch.Size([32, 10]), dtype=torch.int64),
x: Tensor(torch.Size([32, 100]), dtype=torch.float32)},
batch_size=torch.Size([32]),
device=None,
is_shared=False),
intermediate: TensorDict(
fields={
x: Tensor(torch.Size([32, 10]), dtype=torch.float32)},
batch_size=torch.Size([32]),
device=None,
is_shared=False)},
batch_size=torch.Size([32]),
device=None,
is_shared=False)
最后,tensordict.nn
附带了一个 ProbabilisticTensorDictModule
,允许从网络输出构建分布,并从中获取汇总统计信息或样本(以及分布参数)
>>> import torch
>>> from tensordict import TensorDict
>>> from tensordict.nn import TensorDictModule
>>> from tensordict.nn.distributions import NormalParamWrapper
>>> from tensordict.nn.prototype import (
... ProbabilisticTensorDictModule,
... ProbabilisticTensorDictSequential,
... )
>>> from torch.distributions import Normal
>>> td = TensorDict(
... {"input": torch.randn(3, 4), "hidden": torch.randn(3, 8)}, [3]
... )
>>> net = torch.nn.GRUCell(4, 8)
>>> module = TensorDictModule(
... NormalParamWrapper(net), in_keys=["input", "hidden"], out_keys=["loc", "scale"]
... )
>>> prob_module = ProbabilisticTensorDictModule(
... in_keys=["loc", "scale"],
... out_keys=["sample"],
... distribution_class=Normal,
... return_log_prob=True,
... )
>>> td_module = ProbabilisticTensorDictSequential(module, prob_module)
>>> td_module(td)
>>> print(td)
TensorDict(
fields={
action: Tensor(torch.Size([3, 4]), dtype=torch.float32),
hidden: Tensor(torch.Size([3, 8]), dtype=torch.float32),
input: Tensor(torch.Size([3, 4]), dtype=torch.float32),
loc: Tensor(torch.Size([3, 4]), dtype=torch.float32),
sample_log_prob: Tensor(torch.Size([3, 4]), dtype=torch.float32),
scale: Tensor(torch.Size([3, 4]), dtype=torch.float32)},
batch_size=torch.Size([3]),
device=None,
is_shared=False)
|
TensorDict 模块的基类。 |
|
TensorDictModule 是 |
|
概率 TD 模块。 |
|
TensorDictModules 的序列。 |
|
TensorDictModule 对象的包装器类。 |
|
PyTorch 可调用对象的 cudagraph 包装器。 |
集成¶
函数式方法支持直接的集成实现。我们可以使用 tensordict.nn.EnsembleModule
复制和重新初始化模型副本
>>> import torch
>>> from torch import nn
>>> from tensordict.nn import TensorDictModule
>>> from torchrl.modules import EnsembleModule
>>> from tensordict import TensorDict
>>> net = nn.Sequential(nn.Linear(4, 32), nn.ReLU(), nn.Linear(32, 2))
>>> mod = TensorDictModule(net, in_keys=['a'], out_keys=['b'])
>>> ensemble = EnsembleModule(mod, num_copies=3)
>>> data = TensorDict({'a': torch.randn(10, 4)}, batch_size=[10])
>>> ensemble(data)
TensorDict(
fields={
a: Tensor(shape=torch.Size([3, 10, 4]), device=cpu, dtype=torch.float32, is_shared=False),
b: Tensor(shape=torch.Size([3, 10, 2]), device=cpu, dtype=torch.float32, is_shared=False)},
batch_size=torch.Size([3, 10]),
device=None,
is_shared=False)
|
包装模块并重复它以形成集成的模块。 |
编译 TensorDictModules¶
自 v0.5 起,TensorDict 组件与 compile()
兼容。例如,TensorDictSequential
模块可以使用 torch.compile
进行编译,并达到类似于包装在 TensorDictModule
中的常规 PyTorch 模块的运行时性能。
分布¶
|
一个非参数化的 nn.Module,它将其输入拆分为 loc 和 scale 参数。 |
一个 nn.Module,它添加可训练的与状态无关的 scale 参数。 |
|
|
分布的组合。 |
|
Delta 分布。 |
|
One-hot 分类分布。 |
|
截断正态分布。 |
实用工具¶
|
返回从关键字参数或输入字典创建的 TensorDict。 |
|
允许使用 kwargs 调用期望 TensorDict 的函数。 |
|
将所有 ProbabilisticTDModules 的采样设置为所需的类型。 |
|
反向 softplus 函数。 |
|
有偏 softplus 模块。 |
|
用于跳过 TensorDict 图中现有节点的上下文管理器。 |
返回是否应由模块重新计算 tensordict 中的现有条目。 |