快捷方式

DTypeCastTransform

class torchrl.envs.transforms.DTypeCastTransform(dtype_in: torch.dtype, dtype_out: torch.dtype, in_keys: Sequence[NestedKey] | None = None, out_keys: Sequence[NestedKey] | None = None, in_keys_inv: Sequence[NestedKey] | None = None, out_keys_inv: Sequence[NestedKey] | None = None)[source]

将选定键的数据类型从一种转换为另一种。

根据在构造期间是否提供了 in_keysin_keys_inv,类的行为将发生变化

  • 如果提供了键,则仅转换这些条目,将 dtype_in 条目转换为 dtype_out 条目;

  • 如果未提供键且对象位于环境转换注册表中,则 dtype 设置为 dtype_in 的输入和输出规范将分别用作 in_keys_inv / in_keys。

  • 如果未提供键且对象在没有环境的情况下使用,则 forward / inverse 传递将扫描输入 tensordict 中的所有 dtype_in 值,并将它们映射到 dtype_out 张量。对于大型数据结构,这可能会影响性能,因为此扫描并非免费的。将不会缓存要转换的键。请注意,在这种情况下,不能传递 out_keys(或 out_keys_inv),因为无法精确预测处理键的顺序。

参数:
  • dtype_in (torch.dtype) – 输入数据类型(来自环境)。

  • dtype_out (torch.dtype) – 输出数据类型(用于模型训练)。

  • in_keys (嵌套键序列, 可选) – dtype_in 键列表,在公开给外部对象和函数之前将其转换为 dtype_out

  • out_keys (嵌套键序列, 可选) – 目标键列表。如果未提供,则默认为 in_keys

  • in_keys_inv (嵌套键序列, 可选) – dtype_out 键列表,在传递给包含的 base_env 或存储之前将其转换为 dtype_in

  • out_keys_inv (嵌套键序列, 可选) – 逆转换的目标键列表。如果未提供,则默认为 in_keys_inv

示例

>>> td = TensorDict(
...     {'obs': torch.ones(1, dtype=torch.double),
...     'not_transformed': torch.ones(1, dtype=torch.double),
... }, [])
>>> transform = DTypeCastTransform(torch.double, torch.float, in_keys=["obs"])
>>> _ = transform(td)
>>> print(td.get("obs").dtype)
torch.float32
>>> print(td.get("not_transformed").dtype)
torch.float64

在“自动”模式下,所有 float64 条目都会被转换

示例

>>> td = TensorDict(
...     {'obs': torch.ones(1, dtype=torch.double),
...     'not_transformed': torch.ones(1, dtype=torch.double),
... }, [])
>>> transform = DTypeCastTransform(torch.double, torch.float)
>>> _ = transform(td)
>>> print(td.get("obs").dtype)
torch.float32
>>> print(td.get("not_transformed").dtype)
torch.float32

在没有指定转换键的情况下构建环境时,相同的行为也是规则

示例

>>> class MyEnv(EnvBase):
...     def __init__(self):
...         super().__init__()
...         self.observation_spec = CompositeSpec(obs=UnboundedContinuousTensorSpec((), dtype=torch.float64))
...         self.action_spec = UnboundedContinuousTensorSpec((), dtype=torch.float64)
...         self.reward_spec = UnboundedContinuousTensorSpec((1,), dtype=torch.float64)
...         self.done_spec = UnboundedContinuousTensorSpec((1,), dtype=torch.bool)
...     def _reset(self, data=None):
...         return TensorDict({"done": torch.zeros((1,), dtype=torch.bool), **self.observation_spec.rand()}, [])
...     def _step(self, data):
...         assert data["action"].dtype == torch.float64
...         reward = self.reward_spec.rand()
...         done = torch.zeros((1,), dtype=torch.bool)
...         obs = self.observation_spec.rand()
...         assert reward.dtype == torch.float64
...         assert obs["obs"].dtype == torch.float64
...         return obs.empty().set("next", obs.update({"reward": reward, "done": done}))
...     def _set_seed(self, seed):
...         pass
>>> env = TransformedEnv(MyEnv(), DTypeCastTransform(torch.double, torch.float))
>>> assert env.action_spec.dtype == torch.float32
>>> assert env.observation_spec["obs"].dtype == torch.float32
>>> assert env.reward_spec.dtype == torch.float32, env.reward_spec.dtype
>>> print(env.rollout(2))
TensorDict(
    fields={
        action: Tensor(shape=torch.Size([2]), device=cpu, dtype=torch.float32, is_shared=False),
        done: Tensor(shape=torch.Size([2, 1]), device=cpu, dtype=torch.bool, is_shared=False),
        next: TensorDict(
            fields={
                done: Tensor(shape=torch.Size([2, 1]), device=cpu, dtype=torch.bool, is_shared=False),
                obs: Tensor(shape=torch.Size([2]), device=cpu, dtype=torch.float32, is_shared=False),
                reward: Tensor(shape=torch.Size([2, 1]), device=cpu, dtype=torch.float32, is_shared=False)},
            batch_size=torch.Size([2]),
            device=cpu,
            is_shared=False),
        obs: Tensor(shape=torch.Size([2]), device=cpu, dtype=torch.float32, is_shared=False)},
    batch_size=torch.Size([2]),
    device=cpu,
    is_shared=False)
>>> assert env.transform.in_keys == ["obs", "reward"]
>>> assert env.transform.in_keys_inv == ["action"]
forward(tensordict: TensorDictBase) TensorDictBase[source]

读取输入 tensordict,并对选定的键应用转换。

transform_input_spec(input_spec: TensorSpec) TensorSpec[source]

转换输入规范,使结果规范与转换映射匹配。

参数:

input_spec (TensorSpec) – 转换前的规范

返回:

转换后的预期规范

transform_observation_spec(observation_spec)[source]

转换观察规范,使结果规范与转换映射匹配。

参数:

observation_spec (TensorSpec) – 转换前的规范

返回:

转换后的预期规范

transform_output_spec(output_spec: CompositeSpec) CompositeSpec[source]

转换输出规范,使其结果规范与转换映射匹配。

此方法通常应保持不变。更改应使用transform_observation_spec()transform_reward_spec()transformfull_done_spec()实现。:param output_spec: 变换前的规范 :type output_spec: TensorSpec

返回:

转换后的预期规范

文档

访问 PyTorch 的全面开发者文档

查看文档

教程

获取适合初学者和高级开发人员的深度教程

查看教程

资源

查找开发资源并获得问题的解答

查看资源