快捷方式

VmapModule

class torchrl.modules.VmapModule(*args, **kwargs)[source]

TensorDictModule 包装器,用于在输入上进行 vmap 操作。

它旨在与接受批次维度比提供的数据少一个的模块一起使用。通过使用此包装器,可以隐藏一个批次维度并满足包装的模块。

参数:
  • module (TensorDictModuleBase) – 要进行 vmap 操作的模块。

  • vmap_dim (int, optional) – vmap 输入和输出维度。如果未提供,则假定为 tensordict 的最后一个维度。

注意

由于 vmap 需要控制输入的批次大小,因此此模块不支持调度的参数

示例

>>> lam = TensorDictModule(lambda x: x[0], in_keys=["x"], out_keys=["y"])
>>> sample_in = torch.ones((10,3,2))
>>> sample_in_td = TensorDict({"x":sample_in}, batch_size=[10])
>>> lam(sample_in)
>>> vm = VmapModule(lam, 0)
>>> vm(sample_in_td)
>>> assert (sample_in_td["x"][:, 0] == sample_in_td["y"]).all()
forward(tensordict)[source]

定义每次调用时执行的计算。

应由所有子类重写。

注意

虽然前向传递的配方需要在该函数中定义,但之后应该调用 Module 实例,而不是调用此函数,因为前者负责运行注册的钩子,而后者会静默地忽略它们。

文档

访问 PyTorch 的全面开发者文档

查看文档

教程

获取面向初学者和高级开发者的深入教程

查看教程

资源

查找开发资源并获得问题解答

查看资源