VmapModule¶
- class torchrl.modules.VmapModule(*args, **kwargs)[source]¶
一个 TensorDictModule 包装器,用于对输入进行 vmap。
它旨在与接受数据(其批量维度比提供的维度少一个)的模块一起使用。通过使用此包装器,可以隐藏一个批量维度并满足包装的模块。
- 参数:
module (TensorDictModuleBase) – 要进行 vmap 的模块。
vmap_dim (int, 可选) – vmap 输入和输出维度。如果没有提供,则假定为 tensordict 的最后一个维度。
注意
由于 vmap 需要控制输入的批量大小,因此此模块不支持分派参数
示例
>>> lam = TensorDictModule(lambda x: x[0], in_keys=["x"], out_keys=["y"]) >>> sample_in = torch.ones((10,3,2)) >>> sample_in_td = TensorDict({"x":sample_in}, batch_size=[10]) >>> lam(sample_in) >>> vm = VmapModule(lam, 0) >>> vm(sample_in_td) >>> assert (sample_in_td["x"][:, 0] == sample_in_td["y"]).all()