VmapModule¶
- class torchrl.modules.VmapModule(*args, **kwargs)[source]¶
TensorDictModule 包装器,用于在输入上进行 vmap 操作。
它旨在与接受批次维度比提供的数据少一个的模块一起使用。通过使用此包装器,可以隐藏一个批次维度并满足包装的模块。
- 参数:
module (TensorDictModuleBase) – 要进行 vmap 操作的模块。
vmap_dim (int, optional) – vmap 输入和输出维度。如果未提供,则假定为 tensordict 的最后一个维度。
注意
由于 vmap 需要控制输入的批次大小,因此此模块不支持调度的参数
示例
>>> lam = TensorDictModule(lambda x: x[0], in_keys=["x"], out_keys=["y"]) >>> sample_in = torch.ones((10,3,2)) >>> sample_in_td = TensorDict({"x":sample_in}, batch_size=[10]) >>> lam(sample_in) >>> vm = VmapModule(lam, 0) >>> vm(sample_in_td) >>> assert (sample_in_td["x"][:, 0] == sample_in_td["y"]).all()