TensorDictModuleBase¶
- class tensordict.nn.TensorDictModuleBase(*args, **kwargs)¶
TensorDict 模块的基类。
TensorDictModule 的子类通过
in_keys
和out_keys
键列表来标识,这些列表指示应读取哪些输入条目以及应写入哪些输出条目。forward 方法的输入/输出签名应始终遵循以下约定
>>> tensordict_out = module.forward(tensordict_in)
与
TensorDictModule
不同,TensorDictModuleBase 通常通过子类化使用:只要子类 forward 方法读写 tensordict(或相关类型)实例,您就可以将任何 Python 函数包装到 TensorDictModuleBase 子类中。应正确指定 in_keys 和 out_keys。例如,可以使用
select_out_keys()
动态减少 out_keys。示例
>>> from tensordict import TensorDict >>> from tensordict.nn import TensorDictModuleBase >>> class Mod(TensorDictModuleBase): ... in_keys = ["a"] # can also be specified during __init__ ... out_keys = ["b", "c"] ... def forward(self, tensordict): ... b = tensordict["a"].clone() ... c = b + 1 ... return tensordict.replace({"b": b, "c": c}) >>> mod = Mod() >>> td = mod(TensorDict(a=0)) >>> td["b"] tensor(0) >>> td["c"] tensor(1) >>> mod.select_out_keys("c") >>> td = mod(TensorDict(a=0)) >>> td["c"] tensor(1) >>> assert "b" not in td
- static is_tdmodule_compatible(module)¶
检查模块是否与 TensorDictModule API 兼容。
- reset_out_keys()¶
将
out_keys
属性重置为其原始值。返回值:具有其原始
out_keys
值的同一模块。示例
>>> from tensordict import TensorDict >>> from tensordict.nn import TensorDictModule, TensorDictSequential >>> import torch >>> mod = TensorDictModule(lambda x, y: (x+2, y+2), in_keys=["a", "b"], out_keys=["c", "d"]) >>> mod.select_out_keys("d") >>> td = TensorDict({"a": torch.zeros(()), "b": torch.ones(())}, []) >>> mod(td) TensorDict( fields={ a: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False), b: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False), d: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False)}, batch_size=torch.Size([]), device=None, is_shared=False) >>> mod.reset_out_keys() >>> mod(td) TensorDict( fields={ a: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False), b: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False), c: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False), d: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False)}, batch_size=torch.Size([]), device=None, is_shared=False)
- reset_parameters_recursive(parameters: Optional[TensorDictBase] = None) Optional[TensorDictBase] ¶
递归重置模块及其子模块的参数。
- 参数:
parameters (参数的 TensorDict,可选) – 如果设置为 None,模块将使用 self.parameters() 进行重置。否则,我们将原地重置 tensordict 中的参数。这对于参数不存储在模块本身的函数式模块非常有用。
- 返回值:
新的参数的 tensordict,仅在 parameters 不为 None 时返回。
示例
>>> from tensordict.nn import TensorDictModule >>> from torch import nn >>> net = nn.Sequential(nn.Linear(2,3), nn.ReLU()) >>> old_param = net[0].weight.clone() >>> module = TensorDictModule(net, in_keys=['bork'], out_keys=['dork']) >>> module.reset_parameters() >>> (old_param == net[0].weight).any() tensor(False)
此方法也支持函数式参数采样
>>> from tensordict import TensorDict >>> from tensordict.nn import TensorDictModule >>> from torch import nn >>> net = nn.Sequential(nn.Linear(2,3), nn.ReLU()) >>> module = TensorDictModule(net, in_keys=['bork'], out_keys=['dork']) >>> params = TensorDict.from_module(module) >>> old_params = params.clone(recurse=True) >>> module.reset_parameters(params) >>> (old_params == params).any() False
- select_out_keys(*out_keys) TensorDictModuleBase ¶
选择将在输出 tensordict 中找到的键。
这在需要删除复杂图中的中间键,或者这些键的存在可能引发意外行为时非常有用。
原始的
out_keys
仍然可以通过module.out_keys_source
访问。- 参数:
*out_keys (字符串序列或字符串元组) – 应在输出 tensordict 中找到的 out_keys。
返回值:已就地修改的同一模块,其中
out_keys
已更新。最简单的用法是结合
TensorDictModule
示例
>>> from tensordict import TensorDict >>> from tensordict.nn import TensorDictModule, TensorDictSequential >>> import torch >>> mod = TensorDictModule(lambda x, y: (x+2, y+2), in_keys=["a", "b"], out_keys=["c", "d"]) >>> td = TensorDict({"a": torch.zeros(()), "b": torch.ones(())}, []) >>> mod(td) TensorDict( fields={ a: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False), b: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False), c: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False), d: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False)}, batch_size=torch.Size([]), device=None, is_shared=False) >>> mod.select_out_keys("d") >>> td = TensorDict({"a": torch.zeros(()), "b": torch.ones(())}, []) >>> mod(td) TensorDict( fields={ a: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False), b: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False), d: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False)}, batch_size=torch.Size([]), device=None, is_shared=False)
此功能也适用于已分派的参数: .. 标题:示例
>>> mod(torch.zeros(()), torch.ones(())) tensor(2.)
此更改将原地发生(即返回的仍是同一模块,但 out_keys 列表已更新)。可以使用
TensorDictModuleBase.reset_out_keys()
方法恢复此更改。示例
>>> mod.reset_out_keys() >>> mod(TensorDict({"a": torch.zeros(()), "b": torch.ones(())}, [])) TensorDict( fields={ a: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False), b: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False), c: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False), d: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False)}, batch_size=torch.Size([]), device=None, is_shared=False)
这也适用于其他类,例如 Sequential: .. 标题:示例
>>> from tensordict.nn import TensorDictSequential >>> seq = TensorDictSequential( ... TensorDictModule(lambda x: x+1, in_keys=["x"], out_keys=["y"]), ... TensorDictModule(lambda x: x+1, in_keys=["y"], out_keys=["z"]), ... ) >>> td = TensorDict({"x": torch.zeros(())}, []) >>> seq(td) TensorDict( fields={ x: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False), y: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False), z: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False)}, batch_size=torch.Size([]), device=None, is_shared=False) >>> seq.select_out_keys("z") >>> td = TensorDict({"x": torch.zeros(())}, []) >>> seq(td) TensorDict( fields={ x: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False), z: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False)}, batch_size=torch.Size([]), device=None, is_shared=False)