注意
转到结尾 下载完整示例代码。
TensorDictModule 的功能化¶
在本教程中,您将学习如何将 TensorDictModule
与 functorch 结合使用以创建功能化的模块。
在我们查看 tensordict.nn
中的功能化实用程序之前,让我们重新介绍 TensorDictModule
教程中的示例模块之一。
我们将创建一个简单的模块,该模块具有两个线性层,它们共享输入并返回单独的输出。
import functorch
import torch
import torch.nn as nn
from tensordict import TensorDict
from tensordict.nn import TensorDictModule
class MultiHeadLinear(nn.Module):
def __init__(self, in_1, out_1, out_2):
super().__init__()
self.linear_1 = nn.Linear(in_1, out_1)
self.linear_2 = nn.Linear(in_1, out_2)
def forward(self, x):
return self.linear_1(x), self.linear_2(x)
现在,我们可以创建一个 TensorDictModule
,它将从键 "a"
读取输入,并写入键 "output_1"
和 "output_2"
。
splitlinear = TensorDictModule(
MultiHeadLinear(3, 4, 10), in_keys=["a"], out_keys=["output_1", "output_2"]
)
通常,我们会通过简单地对具有所需输入键的 TensorDict
调用该模块来使用它。
tensordict = TensorDict({"a": torch.randn(5, 3)}, batch_size=[5])
splitlinear(tensordict)
print(tensordict)
TensorDict(
fields={
a: Tensor(shape=torch.Size([5, 3]), device=cpu, dtype=torch.float32, is_shared=False),
output_1: Tensor(shape=torch.Size([5, 4]), device=cpu, dtype=torch.float32, is_shared=False),
output_2: Tensor(shape=torch.Size([5, 10]), device=cpu, dtype=torch.float32, is_shared=False)},
batch_size=torch.Size([5]),
device=None,
is_shared=False)
但是,我们也可以使用 functorch.make_functional_with_buffers()
来功能化该模块。
TensorDict(
fields={
a: Tensor(shape=torch.Size([5, 3]), device=cpu, dtype=torch.float32, is_shared=False),
output_1: Tensor(shape=torch.Size([5, 4]), device=cpu, dtype=torch.float32, is_shared=False),
output_2: Tensor(shape=torch.Size([5, 10]), device=cpu, dtype=torch.float32, is_shared=False)},
batch_size=torch.Size([5]),
device=None,
is_shared=False)
这可以与 vmap 运算符一起使用。例如,我们使用参数和缓冲区的 3 个副本,并对这些副本执行向量化映射以处理一批数据。
params_expand = [p.expand(3, *p.shape) for p in params]
buffers_expand = [p.expand(3, *p.shape) for p in buffers]
print(torch.vmap(func, (0, 0, None))(params_expand, buffers_expand, tensordict))
TensorDict(
fields={
a: Tensor(shape=torch.Size([3, 5, 3]), device=cpu, dtype=torch.float32, is_shared=False),
output_1: Tensor(shape=torch.Size([3, 5, 4]), device=cpu, dtype=torch.float32, is_shared=False),
output_2: Tensor(shape=torch.Size([3, 5, 10]), device=cpu, dtype=torch.float32, is_shared=False)},
batch_size=torch.Size([3, 5]),
device=None,
is_shared=False)
我们还可以使用来自 tensordict.nn`
的原生 make_functional
函数,它会修改该模块使其接受参数作为常规输入。
from tensordict.nn import make_functional
tensordict = TensorDict({"a": torch.randn(5, 3)}, batch_size=[5])
num_models = 10
model = TensorDictModule(nn.Linear(3, 4), in_keys=["a"], out_keys=["output"])
params = make_functional(model)
# we stack two groups of parameters to show the vmap usage:
params = torch.stack([params, params.apply(lambda x: torch.zeros_like(x))], 0)
result_td = torch.vmap(model, (None, 0))(tensordict, params)
print("the output tensordict shape is: ", result_td.shape)
the output tensordict shape is: torch.Size([2, 5])
脚本总运行时间: (0 分钟 0.006 秒)