快捷方式

TensorDictModuleBase

class tensordict.nn.TensorDictModuleBase(*args, **kwargs)

TensorDict 模块的基类。

TensorDictModule 子类以 in_keysout_keys 键列表为特征,它们指示要读取哪些输入项以及预期写入哪些输出项。

前向方法的输入/输出签名应始终遵循以下约定

>>> tensordict_out = module.forward(tensordict_in)
static is_tdmodule_compatible(module)

检查模块是否与 TensorDictModule API 兼容。

reset_out_keys()

out_keys 属性重置为其原始值。

返回值:具有其原始 out_keys 值的相同模块。

示例

>>> from tensordict import TensorDict
>>> from tensordict.nn import TensorDictModule, TensorDictSequential
>>> import torch
>>> mod = TensorDictModule(lambda x, y: (x+2, y+2), in_keys=["a", "b"], out_keys=["c", "d"])
>>> mod.select_out_keys("d")
>>> td = TensorDict({"a": torch.zeros(()), "b": torch.ones(())}, [])
>>> mod(td)
TensorDict(
    fields={
        a: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False),
        b: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False),
        d: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False)},
    batch_size=torch.Size([]),
    device=None,
    is_shared=False)
>>> mod.reset_out_keys()
>>> mod(td)
TensorDict(
    fields={
        a: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False),
        b: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False),
        c: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False),
        d: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False)},
    batch_size=torch.Size([]),
    device=None,
    is_shared=False)
reset_parameters_recursive(parameters: Optional[TensorDictBase] = None) Optional[TensorDictBase]

递归地重置模块及其子模块的参数。

参数:

parameters (参数的 TensorDict, 可选) – 如果设置为 None,则模块将使用 self.parameters() 重置。否则,我们将就地重置 tensordict 中的参数。这对于参数未存储在模块本身中的函数模块很有用。

返回值:

参数的 tensordict,仅当参数不为 None 时。

示例

>>> from tensordict.nn import TensorDictModule
>>> from torch import nn
>>> net = nn.Sequential(nn.Linear(2,3), nn.ReLU())
>>> old_param = net[0].weight.clone()
>>> module = TensorDictModule(net, in_keys=['bork'], out_keys=['dork'])
>>> module.reset_parameters()
>>> (old_param == net[0].weight).any()
tensor(False)

此方法还支持函数参数采样

>>> from tensordict import TensorDict
>>> from tensordict.nn import TensorDictModule
>>> from torch import nn
>>> net = nn.Sequential(nn.Linear(2,3), nn.ReLU())
>>> module = TensorDictModule(net, in_keys=['bork'], out_keys=['dork'])
>>> params = TensorDict.from_module(module)
>>> old_params = params.clone(recurse=True)
>>> module.reset_parameters(params)
>>> (old_params == params).any()
False
select_out_keys(*out_keys)

选择在输出 tensordict 中找到的键。

这在希望从复杂的图形中去除中间键时很有用,或者当这些键的存在可能触发意外行为时。

原始的 out_keys 仍然可以通过 module.out_keys_source 访问。

参数:

*out_keys (字符串序列字符串元组) – 输出 tensordict 中应该找到的 out_keys。

返回值:具有更新的 out_keys 的相同模块,就地修改。

最简单的用法是使用 TensorDictModule

示例

>>> from tensordict import TensorDict
>>> from tensordict.nn import TensorDictModule, TensorDictSequential
>>> import torch
>>> mod = TensorDictModule(lambda x, y: (x+2, y+2), in_keys=["a", "b"], out_keys=["c", "d"])
>>> td = TensorDict({"a": torch.zeros(()), "b": torch.ones(())}, [])
>>> mod(td)
TensorDict(
    fields={
        a: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False),
        b: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False),
        c: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False),
        d: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False)},
    batch_size=torch.Size([]),
    device=None,
    is_shared=False)
>>> mod.select_out_keys("d")
>>> td = TensorDict({"a": torch.zeros(()), "b": torch.ones(())}, [])
>>> mod(td)
TensorDict(
    fields={
        a: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False),
        b: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False),
        d: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False)},
    batch_size=torch.Size([]),
    device=None,
    is_shared=False)

此功能也适用于分派的论据: .. rubric:: 示例

>>> mod(torch.zeros(()), torch.ones(()))
tensor(2.)

此更改将就地发生(即返回的将是具有更新的 out_keys 列表的相同模块)。可以使用 TensorDictModuleBase.reset_out_keys() 方法将其还原。

示例

>>> mod.reset_out_keys()
>>> mod(TensorDict({"a": torch.zeros(()), "b": torch.ones(())}, []))
TensorDict(
    fields={
        a: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False),
        b: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False),
        c: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False),
        d: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False)},
    batch_size=torch.Size([]),
    device=None,
    is_shared=False)

这将与其他类(如 Sequential)一起使用: .. rubric:: 示例

>>> from tensordict.nn import TensorDictSequential
>>> seq = TensorDictSequential(
...     TensorDictModule(lambda x: x+1, in_keys=["x"], out_keys=["y"]),
...     TensorDictModule(lambda x: x+1, in_keys=["y"], out_keys=["z"]),
... )
>>> td = TensorDict({"x": torch.zeros(())}, [])
>>> seq(td)
TensorDict(
    fields={
        x: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False),
        y: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False),
        z: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False)},
    batch_size=torch.Size([]),
    device=None,
    is_shared=False)
>>> seq.select_out_keys("z")
>>> td = TensorDict({"x": torch.zeros(())}, [])
>>> seq(td)
TensorDict(
    fields={
        x: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False),
        z: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False)},
    batch_size=torch.Size([]),
    device=None,
    is_shared=False)

文档

访问 PyTorch 的全面开发者文档

查看文档

教程

获取面向初学者和高级开发人员的深入教程

查看教程

资源

查找开发资源并获得问题的解答

查看资源