• 文档 >
  • TensorDictModule 的功能化
快捷键

TensorDictModule 的功能化

在本教程中,您将学习如何将 TensorDictModule 与 functorch 结合使用以创建功能化的模块。

在我们查看 tensordict.nn 中的功能化实用程序之前,让我们重新介绍 TensorDictModule 教程中的示例模块之一。

我们将创建一个简单的模块,该模块具有两个线性层,它们共享输入并返回单独的输出。

import functorch
import torch
import torch.nn as nn
from tensordict import TensorDict
from tensordict.nn import TensorDictModule


class MultiHeadLinear(nn.Module):
    def __init__(self, in_1, out_1, out_2):
        super().__init__()
        self.linear_1 = nn.Linear(in_1, out_1)
        self.linear_2 = nn.Linear(in_1, out_2)

    def forward(self, x):
        return self.linear_1(x), self.linear_2(x)

现在,我们可以创建一个 TensorDictModule,它将从键 "a" 读取输入,并写入键 "output_1""output_2"

splitlinear = TensorDictModule(
    MultiHeadLinear(3, 4, 10), in_keys=["a"], out_keys=["output_1", "output_2"]
)

通常,我们会通过简单地对具有所需输入键的 TensorDict 调用该模块来使用它。

tensordict = TensorDict({"a": torch.randn(5, 3)}, batch_size=[5])
splitlinear(tensordict)
print(tensordict)
TensorDict(
    fields={
        a: Tensor(shape=torch.Size([5, 3]), device=cpu, dtype=torch.float32, is_shared=False),
        output_1: Tensor(shape=torch.Size([5, 4]), device=cpu, dtype=torch.float32, is_shared=False),
        output_2: Tensor(shape=torch.Size([5, 10]), device=cpu, dtype=torch.float32, is_shared=False)},
    batch_size=torch.Size([5]),
    device=None,
    is_shared=False)

但是,我们也可以使用 functorch.make_functional_with_buffers() 来功能化该模块。

func, params, buffers = functorch.make_functional_with_buffers(splitlinear)
print(func(params, buffers, tensordict))
TensorDict(
    fields={
        a: Tensor(shape=torch.Size([5, 3]), device=cpu, dtype=torch.float32, is_shared=False),
        output_1: Tensor(shape=torch.Size([5, 4]), device=cpu, dtype=torch.float32, is_shared=False),
        output_2: Tensor(shape=torch.Size([5, 10]), device=cpu, dtype=torch.float32, is_shared=False)},
    batch_size=torch.Size([5]),
    device=None,
    is_shared=False)

这可以与 vmap 运算符一起使用。例如,我们使用参数和缓冲区的 3 个副本,并对这些副本执行向量化映射以处理一批数据。

params_expand = [p.expand(3, *p.shape) for p in params]
buffers_expand = [p.expand(3, *p.shape) for p in buffers]
print(torch.vmap(func, (0, 0, None))(params_expand, buffers_expand, tensordict))
TensorDict(
    fields={
        a: Tensor(shape=torch.Size([3, 5, 3]), device=cpu, dtype=torch.float32, is_shared=False),
        output_1: Tensor(shape=torch.Size([3, 5, 4]), device=cpu, dtype=torch.float32, is_shared=False),
        output_2: Tensor(shape=torch.Size([3, 5, 10]), device=cpu, dtype=torch.float32, is_shared=False)},
    batch_size=torch.Size([3, 5]),
    device=None,
    is_shared=False)

我们还可以使用来自 tensordict.nn` 的原生 make_functional 函数,它会修改该模块使其接受参数作为常规输入。

from tensordict.nn import make_functional

tensordict = TensorDict({"a": torch.randn(5, 3)}, batch_size=[5])

num_models = 10
model = TensorDictModule(nn.Linear(3, 4), in_keys=["a"], out_keys=["output"])
params = make_functional(model)
# we stack two groups of parameters to show the vmap usage:
params = torch.stack([params, params.apply(lambda x: torch.zeros_like(x))], 0)
result_td = torch.vmap(model, (None, 0))(tensordict, params)
print("the output tensordict shape is: ", result_td.shape)
the output tensordict shape is:  torch.Size([2, 5])

脚本总运行时间: (0 分钟 0.006 秒)

Sphinx-Gallery 生成的图库

文档

访问 PyTorch 的全面开发者文档

查看文档

教程

获取适合初学者和高级开发人员的深入教程

查看教程

资源

查找开发资源并获得问题的解答

查看资源