快捷方式

TensorDictModule

在本教程中,您将学习如何使用 TensorDictModuleTensorDictSequential 创建通用的、可重用的模块,这些模块可以接受 TensorDict 作为输入。

为了方便地将 TensorDict 类与 nn.Module 一起使用,tensordict 提供了这两个之间的接口,称为 TensorDictModuleTensorDictModule 类是一个 nn.Module,在被调用时会接收一个 TensorDict 作为输入。用户需要定义要读取作为输入和输出的键。

TensorDictModule 示例

import torch
import torch.nn as nn
from tensordict import TensorDict
from tensordict.nn import TensorDictModule, TensorDictSequential

示例 1:简单用法

我们有一个 TensorDict,其中包含 2 个条目 "a""b",但网络只需要读取与 "a" 关联的值。

tensordict = TensorDict(
    {"a": torch.randn(5, 3), "b": torch.zeros(5, 4, 3)},
    batch_size=[5],
)
linear = TensorDictModule(nn.Linear(3, 10), in_keys=["a"], out_keys=["a_out"])
linear(tensordict)
assert (tensordict.get("b") == 0).all()
print(tensordict)
TensorDict(
    fields={
        a: Tensor(shape=torch.Size([5, 3]), device=cpu, dtype=torch.float32, is_shared=False),
        a_out: Tensor(shape=torch.Size([5, 10]), device=cpu, dtype=torch.float32, is_shared=False),
        b: Tensor(shape=torch.Size([5, 4, 3]), device=cpu, dtype=torch.float32, is_shared=False)},
    batch_size=torch.Size([5]),
    device=None,
    is_shared=False)

示例 2:多个输入

假设我们有一个稍微复杂的网络,它接收 2 个条目并将它们平均到一个输出张量中。为了使 TensorDictModule 实例读取多个输入值,必须在构造函数的 in_keys 关键字参数中注册它们。

class MergeLinear(nn.Module):
    def __init__(self, in_1, in_2, out):
        super().__init__()
        self.linear_1 = nn.Linear(in_1, out)
        self.linear_2 = nn.Linear(in_2, out)

    def forward(self, x_1, x_2):
        return (self.linear_1(x_1) + self.linear_2(x_2)) / 2
tensordict = TensorDict(
    {
        "a": torch.randn(5, 3),
        "b": torch.randn(5, 4),
    },
    batch_size=[5],
)

mergelinear = TensorDictModule(
    MergeLinear(3, 4, 10), in_keys=["a", "b"], out_keys=["output"]
)

mergelinear(tensordict)
TensorDict(
    fields={
        a: Tensor(shape=torch.Size([5, 3]), device=cpu, dtype=torch.float32, is_shared=False),
        b: Tensor(shape=torch.Size([5, 4]), device=cpu, dtype=torch.float32, is_shared=False),
        output: Tensor(shape=torch.Size([5, 10]), device=cpu, dtype=torch.float32, is_shared=False)},
    batch_size=torch.Size([5]),
    device=None,
    is_shared=False)

示例 3:多个输出

类似地,TensorDictModule 不仅支持多个输入,还支持多个输出。为了使 TensorDictModule 实例写入多个输出值,必须在构造函数的 out_keys 关键字参数中注册它们。

class MultiHeadLinear(nn.Module):
    def __init__(self, in_1, out_1, out_2):
        super().__init__()
        self.linear_1 = nn.Linear(in_1, out_1)
        self.linear_2 = nn.Linear(in_1, out_2)

    def forward(self, x):
        return self.linear_1(x), self.linear_2(x)
tensordict = TensorDict({"a": torch.randn(5, 3)}, batch_size=[5])

splitlinear = TensorDictModule(
    MultiHeadLinear(3, 4, 10),
    in_keys=["a"],
    out_keys=["output_1", "output_2"],
)
splitlinear(tensordict)
TensorDict(
    fields={
        a: Tensor(shape=torch.Size([5, 3]), device=cpu, dtype=torch.float32, is_shared=False),
        output_1: Tensor(shape=torch.Size([5, 4]), device=cpu, dtype=torch.float32, is_shared=False),
        output_2: Tensor(shape=torch.Size([5, 10]), device=cpu, dtype=torch.float32, is_shared=False)},
    batch_size=torch.Size([5]),
    device=None,
    is_shared=False)

当有多个输入键和输出键时,请确保它们与模块中的顺序匹配。

TensorDictModule 可以与包含比 in_keys 属性指示的更多张量的 TensorDict 实例一起使用。

除非使用 vmap 运算符,否则 TensorDict 将被就地修改。

忽略一些输出

请注意,可以使用 "_"out_keys 中来避免将某些张量写入 TensorDict 输出。

示例 4:将多个 TensorDictModuleTensorDictSequential 结合使用

要组合多个 TensorDictModule 实例,我们可以使用 TensorDictSequential。我们创建一个列表,其中每个 TensorDictModule 必须按顺序执行。 TensorDictSequential 将按照提供的模块序列读取和写入 tensordict 中的键。

我们还可以使用 in_keys 属性收集 TensorDictSequential 所需的输入,并且输出键在 out_keys 属性中找到。

tensordict = TensorDict({"a": torch.randn(5, 3)}, batch_size=[5])

splitlinear = TensorDictModule(
    MultiHeadLinear(3, 4, 10),
    in_keys=["a"],
    out_keys=["output_1", "output_2"],
)
mergelinear = TensorDictModule(
    MergeLinear(4, 10, 13),
    in_keys=["output_1", "output_2"],
    out_keys=["output"],
)

split_and_merge_linear = TensorDictSequential(splitlinear, mergelinear)

assert split_and_merge_linear(tensordict)["output"].shape == torch.Size([5, 13])

TensorDictModule 的注意事项

不要使用 nn.Sequence,类似于 nn.Module,它会破坏 functorch 兼容性等功能。请改用 TensorDictSequential

不要将输出 tensordict 分配给一个新的变量,因为输出 tensordict 只是就地修改的输入。

tensordict = module(tensordict) # 可以!

tensordict_out = module(tensordict) # 不要!

ProbabilisticTensorDictModule

ProbabilisticTensorDictModule 是一个表示概率分布的非参数模块。分布参数从 tensordict 输入中读取,输出写入到输出 tensordict 中。给定某个规则,根据输入的 default_interaction_type 参数和 exploration_mode() 全局函数对输出进行采样。如果它们发生冲突,则上下文管理器优先。

它可以与返回一个使用 ProbabilisticTensorDictSequential 更新了分布参数的 tensordict 的 TensorDictModule 连接起来。这是 TensorDictSequential 的一个特例,它以 ProbabilisticTensorDictModule 结束。

ProbabilisticTensorDictModule 负责构建分布(通过 get_dist() 方法)和/或从该分布中采样(通过对模块进行常规的 __call__())。相同的 get_dist() 方法在 ``ProbabilisticTensorDictSequential 上公开。

可以在输出 tensordict 中找到参数,以及如果需要的话,找到对数概率。

from tensordict.nn import (
    ProbabilisticTensorDictModule,
    ProbabilisticTensorDictSequential,
)
from tensordict.nn.distributions import NormalParamExtractor
from torch import distributions as dist

td = TensorDict({"input": torch.randn(3, 4), "hidden": torch.randn(3, 8)}, [3])
net = torch.nn.GRUCell(4, 8)
net = TensorDictModule(net, in_keys=["input", "hidden"], out_keys=["hidden"])
extractor = NormalParamExtractor()
extractor = TensorDictModule(extractor, in_keys=["hidden"], out_keys=["loc", "scale"])
td_module = ProbabilisticTensorDictSequential(
    net,
    extractor,
    ProbabilisticTensorDictModule(
        in_keys=["loc", "scale"],
        out_keys=["action"],
        distribution_class=dist.Normal,
        return_log_prob=True,
    ),
)
print(f"TensorDict before going through module: {td}")
td_module(td)
print(f"TensorDict after going through module now as keys action, loc and scale: {td}")
TensorDict before going through module: TensorDict(
    fields={
        hidden: Tensor(shape=torch.Size([3, 8]), device=cpu, dtype=torch.float32, is_shared=False),
        input: Tensor(shape=torch.Size([3, 4]), device=cpu, dtype=torch.float32, is_shared=False)},
    batch_size=torch.Size([3]),
    device=None,
    is_shared=False)
TensorDict after going through module now as keys action, loc and scale: TensorDict(
    fields={
        action: Tensor(shape=torch.Size([3, 4]), device=cpu, dtype=torch.float32, is_shared=False),
        hidden: Tensor(shape=torch.Size([3, 8]), device=cpu, dtype=torch.float32, is_shared=False),
        input: Tensor(shape=torch.Size([3, 4]), device=cpu, dtype=torch.float32, is_shared=False),
        loc: Tensor(shape=torch.Size([3, 4]), device=cpu, dtype=torch.float32, is_shared=False),
        sample_log_prob: Tensor(shape=torch.Size([3, 4]), device=cpu, dtype=torch.float32, is_shared=False),
        scale: Tensor(shape=torch.Size([3, 4]), device=cpu, dtype=torch.float32, is_shared=False)},
    batch_size=torch.Size([3]),
    device=None,
    is_shared=False)

展示:使用 TensorDictModule 实现 Transformer

为了展示 TensorDictModule 的灵活性,我们将创建一个使用 TensorDictModule 读取 TensorDict 对象的 Transformer。

下图显示了经典的 Transformer 架构(Vaswani 等人,2017)。

The transformer png

为了简单起见,我们忽略了位置编码器。

让我们重写经典的 Transformer 块

class TokensToQKV(nn.Module):
    def __init__(self, to_dim, from_dim, latent_dim):
        super().__init__()
        self.q = nn.Linear(to_dim, latent_dim)
        self.k = nn.Linear(from_dim, latent_dim)
        self.v = nn.Linear(from_dim, latent_dim)

    def forward(self, X_to, X_from):
        Q = self.q(X_to)
        K = self.k(X_from)
        V = self.v(X_from)
        return Q, K, V


class SplitHeads(nn.Module):
    def __init__(self, num_heads):
        super().__init__()
        self.num_heads = num_heads

    def forward(self, Q, K, V):
        batch_size, to_num, latent_dim = Q.shape
        _, from_num, _ = K.shape
        d_tensor = latent_dim // self.num_heads
        Q = Q.reshape(batch_size, to_num, self.num_heads, d_tensor).transpose(1, 2)
        K = K.reshape(batch_size, from_num, self.num_heads, d_tensor).transpose(1, 2)
        V = V.reshape(batch_size, from_num, self.num_heads, d_tensor).transpose(1, 2)
        return Q, K, V


class Attention(nn.Module):
    def __init__(self, latent_dim, to_dim):
        super().__init__()
        self.softmax = nn.Softmax(dim=-1)
        self.out = nn.Linear(latent_dim, to_dim)

    def forward(self, Q, K, V):
        batch_size, n_heads, to_num, d_in = Q.shape
        attn = self.softmax(Q @ K.transpose(2, 3) / d_in)
        out = attn @ V
        out = self.out(out.transpose(1, 2).reshape(batch_size, to_num, n_heads * d_in))
        return out, attn


class SkipLayerNorm(nn.Module):
    def __init__(self, to_len, to_dim):
        super().__init__()
        self.layer_norm = nn.LayerNorm((to_len, to_dim))

    def forward(self, x_0, x_1):
        return self.layer_norm(x_0 + x_1)


class FFN(nn.Module):
    def __init__(self, to_dim, hidden_dim, dropout_rate=0.2):
        super().__init__()
        self.FFN = nn.Sequential(
            nn.Linear(to_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, to_dim),
            nn.Dropout(dropout_rate),
        )

    def forward(self, X):
        return self.FFN(X)


class AttentionBlock(nn.Module):
    def __init__(self, to_dim, to_len, from_dim, latent_dim, num_heads):
        super().__init__()
        self.tokens_to_qkv = TokensToQKV(to_dim, from_dim, latent_dim)
        self.split_heads = SplitHeads(num_heads)
        self.attention = Attention(latent_dim, to_dim)
        self.skip = SkipLayerNorm(to_len, to_dim)

    def forward(self, X_to, X_from):
        Q, K, V = self.tokens_to_qkv(X_to, X_from)
        Q, K, V = self.split_heads(Q, K, V)
        out, attention = self.attention(Q, K, V)
        out = self.skip(X_to, out)
        return out


class EncoderTransformerBlock(nn.Module):
    def __init__(self, to_dim, to_len, latent_dim, num_heads):
        super().__init__()
        self.attention_block = AttentionBlock(
            to_dim, to_len, to_dim, latent_dim, num_heads
        )
        self.FFN = FFN(to_dim, 4 * to_dim)
        self.skip = SkipLayerNorm(to_len, to_dim)

    def forward(self, X_to):
        X_to = self.attention_block(X_to, X_to)
        X_out = self.FFN(X_to)
        return self.skip(X_out, X_to)


class DecoderTransformerBlock(nn.Module):
    def __init__(self, to_dim, to_len, from_dim, latent_dim, num_heads):
        super().__init__()
        self.attention_block = AttentionBlock(
            to_dim, to_len, from_dim, latent_dim, num_heads
        )
        self.encoder_block = EncoderTransformerBlock(
            to_dim, to_len, latent_dim, num_heads
        )

    def forward(self, X_to, X_from):
        X_to = self.attention_block(X_to, X_from)
        X_to = self.encoder_block(X_to)
        return X_to


class TransformerEncoder(nn.Module):
    def __init__(self, num_blocks, to_dim, to_len, latent_dim, num_heads):
        super().__init__()
        self.encoder = nn.ModuleList(
            [
                EncoderTransformerBlock(to_dim, to_len, latent_dim, num_heads)
                for i in range(num_blocks)
            ]
        )

    def forward(self, X_to):
        for i in range(len(self.encoder)):
            X_to = self.encoder[i](X_to)
        return X_to


class TransformerDecoder(nn.Module):
    def __init__(self, num_blocks, to_dim, to_len, from_dim, latent_dim, num_heads):
        super().__init__()
        self.decoder = nn.ModuleList(
            [
                DecoderTransformerBlock(to_dim, to_len, from_dim, latent_dim, num_heads)
                for i in range(num_blocks)
            ]
        )

    def forward(self, X_to, X_from):
        for i in range(len(self.decoder)):
            X_to = self.decoder[i](X_to, X_from)
        return X_to


class Transformer(nn.Module):
    def __init__(
        self, num_blocks, to_dim, to_len, from_dim, from_len, latent_dim, num_heads
    ):
        super().__init__()
        self.encoder = TransformerEncoder(
            num_blocks, to_dim, to_len, latent_dim, num_heads
        )
        self.decoder = TransformerDecoder(
            num_blocks, from_dim, from_len, to_dim, latent_dim, num_heads
        )

    def forward(self, X_to, X_from):
        X_to = self.encoder(X_to)
        X_out = self.decoder(X_from, X_to)
        return X_out

我们首先创建AttentionBlockTensorDict,这是一个使用TensorDictModuleTensorDictSequential构建的注意力块。

将模块连接起来的布线操作需要我们指示每个模块必须读取和写入哪个键。与nn.Sequence不同,TensorDictSequential可以读取/写入多个输入/输出。此外,其组件的输入不必与前一层输出相同,这使我们能够编写复杂的neural architecture。

class AttentionBlockTensorDict(TensorDictSequential):
    def __init__(
        self,
        to_name,
        from_name,
        to_dim,
        to_len,
        from_dim,
        latent_dim,
        num_heads,
    ):
        super().__init__(
            TensorDictModule(
                TokensToQKV(to_dim, from_dim, latent_dim),
                in_keys=[to_name, from_name],
                out_keys=["Q", "K", "V"],
            ),
            TensorDictModule(
                SplitHeads(num_heads),
                in_keys=["Q", "K", "V"],
                out_keys=["Q", "K", "V"],
            ),
            TensorDictModule(
                Attention(latent_dim, to_dim),
                in_keys=["Q", "K", "V"],
                out_keys=["X_out", "Attn"],
            ),
            TensorDictModule(
                SkipLayerNorm(to_len, to_dim),
                in_keys=[to_name, "X_out"],
                out_keys=[to_name],
            ),
        )

我们使用TensorDictModule构建了将作为transformer一部分的编码器和解码器块。

class TransformerBlockEncoderTensorDict(TensorDictSequential):
    def __init__(
        self,
        to_name,
        from_name,
        to_dim,
        to_len,
        from_dim,
        latent_dim,
        num_heads,
    ):
        super().__init__(
            AttentionBlockTensorDict(
                to_name,
                from_name,
                to_dim,
                to_len,
                from_dim,
                latent_dim,
                num_heads,
            ),
            TensorDictModule(
                FFN(to_dim, 4 * to_dim),
                in_keys=[to_name],
                out_keys=["X_out"],
            ),
            TensorDictModule(
                SkipLayerNorm(to_len, to_dim),
                in_keys=[to_name, "X_out"],
                out_keys=[to_name],
            ),
        )


class TransformerBlockDecoderTensorDict(TensorDictSequential):
    def __init__(
        self,
        to_name,
        from_name,
        to_dim,
        to_len,
        from_dim,
        latent_dim,
        num_heads,
    ):
        super().__init__(
            AttentionBlockTensorDict(
                to_name,
                to_name,
                to_dim,
                to_len,
                to_dim,
                latent_dim,
                num_heads,
            ),
            TransformerBlockEncoderTensorDict(
                to_name,
                from_name,
                to_dim,
                to_len,
                from_dim,
                latent_dim,
                num_heads,
            ),
        )

我们创建了transformer的编码器和解码器。

对于编码器,我们只需要为查询、键和值都使用相同的token。

对于解码器,我们现在可以从X_from中提取信息到X_toX_from将映射到查询,而X_from将映射到键和值。

class TransformerEncoderTensorDict(TensorDictSequential):
    def __init__(
        self,
        num_blocks,
        to_name,
        from_name,
        to_dim,
        to_len,
        from_dim,
        latent_dim,
        num_heads,
    ):
        super().__init__(
            *[
                TransformerBlockEncoderTensorDict(
                    to_name,
                    from_name,
                    to_dim,
                    to_len,
                    from_dim,
                    latent_dim,
                    num_heads,
                )
                for _ in range(num_blocks)
            ]
        )


class TransformerDecoderTensorDict(TensorDictSequential):
    def __init__(
        self,
        num_blocks,
        to_name,
        from_name,
        to_dim,
        to_len,
        from_dim,
        latent_dim,
        num_heads,
    ):
        super().__init__(
            *[
                TransformerBlockDecoderTensorDict(
                    to_name,
                    from_name,
                    to_dim,
                    to_len,
                    from_dim,
                    latent_dim,
                    num_heads,
                )
                for _ in range(num_blocks)
            ]
        )


class TransformerTensorDict(TensorDictSequential):
    def __init__(
        self,
        num_blocks,
        to_name,
        from_name,
        to_dim,
        to_len,
        from_dim,
        from_len,
        latent_dim,
        num_heads,
    ):
        super().__init__(
            TransformerEncoderTensorDict(
                num_blocks,
                to_name,
                to_name,
                to_dim,
                to_len,
                to_dim,
                latent_dim,
                num_heads,
            ),
            TransformerDecoderTensorDict(
                num_blocks,
                from_name,
                to_name,
                from_dim,
                from_len,
                to_dim,
                latent_dim,
                num_heads,
            ),
        )

现在我们测试新的TransformerTensorDict

to_dim = 5
from_dim = 6
latent_dim = 10
to_len = 3
from_len = 10
batch_size = 8
num_heads = 2
num_blocks = 6

tokens = TensorDict(
    {
        "X_encode": torch.randn(batch_size, to_len, to_dim),
        "X_decode": torch.randn(batch_size, from_len, from_dim),
    },
    batch_size=[batch_size],
)

transformer = TransformerTensorDict(
    num_blocks,
    "X_encode",
    "X_decode",
    to_dim,
    to_len,
    from_dim,
    from_len,
    latent_dim,
    num_heads,
)

transformer(tokens)
tokens
TensorDict(
    fields={
        Attn: Tensor(shape=torch.Size([8, 2, 10, 3]), device=cpu, dtype=torch.float32, is_shared=False),
        K: Tensor(shape=torch.Size([8, 2, 3, 5]), device=cpu, dtype=torch.float32, is_shared=False),
        Q: Tensor(shape=torch.Size([8, 2, 10, 5]), device=cpu, dtype=torch.float32, is_shared=False),
        V: Tensor(shape=torch.Size([8, 2, 3, 5]), device=cpu, dtype=torch.float32, is_shared=False),
        X_decode: Tensor(shape=torch.Size([8, 10, 6]), device=cpu, dtype=torch.float32, is_shared=False),
        X_encode: Tensor(shape=torch.Size([8, 3, 5]), device=cpu, dtype=torch.float32, is_shared=False),
        X_out: Tensor(shape=torch.Size([8, 10, 6]), device=cpu, dtype=torch.float32, is_shared=False)},
    batch_size=torch.Size([8]),
    device=None,
    is_shared=False)

我们已经成功地使用TensorDictModule创建了一个transformer。这表明TensorDictModule是一个灵活的模块,可以实现复杂的操作。

基准测试

tdtransformer = TransformerTensorDict(
    num_blocks,
    "X_encode",
    "X_decode",
    to_dim,
    to_len,
    from_dim,
    from_len,
    latent_dim,
    num_heads,
)

推理时间

import time
t1 = time.time()
tokens = tdtransformer(td_tokens)
t2 = time.time()
print("Execution time:", t2 - t1, "seconds")
Execution time: 0.009625911712646484 seconds
t3 = time.time()
X_out = transformer(X_encode, X_decode)
t4 = time.time()
print("Execution time:", t4 - t3, "seconds")
Execution time: 0.006480216979980469 seconds

在这个简单的例子中,我们可以看到TensorDictModule引入的开销是微不足道的。

尽情享受TensorDictModule吧!

脚本的总运行时间:(0 分钟 10.088 秒)

由Sphinx-Gallery生成的图库

文档

访问 PyTorch 的全面开发者文档

查看文档

教程

获取适合初学者和高级开发人员的深入教程

查看教程

资源

查找开发资源并获得问题的解答

查看资源