DTypeCastTransform¶
- class torchrl.envs.transforms.DTypeCastTransform(dtype_in: torch.dtype, dtype_out: torch.dtype, in_keys: Sequence[NestedKey] | None = None, out_keys: Sequence[NestedKey] | None = None, in_keys_inv: Sequence[NestedKey] | None = None, out_keys_inv: Sequence[NestedKey] | None = None)[source]¶
将选定键的数据类型从一种转换为另一种。
根据在构造期间是否提供了
in_keys
或in_keys_inv
,类的行为将发生变化如果提供了键,则仅转换这些条目,将
dtype_in
条目转换为dtype_out
条目;如果未提供键且对象位于环境转换注册表中,则 dtype 设置为
dtype_in
的输入和输出规范将分别用作 in_keys_inv / in_keys。如果未提供键且对象在没有环境的情况下使用,则
forward
/inverse
传递将扫描输入 tensordict 中的所有dtype_in
值,并将它们映射到dtype_out
张量。对于大型数据结构,这可能会影响性能,因为此扫描并非免费的。将不会缓存要转换的键。请注意,在这种情况下,不能传递 out_keys(或 out_keys_inv),因为无法精确预测处理键的顺序。
- 参数:
dtype_in (torch.dtype) – 输入数据类型(来自环境)。
dtype_out (torch.dtype) – 输出数据类型(用于模型训练)。
in_keys (嵌套键序列, 可选) –
dtype_in
键列表,在公开给外部对象和函数之前将其转换为dtype_out
。out_keys (嵌套键序列, 可选) – 目标键列表。如果未提供,则默认为
in_keys
。in_keys_inv (嵌套键序列, 可选) –
dtype_out
键列表,在传递给包含的 base_env 或存储之前将其转换为dtype_in
。out_keys_inv (嵌套键序列, 可选) – 逆转换的目标键列表。如果未提供,则默认为
in_keys_inv
。
示例
>>> td = TensorDict( ... {'obs': torch.ones(1, dtype=torch.double), ... 'not_transformed': torch.ones(1, dtype=torch.double), ... }, []) >>> transform = DTypeCastTransform(torch.double, torch.float, in_keys=["obs"]) >>> _ = transform(td) >>> print(td.get("obs").dtype) torch.float32 >>> print(td.get("not_transformed").dtype) torch.float64
在“自动”模式下,所有 float64 条目都会被转换
示例
>>> td = TensorDict( ... {'obs': torch.ones(1, dtype=torch.double), ... 'not_transformed': torch.ones(1, dtype=torch.double), ... }, []) >>> transform = DTypeCastTransform(torch.double, torch.float) >>> _ = transform(td) >>> print(td.get("obs").dtype) torch.float32 >>> print(td.get("not_transformed").dtype) torch.float32
在没有指定转换键的情况下构建环境时,相同的行为也是规则
示例
>>> class MyEnv(EnvBase): ... def __init__(self): ... super().__init__() ... self.observation_spec = CompositeSpec(obs=UnboundedContinuousTensorSpec((), dtype=torch.float64)) ... self.action_spec = UnboundedContinuousTensorSpec((), dtype=torch.float64) ... self.reward_spec = UnboundedContinuousTensorSpec((1,), dtype=torch.float64) ... self.done_spec = UnboundedContinuousTensorSpec((1,), dtype=torch.bool) ... def _reset(self, data=None): ... return TensorDict({"done": torch.zeros((1,), dtype=torch.bool), **self.observation_spec.rand()}, []) ... def _step(self, data): ... assert data["action"].dtype == torch.float64 ... reward = self.reward_spec.rand() ... done = torch.zeros((1,), dtype=torch.bool) ... obs = self.observation_spec.rand() ... assert reward.dtype == torch.float64 ... assert obs["obs"].dtype == torch.float64 ... return obs.empty().set("next", obs.update({"reward": reward, "done": done})) ... def _set_seed(self, seed): ... pass >>> env = TransformedEnv(MyEnv(), DTypeCastTransform(torch.double, torch.float)) >>> assert env.action_spec.dtype == torch.float32 >>> assert env.observation_spec["obs"].dtype == torch.float32 >>> assert env.reward_spec.dtype == torch.float32, env.reward_spec.dtype >>> print(env.rollout(2)) TensorDict( fields={ action: Tensor(shape=torch.Size([2]), device=cpu, dtype=torch.float32, is_shared=False), done: Tensor(shape=torch.Size([2, 1]), device=cpu, dtype=torch.bool, is_shared=False), next: TensorDict( fields={ done: Tensor(shape=torch.Size([2, 1]), device=cpu, dtype=torch.bool, is_shared=False), obs: Tensor(shape=torch.Size([2]), device=cpu, dtype=torch.float32, is_shared=False), reward: Tensor(shape=torch.Size([2, 1]), device=cpu, dtype=torch.float32, is_shared=False)}, batch_size=torch.Size([2]), device=cpu, is_shared=False), obs: Tensor(shape=torch.Size([2]), device=cpu, dtype=torch.float32, is_shared=False)}, batch_size=torch.Size([2]), device=cpu, is_shared=False) >>> assert env.transform.in_keys == ["obs", "reward"] >>> assert env.transform.in_keys_inv == ["action"]
- transform_input_spec(input_spec: TensorSpec) TensorSpec [source]¶
转换输入规范,使结果规范与转换映射匹配。
- 参数:
input_spec (TensorSpec) – 转换前的规范
- 返回:
转换后的预期规范
- transform_observation_spec(observation_spec)[source]¶
转换观察规范,使结果规范与转换映射匹配。
- 参数:
observation_spec (TensorSpec) – 转换前的规范
- 返回:
转换后的预期规范
- transform_output_spec(output_spec: CompositeSpec) CompositeSpec [source]¶
转换输出规范,使其结果规范与转换映射匹配。
此方法通常应保持不变。更改应使用
transform_observation_spec()
、transform_reward_spec()
和transformfull_done_spec()
实现。:param output_spec: 变换前的规范 :type output_spec: TensorSpec- 返回:
转换后的预期规范