DeviceCastTransform¶
- class torchrl.envs.transforms.DeviceCastTransform(device, orig_device=None, *, in_keys=None, out_keys=None, in_keys_inv=None, out_keys_inv=None)[源代码]¶
将数据从一个设备移动到另一个设备。
- 参数:
device (torch.device 或 等效) – 目标设备。
orig_device (torch.device 或 等效) – 源设备。如果未指定且存在父环境,则从父环境中检索。在所有其他情况下,它保持未指定。
示例
>>> td = TensorDict( ... {'obs': torch.ones(1, dtype=torch.double), ... }, [], device="cpu:0") >>> transform = DeviceCastTransform(device=torch.device("cpu:2")) >>> td = transform(td) >>> print(td.device) cpu:2
- transform_done_spec(full_done_spec: CompositeSpec) CompositeSpec [源代码]¶
转换 done spec,以使生成的 spec 匹配转换映射。
- 参数:
done_spec (TensorSpec) – 转换前的 spec
- 返回值:
转换后的预期 spec
- transform_input_spec(input_spec: CompositeSpec) CompositeSpec [源代码]¶
转换输入 spec,以使生成的 spec 匹配转换映射。
- 参数:
input_spec (TensorSpec) – 转换前的 spec
- 返回值:
转换后的预期 spec
- transform_observation_spec(observation_spec: CompositeSpec) CompositeSpec [源代码]¶
转换观察 spec,以使生成的 spec 匹配转换映射。
- 参数:
observation_spec (TensorSpec) – 转换前的 spec
- 返回值:
转换后的预期 spec
- transform_output_spec(output_spec: CompositeSpec) CompositeSpec [源代码]¶
转换输出 spec,以使生成的 spec 匹配转换映射。
此方法通常应保持不变。应使用
transform_observation_spec()
、transform_reward_spec()
和transformfull_done_spec()
来实现更改。 :param output_spec: 转换前的 spec :type output_spec: TensorSpec- 返回值:
转换后的预期 spec
- transform_reward_spec(full_reward_spec: CompositeSpec) CompositeSpec [source]¶
转换奖励规范,使结果规范与转换映射匹配。
- 参数:
reward_spec (TensorSpec) – 转换前的规范
- 返回值:
转换后的预期 spec