快捷方式

VmapModule

class torchrl.modules.VmapModule(*args, **kwargs)[source]

一个 TensorDictModule 包装器,用于对输入进行 vmap。

它旨在与接受数据(其批量维度比提供的维度少一个)的模块一起使用。通过使用此包装器,可以隐藏一个批量维度并满足包装的模块。

参数:
  • module (TensorDictModuleBase) – 要进行 vmap 的模块。

  • vmap_dim (int, 可选) – vmap 输入和输出维度。如果没有提供,则假定为 tensordict 的最后一个维度。

注意

由于 vmap 需要控制输入的批量大小,因此此模块不支持分派参数

示例

>>> lam = TensorDictModule(lambda x: x[0], in_keys=["x"], out_keys=["y"])
>>> sample_in = torch.ones((10,3,2))
>>> sample_in_td = TensorDict({"x":sample_in}, batch_size=[10])
>>> lam(sample_in)
>>> vm = VmapModule(lam, 0)
>>> vm(sample_in_td)
>>> assert (sample_in_td["x"][:, 0] == sample_in_td["y"]).all()
forward(tensordict)[source]

定义每次调用时执行的计算。

所有子类都应覆盖此方法。

注意

尽管正向传递的步骤需要在此函数中定义,但应该随后调用 Module 实例,而不是此方法,因为前者负责运行已注册的钩子,而后者则会静默忽略它们。

文档

访问 PyTorch 的全面开发者文档

查看文档

教程

获取针对初学者和高级开发人员的深入教程

查看教程

资源

查找开发资源并获得问题的解答

查看资源