快捷方式

VmapModule

class torchrl.modules.VmapModule(*args, **kwargs)[source]

一个 TensorDictModule 包装器,用于对输入进行 vmap 操作。

它旨在与接受比提供的数据少一个批处理维度的模块一起使用。通过使用此包装器,可以隐藏一个批处理维度并满足被包装模块的要求。

参数:
  • module (TensorDictModuleBase) – 要进行 vmap 操作的模块。

  • vmap_dim (int, optional) – vmap 输入和输出维度。如果未提供,则假定为 tensordict 的最后一个维度。

注意

由于 vmap 要求控制输入的批大小,此模块不支持分派的参数

示例

>>> lam = TensorDictModule(lambda x: x[0], in_keys=["x"], out_keys=["y"])
>>> sample_in = torch.ones((10,3,2))
>>> sample_in_td = TensorDict({"x":sample_in}, batch_size=[10])
>>> lam(sample_in)
>>> vm = VmapModule(lam, 0)
>>> vm(sample_in_td)
>>> assert (sample_in_td["x"][:, 0] == sample_in_td["y"]).all()
forward(tensordict)[source]

定义每次调用时执行的计算。

应由所有子类覆盖。

注意

尽管需要在该函数内部定义前向传播的实现方式,但之后应调用 Module 实例而不是该函数本身,因为前者负责运行已注册的钩子,而后者会默默忽略它们。

文档

查阅 PyTorch 的全面开发者文档

查看文档

教程

获取面向初学者和高级开发者的深度教程

查看教程

资源

查找开发资源并获得问题的解答

查看资源