注意
转到末尾 下载完整的示例代码。
TensorDictModule¶
在本教程中,您将学习如何使用 TensorDictModule
和 TensorDictSequential
创建通用的、可重用的模块,这些模块可以接受 TensorDict
作为输入。
为了方便地将 TensorDict
类与 nn.Module
一起使用,tensordict
提供了这两个之间的接口,称为 TensorDictModule
。 TensorDictModule
类是一个 nn.Module
,在被调用时会接收一个 TensorDict
作为输入。用户需要定义要读取作为输入和输出的键。
TensorDictModule 示例¶
import torch
import torch.nn as nn
from tensordict import TensorDict
from tensordict.nn import TensorDictModule, TensorDictSequential
示例 1:简单用法¶
我们有一个 TensorDict
,其中包含 2 个条目 "a"
和 "b"
,但网络只需要读取与 "a"
关联的值。
tensordict = TensorDict(
{"a": torch.randn(5, 3), "b": torch.zeros(5, 4, 3)},
batch_size=[5],
)
linear = TensorDictModule(nn.Linear(3, 10), in_keys=["a"], out_keys=["a_out"])
linear(tensordict)
assert (tensordict.get("b") == 0).all()
print(tensordict)
TensorDict(
fields={
a: Tensor(shape=torch.Size([5, 3]), device=cpu, dtype=torch.float32, is_shared=False),
a_out: Tensor(shape=torch.Size([5, 10]), device=cpu, dtype=torch.float32, is_shared=False),
b: Tensor(shape=torch.Size([5, 4, 3]), device=cpu, dtype=torch.float32, is_shared=False)},
batch_size=torch.Size([5]),
device=None,
is_shared=False)
示例 2:多个输入¶
假设我们有一个稍微复杂的网络,它接收 2 个条目并将它们平均到一个输出张量中。为了使 TensorDictModule
实例读取多个输入值,必须在构造函数的 in_keys
关键字参数中注册它们。
tensordict = TensorDict(
{
"a": torch.randn(5, 3),
"b": torch.randn(5, 4),
},
batch_size=[5],
)
mergelinear = TensorDictModule(
MergeLinear(3, 4, 10), in_keys=["a", "b"], out_keys=["output"]
)
mergelinear(tensordict)
TensorDict(
fields={
a: Tensor(shape=torch.Size([5, 3]), device=cpu, dtype=torch.float32, is_shared=False),
b: Tensor(shape=torch.Size([5, 4]), device=cpu, dtype=torch.float32, is_shared=False),
output: Tensor(shape=torch.Size([5, 10]), device=cpu, dtype=torch.float32, is_shared=False)},
batch_size=torch.Size([5]),
device=None,
is_shared=False)
示例 3:多个输出¶
类似地,TensorDictModule
不仅支持多个输入,还支持多个输出。为了使 TensorDictModule
实例写入多个输出值,必须在构造函数的 out_keys
关键字参数中注册它们。
tensordict = TensorDict({"a": torch.randn(5, 3)}, batch_size=[5])
splitlinear = TensorDictModule(
MultiHeadLinear(3, 4, 10),
in_keys=["a"],
out_keys=["output_1", "output_2"],
)
splitlinear(tensordict)
TensorDict(
fields={
a: Tensor(shape=torch.Size([5, 3]), device=cpu, dtype=torch.float32, is_shared=False),
output_1: Tensor(shape=torch.Size([5, 4]), device=cpu, dtype=torch.float32, is_shared=False),
output_2: Tensor(shape=torch.Size([5, 10]), device=cpu, dtype=torch.float32, is_shared=False)},
batch_size=torch.Size([5]),
device=None,
is_shared=False)
当有多个输入键和输出键时,请确保它们与模块中的顺序匹配。
TensorDictModule
可以与包含比 in_keys
属性指示的更多张量的 TensorDict
实例一起使用。
除非使用 vmap
运算符,否则 TensorDict
将被就地修改。
忽略一些输出
请注意,可以使用 "_"
在 out_keys
中来避免将某些张量写入 TensorDict
输出。
示例 4:将多个 TensorDictModule
与 TensorDictSequential
结合使用¶
要组合多个 TensorDictModule
实例,我们可以使用 TensorDictSequential
。我们创建一个列表,其中每个 TensorDictModule
必须按顺序执行。 TensorDictSequential
将按照提供的模块序列读取和写入 tensordict 中的键。
我们还可以使用 in_keys
属性收集 TensorDictSequential
所需的输入,并且输出键在 out_keys
属性中找到。
tensordict = TensorDict({"a": torch.randn(5, 3)}, batch_size=[5])
splitlinear = TensorDictModule(
MultiHeadLinear(3, 4, 10),
in_keys=["a"],
out_keys=["output_1", "output_2"],
)
mergelinear = TensorDictModule(
MergeLinear(4, 10, 13),
in_keys=["output_1", "output_2"],
out_keys=["output"],
)
split_and_merge_linear = TensorDictSequential(splitlinear, mergelinear)
assert split_and_merge_linear(tensordict)["output"].shape == torch.Size([5, 13])
TensorDictModule 的注意事项¶
不要使用 nn.Sequence
,类似于 nn.Module
,它会破坏 functorch
兼容性等功能。请改用 TensorDictSequential
。
不要将输出 tensordict 分配给一个新的变量,因为输出 tensordict 只是就地修改的输入。
tensordict = module(tensordict) # 可以!
tensordict_out = module(tensordict) # 不要!
ProbabilisticTensorDictModule
¶
ProbabilisticTensorDictModule
是一个表示概率分布的非参数模块。分布参数从 tensordict 输入中读取,输出写入到输出 tensordict 中。给定某个规则,根据输入的 default_interaction_type
参数和 exploration_mode()
全局函数对输出进行采样。如果它们发生冲突,则上下文管理器优先。
它可以与返回一个使用 ProbabilisticTensorDictSequential
更新了分布参数的 tensordict 的 TensorDictModule
连接起来。这是 TensorDictSequential
的一个特例,它以 ProbabilisticTensorDictModule
结束。
ProbabilisticTensorDictModule
负责构建分布(通过 get_dist()
方法)和/或从该分布中采样(通过对模块进行常规的 __call__()
)。相同的 get_dist()
方法在 ``ProbabilisticTensorDictSequential 上公开。
可以在输出 tensordict 中找到参数,以及如果需要的话,找到对数概率。
from tensordict.nn import (
ProbabilisticTensorDictModule,
ProbabilisticTensorDictSequential,
)
from tensordict.nn.distributions import NormalParamExtractor
from torch import distributions as dist
td = TensorDict({"input": torch.randn(3, 4), "hidden": torch.randn(3, 8)}, [3])
net = torch.nn.GRUCell(4, 8)
net = TensorDictModule(net, in_keys=["input", "hidden"], out_keys=["hidden"])
extractor = NormalParamExtractor()
extractor = TensorDictModule(extractor, in_keys=["hidden"], out_keys=["loc", "scale"])
td_module = ProbabilisticTensorDictSequential(
net,
extractor,
ProbabilisticTensorDictModule(
in_keys=["loc", "scale"],
out_keys=["action"],
distribution_class=dist.Normal,
return_log_prob=True,
),
)
print(f"TensorDict before going through module: {td}")
td_module(td)
print(f"TensorDict after going through module now as keys action, loc and scale: {td}")
TensorDict before going through module: TensorDict(
fields={
hidden: Tensor(shape=torch.Size([3, 8]), device=cpu, dtype=torch.float32, is_shared=False),
input: Tensor(shape=torch.Size([3, 4]), device=cpu, dtype=torch.float32, is_shared=False)},
batch_size=torch.Size([3]),
device=None,
is_shared=False)
TensorDict after going through module now as keys action, loc and scale: TensorDict(
fields={
action: Tensor(shape=torch.Size([3, 4]), device=cpu, dtype=torch.float32, is_shared=False),
hidden: Tensor(shape=torch.Size([3, 8]), device=cpu, dtype=torch.float32, is_shared=False),
input: Tensor(shape=torch.Size([3, 4]), device=cpu, dtype=torch.float32, is_shared=False),
loc: Tensor(shape=torch.Size([3, 4]), device=cpu, dtype=torch.float32, is_shared=False),
sample_log_prob: Tensor(shape=torch.Size([3, 4]), device=cpu, dtype=torch.float32, is_shared=False),
scale: Tensor(shape=torch.Size([3, 4]), device=cpu, dtype=torch.float32, is_shared=False)},
batch_size=torch.Size([3]),
device=None,
is_shared=False)
展示:使用 TensorDictModule 实现 Transformer¶
为了展示 TensorDictModule
的灵活性,我们将创建一个使用 TensorDictModule
读取 TensorDict
对象的 Transformer。
下图显示了经典的 Transformer 架构(Vaswani 等人,2017)。
为了简单起见,我们忽略了位置编码器。
让我们重写经典的 Transformer 块
class TokensToQKV(nn.Module):
def __init__(self, to_dim, from_dim, latent_dim):
super().__init__()
self.q = nn.Linear(to_dim, latent_dim)
self.k = nn.Linear(from_dim, latent_dim)
self.v = nn.Linear(from_dim, latent_dim)
def forward(self, X_to, X_from):
Q = self.q(X_to)
K = self.k(X_from)
V = self.v(X_from)
return Q, K, V
class SplitHeads(nn.Module):
def __init__(self, num_heads):
super().__init__()
self.num_heads = num_heads
def forward(self, Q, K, V):
batch_size, to_num, latent_dim = Q.shape
_, from_num, _ = K.shape
d_tensor = latent_dim // self.num_heads
Q = Q.reshape(batch_size, to_num, self.num_heads, d_tensor).transpose(1, 2)
K = K.reshape(batch_size, from_num, self.num_heads, d_tensor).transpose(1, 2)
V = V.reshape(batch_size, from_num, self.num_heads, d_tensor).transpose(1, 2)
return Q, K, V
class Attention(nn.Module):
def __init__(self, latent_dim, to_dim):
super().__init__()
self.softmax = nn.Softmax(dim=-1)
self.out = nn.Linear(latent_dim, to_dim)
def forward(self, Q, K, V):
batch_size, n_heads, to_num, d_in = Q.shape
attn = self.softmax(Q @ K.transpose(2, 3) / d_in)
out = attn @ V
out = self.out(out.transpose(1, 2).reshape(batch_size, to_num, n_heads * d_in))
return out, attn
class SkipLayerNorm(nn.Module):
def __init__(self, to_len, to_dim):
super().__init__()
self.layer_norm = nn.LayerNorm((to_len, to_dim))
def forward(self, x_0, x_1):
return self.layer_norm(x_0 + x_1)
class FFN(nn.Module):
def __init__(self, to_dim, hidden_dim, dropout_rate=0.2):
super().__init__()
self.FFN = nn.Sequential(
nn.Linear(to_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, to_dim),
nn.Dropout(dropout_rate),
)
def forward(self, X):
return self.FFN(X)
class AttentionBlock(nn.Module):
def __init__(self, to_dim, to_len, from_dim, latent_dim, num_heads):
super().__init__()
self.tokens_to_qkv = TokensToQKV(to_dim, from_dim, latent_dim)
self.split_heads = SplitHeads(num_heads)
self.attention = Attention(latent_dim, to_dim)
self.skip = SkipLayerNorm(to_len, to_dim)
def forward(self, X_to, X_from):
Q, K, V = self.tokens_to_qkv(X_to, X_from)
Q, K, V = self.split_heads(Q, K, V)
out, attention = self.attention(Q, K, V)
out = self.skip(X_to, out)
return out
class EncoderTransformerBlock(nn.Module):
def __init__(self, to_dim, to_len, latent_dim, num_heads):
super().__init__()
self.attention_block = AttentionBlock(
to_dim, to_len, to_dim, latent_dim, num_heads
)
self.FFN = FFN(to_dim, 4 * to_dim)
self.skip = SkipLayerNorm(to_len, to_dim)
def forward(self, X_to):
X_to = self.attention_block(X_to, X_to)
X_out = self.FFN(X_to)
return self.skip(X_out, X_to)
class DecoderTransformerBlock(nn.Module):
def __init__(self, to_dim, to_len, from_dim, latent_dim, num_heads):
super().__init__()
self.attention_block = AttentionBlock(
to_dim, to_len, from_dim, latent_dim, num_heads
)
self.encoder_block = EncoderTransformerBlock(
to_dim, to_len, latent_dim, num_heads
)
def forward(self, X_to, X_from):
X_to = self.attention_block(X_to, X_from)
X_to = self.encoder_block(X_to)
return X_to
class TransformerEncoder(nn.Module):
def __init__(self, num_blocks, to_dim, to_len, latent_dim, num_heads):
super().__init__()
self.encoder = nn.ModuleList(
[
EncoderTransformerBlock(to_dim, to_len, latent_dim, num_heads)
for i in range(num_blocks)
]
)
def forward(self, X_to):
for i in range(len(self.encoder)):
X_to = self.encoder[i](X_to)
return X_to
class TransformerDecoder(nn.Module):
def __init__(self, num_blocks, to_dim, to_len, from_dim, latent_dim, num_heads):
super().__init__()
self.decoder = nn.ModuleList(
[
DecoderTransformerBlock(to_dim, to_len, from_dim, latent_dim, num_heads)
for i in range(num_blocks)
]
)
def forward(self, X_to, X_from):
for i in range(len(self.decoder)):
X_to = self.decoder[i](X_to, X_from)
return X_to
class Transformer(nn.Module):
def __init__(
self, num_blocks, to_dim, to_len, from_dim, from_len, latent_dim, num_heads
):
super().__init__()
self.encoder = TransformerEncoder(
num_blocks, to_dim, to_len, latent_dim, num_heads
)
self.decoder = TransformerDecoder(
num_blocks, from_dim, from_len, to_dim, latent_dim, num_heads
)
def forward(self, X_to, X_from):
X_to = self.encoder(X_to)
X_out = self.decoder(X_from, X_to)
return X_out
我们首先创建AttentionBlockTensorDict
,这是一个使用TensorDictModule
和TensorDictSequential
构建的注意力块。
将模块连接起来的布线操作需要我们指示每个模块必须读取和写入哪个键。与nn.Sequence
不同,TensorDictSequential
可以读取/写入多个输入/输出。此外,其组件的输入不必与前一层输出相同,这使我们能够编写复杂的neural architecture。
class AttentionBlockTensorDict(TensorDictSequential):
def __init__(
self,
to_name,
from_name,
to_dim,
to_len,
from_dim,
latent_dim,
num_heads,
):
super().__init__(
TensorDictModule(
TokensToQKV(to_dim, from_dim, latent_dim),
in_keys=[to_name, from_name],
out_keys=["Q", "K", "V"],
),
TensorDictModule(
SplitHeads(num_heads),
in_keys=["Q", "K", "V"],
out_keys=["Q", "K", "V"],
),
TensorDictModule(
Attention(latent_dim, to_dim),
in_keys=["Q", "K", "V"],
out_keys=["X_out", "Attn"],
),
TensorDictModule(
SkipLayerNorm(to_len, to_dim),
in_keys=[to_name, "X_out"],
out_keys=[to_name],
),
)
我们使用TensorDictModule
构建了将作为transformer一部分的编码器和解码器块。
class TransformerBlockEncoderTensorDict(TensorDictSequential):
def __init__(
self,
to_name,
from_name,
to_dim,
to_len,
from_dim,
latent_dim,
num_heads,
):
super().__init__(
AttentionBlockTensorDict(
to_name,
from_name,
to_dim,
to_len,
from_dim,
latent_dim,
num_heads,
),
TensorDictModule(
FFN(to_dim, 4 * to_dim),
in_keys=[to_name],
out_keys=["X_out"],
),
TensorDictModule(
SkipLayerNorm(to_len, to_dim),
in_keys=[to_name, "X_out"],
out_keys=[to_name],
),
)
class TransformerBlockDecoderTensorDict(TensorDictSequential):
def __init__(
self,
to_name,
from_name,
to_dim,
to_len,
from_dim,
latent_dim,
num_heads,
):
super().__init__(
AttentionBlockTensorDict(
to_name,
to_name,
to_dim,
to_len,
to_dim,
latent_dim,
num_heads,
),
TransformerBlockEncoderTensorDict(
to_name,
from_name,
to_dim,
to_len,
from_dim,
latent_dim,
num_heads,
),
)
我们创建了transformer的编码器和解码器。
对于编码器,我们只需要为查询、键和值都使用相同的token。
对于解码器,我们现在可以从X_from
中提取信息到X_to
。 X_from
将映射到查询,而X_from
将映射到键和值。
class TransformerEncoderTensorDict(TensorDictSequential):
def __init__(
self,
num_blocks,
to_name,
from_name,
to_dim,
to_len,
from_dim,
latent_dim,
num_heads,
):
super().__init__(
*[
TransformerBlockEncoderTensorDict(
to_name,
from_name,
to_dim,
to_len,
from_dim,
latent_dim,
num_heads,
)
for _ in range(num_blocks)
]
)
class TransformerDecoderTensorDict(TensorDictSequential):
def __init__(
self,
num_blocks,
to_name,
from_name,
to_dim,
to_len,
from_dim,
latent_dim,
num_heads,
):
super().__init__(
*[
TransformerBlockDecoderTensorDict(
to_name,
from_name,
to_dim,
to_len,
from_dim,
latent_dim,
num_heads,
)
for _ in range(num_blocks)
]
)
class TransformerTensorDict(TensorDictSequential):
def __init__(
self,
num_blocks,
to_name,
from_name,
to_dim,
to_len,
from_dim,
from_len,
latent_dim,
num_heads,
):
super().__init__(
TransformerEncoderTensorDict(
num_blocks,
to_name,
to_name,
to_dim,
to_len,
to_dim,
latent_dim,
num_heads,
),
TransformerDecoderTensorDict(
num_blocks,
from_name,
to_name,
from_dim,
from_len,
to_dim,
latent_dim,
num_heads,
),
)
现在我们测试新的TransformerTensorDict
。
to_dim = 5
from_dim = 6
latent_dim = 10
to_len = 3
from_len = 10
batch_size = 8
num_heads = 2
num_blocks = 6
tokens = TensorDict(
{
"X_encode": torch.randn(batch_size, to_len, to_dim),
"X_decode": torch.randn(batch_size, from_len, from_dim),
},
batch_size=[batch_size],
)
transformer = TransformerTensorDict(
num_blocks,
"X_encode",
"X_decode",
to_dim,
to_len,
from_dim,
from_len,
latent_dim,
num_heads,
)
transformer(tokens)
tokens
TensorDict(
fields={
Attn: Tensor(shape=torch.Size([8, 2, 10, 3]), device=cpu, dtype=torch.float32, is_shared=False),
K: Tensor(shape=torch.Size([8, 2, 3, 5]), device=cpu, dtype=torch.float32, is_shared=False),
Q: Tensor(shape=torch.Size([8, 2, 10, 5]), device=cpu, dtype=torch.float32, is_shared=False),
V: Tensor(shape=torch.Size([8, 2, 3, 5]), device=cpu, dtype=torch.float32, is_shared=False),
X_decode: Tensor(shape=torch.Size([8, 10, 6]), device=cpu, dtype=torch.float32, is_shared=False),
X_encode: Tensor(shape=torch.Size([8, 3, 5]), device=cpu, dtype=torch.float32, is_shared=False),
X_out: Tensor(shape=torch.Size([8, 10, 6]), device=cpu, dtype=torch.float32, is_shared=False)},
batch_size=torch.Size([8]),
device=None,
is_shared=False)
我们已经成功地使用TensorDictModule
创建了一个transformer。这表明TensorDictModule
是一个灵活的模块,可以实现复杂的操作。
基准测试¶
to_dim = 5
from_dim = 6
latent_dim = 10
to_len = 3
from_len = 10
batch_size = 8
num_heads = 2
num_blocks = 6
td_tokens = TensorDict(
{
"X_encode": torch.randn(batch_size, to_len, to_dim),
"X_decode": torch.randn(batch_size, from_len, from_dim),
},
batch_size=[batch_size],
)
tdtransformer = TransformerTensorDict(
num_blocks,
"X_encode",
"X_decode",
to_dim,
to_len,
from_dim,
from_len,
latent_dim,
num_heads,
)
transformer = Transformer(
num_blocks, to_dim, to_len, from_dim, from_len, latent_dim, num_heads
)
推理时间
import time
Execution time: 0.009625911712646484 seconds
Execution time: 0.006480216979980469 seconds
在这个简单的例子中,我们可以看到TensorDictModule
引入的开销是微不足道的。
尽情享受TensorDictModule吧!
脚本的总运行时间:(0 分钟 10.088 秒)