快捷方式

WrapModule

class tensordict.nn.WrapModule(*args, **kwargs)

一个用于包装处理 TensorDict 实例的任意可调用对象的包装器。

当构建 TensorDictSequential 栈时,以及当转换需要整个 TensorDict 实例可见时,这个包装器非常有用。

参数:

func (Callable[[TensorDictBase], TensorDictBase]) – 一个可调用函数,接受一个 TensorDictBase 实例并返回一个转换后的 TensorDictBase 实例。

关键字参数:
  • inplace (bool, optional) – 如果为 True,则原地修改输入的 TensorDict。否则,将返回一个新的 TensorDict(如果函数未原地修改并返回它)。默认为 False

  • in_keys (list of NestedKey, optional) – 如果提供,表示模块读取哪些条目。这不会被检查,仅用于通知 TensorDictSequential 关于包装模块的输入键。默认为 []

  • out_keys (list of NestedKey, optional) – 如果提供,表示模块写入哪些条目。这不会被检查,仅用于通知 TensorDictSequential 关于包装模块的输出键。默认为 []

示例

>>> from tensordict.nn import TensorDictSequential as Seq, TensorDictModule as Mod, WrapModule
>>> seq = Seq(
...     Mod(lambda x: x * 2, in_keys=["x"], out_keys=["y"]),
...     WrapModule(lambda td: td.reshape(-1)),
... )
>>> td = TensorDict(x=torch.ones(3, 4, 5), batch_size=[3, 4])
>>> td = Seq(td)
>>> assert td.shape == (12,)
>>> assert (td["y"] == 2).all()
>>> assert td["y"].shape == (12, 5)
forward(data: TensorDictBase) TensorDictBase

定义每次调用时执行的计算。

应由所有子类覆盖。

注意

尽管正向传播的实现需要在该函数内部定义,但后续应调用 Module 实例而非此函数本身,因为前者负责运行注册的钩子,而后者则会默默忽略它们。


© 版权所有 2022, Meta.

使用 Sphinx 构建,主题由 Read the Docs 提供。

文档

查阅 PyTorch 的全面开发者文档

查看文档

教程

获取面向初学者和高级开发者的深度教程

查看教程

资源

查找开发资源并获得问题解答

查看资源