UnaryTransform¶
- class torchrl.envs.transforms.UnaryTransform(in_keys: Sequence[NestedKey], out_keys: Sequence[NestedKey], in_keys_inv: Sequence[NestedKey] | None = None, out_keys_inv: Sequence[NestedKey] | None = None, *, fn: Callable[[Any], Tensor | TensorDictBase], inv_fn: Callable[[Any], Any] | None = None, use_raw_nontensor: bool = False)[source]¶
对指定的输入应用一元操作。
- 参数:
in_keys (sequence of NestedKey) – 一元操作的输入键。
out_keys (sequence of NestedKey) – 一元操作的输出键。
in_keys_inv (sequence of NestedKey, optional) – 逆向调用期间一元操作的输入键。
out_keys_inv (sequence of NestedKey, optional) – 逆向调用期间一元操作的输出键。
- 关键字参数:
fn (Callable[[Any], Tensor | TensorDictBase]) – 用作一元操作的函数。如果它接受非张量输入,它也必须接受
None
。inv_fn (Callable[[Any], Any], optional) – 在逆向调用期间用作一元操作的函数。如果它接受非张量输入,它也必须接受
None
。可以省略,在这种情况下fn
将用于逆向映射。use_raw_nontensor (bool, optional) – 如果为
False
,则在调用fn
之前,从NonTensorData
/NonTensorStack
输入中提取数据。如果为True
,则直接将原始NonTensorData
/NonTensorStack
输入提供给fn
,它必须支持这些输入。默认为False
。
示例
>>> from torchrl.envs import GymEnv, UnaryTransform >>> env = GymEnv("Pendulum-v1") >>> env = env.append_transform( ... UnaryTransform( ... in_keys=["observation"], ... out_keys=["observation_trsf"], ... fn=lambda tensor: str(tensor.numpy().tobytes()))) >>> env.observation_spec Composite( observation: BoundedContinuous( shape=torch.Size([3]), space=ContinuousBox( low=Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float32, contiguous=True), high=Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float32, contiguous=True)), device=cpu, dtype=torch.float32, domain=continuous), observation_trsf: NonTensor( shape=torch.Size([]), space=None, device=cpu, dtype=None, domain=None), device=None, shape=torch.Size([])) >>> env.rollout(3) TensorDict( fields={ action: Tensor(shape=torch.Size([3, 1]), device=cpu, dtype=torch.float32, is_shared=False), done: Tensor(shape=torch.Size([3, 1]), device=cpu, dtype=torch.bool, is_shared=False), next: TensorDict( fields={ done: Tensor(shape=torch.Size([3, 1]), device=cpu, dtype=torch.bool, is_shared=False), observation: Tensor(shape=torch.Size([3, 3]), device=cpu, dtype=torch.float32, is_shared=False), observation_trsf: NonTensorStack( ["b'\\xbe\\xbc\\x7f?8\\x859=/\\x81\\xbe;'", "b'\\x..., batch_size=torch.Size([3]), device=None), reward: Tensor(shape=torch.Size([3, 1]), device=cpu, dtype=torch.float32, is_shared=False), terminated: Tensor(shape=torch.Size([3, 1]), device=cpu, dtype=torch.bool, is_shared=False), truncated: Tensor(shape=torch.Size([3, 1]), device=cpu, dtype=torch.bool, is_shared=False)}, batch_size=torch.Size([3]), device=None, is_shared=False), observation: Tensor(shape=torch.Size([3, 3]), device=cpu, dtype=torch.float32, is_shared=False), observation_trsf: NonTensorStack( ["b'\\x9a\\xbd\\x7f?\\xb8T8=8.c>'", "b'\\xbe\\xbc\..., batch_size=torch.Size([3]), device=None), terminated: Tensor(shape=torch.Size([3, 1]), device=cpu, dtype=torch.bool, is_shared=False), truncated: Tensor(shape=torch.Size([3, 1]), device=cpu, dtype=torch.bool, is_shared=False)}, batch_size=torch.Size([3]), device=None, is_shared=False) >>> env.check_env_specs() [torchrl][INFO] check_env_specs succeeded!
- transform_action_spec(action_spec: TensorSpec, test_input_spec: TensorSpec) TensorSpec [source]¶
转换 action spec 以便生成的 spec 与 transform 映射匹配。
- 参数:
action_spec (TensorSpec) – transform 前的 spec
- 返回:
transform 后的预期 spec
- transform_done_spec(done_spec: TensorSpec, test_output_spec: TensorSpec) TensorSpec [source]¶
转换 done spec 以便生成的 spec 与 transform 映射匹配。
- 参数:
done_spec (TensorSpec) – transform 前的 spec
- 返回:
transform 后的预期 spec
- transform_input_spec(input_spec: Composite) Composite [source]¶
转换 input spec 以便生成的 spec 与 transform 映射匹配。
- 参数:
input_spec (TensorSpec) – transform 前的 spec
- 返回:
transform 后的预期 spec
- transform_observation_spec(observation_spec: TensorSpec, test_output_spec: TensorSpec) TensorSpec [source]¶
转换 observation spec 以便生成的 spec 与 transform 映射匹配。
- 参数:
observation_spec (TensorSpec) – transform 前的 spec
- 返回:
transform 后的预期 spec
- transform_output_spec(output_spec: Composite) Composite [source]¶
转换 output spec 以便生成的 spec 与 transform 映射匹配。
此方法通常不应修改。更改应使用
transform_observation_spec()
、transform_reward_spec()
和transform_full_done_spec()
实现。 :param output_spec: transform 前的 spec :type output_spec: TensorSpec- 返回:
transform 后的预期 spec
- transform_reward_spec(reward_spec: TensorSpec, test_output_spec: TensorSpec) TensorSpec [source]¶
转换 reward spec 以便生成的 spec 与 transform 映射匹配。
- 参数:
reward_spec (TensorSpec) – transform 前的 spec
- 返回:
transform 后的预期 spec
- transform_state_spec(state_spec: TensorSpec, test_input_spec: TensorSpec) TensorSpec [source]¶
转换 state spec 以便生成的 spec 与 transform 映射匹配。
- 参数:
state_spec (TensorSpec) – transform 前的 spec
- 返回:
transform 后的预期 spec