快捷方式

UnaryTransform

class torchrl.envs.transforms.UnaryTransform(in_keys: Sequence[NestedKey], out_keys: Sequence[NestedKey], in_keys_inv: Sequence[NestedKey] | None = None, out_keys_inv: Sequence[NestedKey] | None = None, *, fn: Callable[[Any], Tensor | TensorDictBase], inv_fn: Callable[[Any], Any] | None = None, use_raw_nontensor: bool = False)[source]

对指定的输入应用一元操作。

参数:
  • in_keys (sequence of NestedKey) – 一元操作的输入键。

  • out_keys (sequence of NestedKey) – 一元操作的输出键。

  • in_keys_inv (sequence of NestedKey, optional) – 逆向调用期间一元操作的输入键。

  • out_keys_inv (sequence of NestedKey, optional) – 逆向调用期间一元操作的输出键。

关键字参数:
  • fn (Callable[[Any], Tensor | TensorDictBase]) – 用作一元操作的函数。如果它接受非张量输入,它也必须接受 None

  • inv_fn (Callable[[Any], Any], optional) – 在逆向调用期间用作一元操作的函数。如果它接受非张量输入,它也必须接受 None。可以省略,在这种情况下 fn 将用于逆向映射。

  • use_raw_nontensor (bool, optional) – 如果为 False,则在调用 fn 之前,从 NonTensorData/NonTensorStack 输入中提取数据。如果为 True,则直接将原始 NonTensorData/NonTensorStack 输入提供给 fn,它必须支持这些输入。默认为 False

示例

>>> from torchrl.envs import GymEnv, UnaryTransform
>>> env = GymEnv("Pendulum-v1")
>>> env = env.append_transform(
...     UnaryTransform(
...         in_keys=["observation"],
...         out_keys=["observation_trsf"],
...             fn=lambda tensor: str(tensor.numpy().tobytes())))
>>> env.observation_spec
Composite(
    observation: BoundedContinuous(
        shape=torch.Size([3]),
        space=ContinuousBox(
            low=Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float32, contiguous=True),
            high=Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float32, contiguous=True)),
        device=cpu,
        dtype=torch.float32,
        domain=continuous),
    observation_trsf: NonTensor(
        shape=torch.Size([]),
        space=None,
        device=cpu,
        dtype=None,
        domain=None),
    device=None,
    shape=torch.Size([]))
>>> env.rollout(3)
TensorDict(
    fields={
        action: Tensor(shape=torch.Size([3, 1]), device=cpu, dtype=torch.float32, is_shared=False),
        done: Tensor(shape=torch.Size([3, 1]), device=cpu, dtype=torch.bool, is_shared=False),
        next: TensorDict(
            fields={
                done: Tensor(shape=torch.Size([3, 1]), device=cpu, dtype=torch.bool, is_shared=False),
                observation: Tensor(shape=torch.Size([3, 3]), device=cpu, dtype=torch.float32, is_shared=False),
                observation_trsf: NonTensorStack(
                    ["b'\\xbe\\xbc\\x7f?8\\x859=/\\x81\\xbe;'", "b'\\x...,
                    batch_size=torch.Size([3]),
                    device=None),
                reward: Tensor(shape=torch.Size([3, 1]), device=cpu, dtype=torch.float32, is_shared=False),
                terminated: Tensor(shape=torch.Size([3, 1]), device=cpu, dtype=torch.bool, is_shared=False),
                truncated: Tensor(shape=torch.Size([3, 1]), device=cpu, dtype=torch.bool, is_shared=False)},
            batch_size=torch.Size([3]),
            device=None,
            is_shared=False),
        observation: Tensor(shape=torch.Size([3, 3]), device=cpu, dtype=torch.float32, is_shared=False),
        observation_trsf: NonTensorStack(
            ["b'\\x9a\\xbd\\x7f?\\xb8T8=8.c>'", "b'\\xbe\\xbc\...,
            batch_size=torch.Size([3]),
            device=None),
        terminated: Tensor(shape=torch.Size([3, 1]), device=cpu, dtype=torch.bool, is_shared=False),
        truncated: Tensor(shape=torch.Size([3, 1]), device=cpu, dtype=torch.bool, is_shared=False)},
    batch_size=torch.Size([3]),
    device=None,
    is_shared=False)
>>> env.check_env_specs()
[torchrl][INFO] check_env_specs succeeded!
transform_action_spec(action_spec: TensorSpec, test_input_spec: TensorSpec) TensorSpec[source]

转换 action spec 以便生成的 spec 与 transform 映射匹配。

参数:

action_spec (TensorSpec) – transform 前的 spec

返回:

transform 后的预期 spec

transform_done_spec(done_spec: TensorSpec, test_output_spec: TensorSpec) TensorSpec[source]

转换 done spec 以便生成的 spec 与 transform 映射匹配。

参数:

done_spec (TensorSpec) – transform 前的 spec

返回:

transform 后的预期 spec

transform_input_spec(input_spec: Composite) Composite[source]

转换 input spec 以便生成的 spec 与 transform 映射匹配。

参数:

input_spec (TensorSpec) – transform 前的 spec

返回:

transform 后的预期 spec

transform_observation_spec(observation_spec: TensorSpec, test_output_spec: TensorSpec) TensorSpec[source]

转换 observation spec 以便生成的 spec 与 transform 映射匹配。

参数:

observation_spec (TensorSpec) – transform 前的 spec

返回:

transform 后的预期 spec

transform_output_spec(output_spec: Composite) Composite[source]

转换 output spec 以便生成的 spec 与 transform 映射匹配。

此方法通常不应修改。更改应使用 transform_observation_spec()transform_reward_spec()transform_full_done_spec() 实现。 :param output_spec: transform 前的 spec :type output_spec: TensorSpec

返回:

transform 后的预期 spec

transform_reward_spec(reward_spec: TensorSpec, test_output_spec: TensorSpec) TensorSpec[source]

转换 reward spec 以便生成的 spec 与 transform 映射匹配。

参数:

reward_spec (TensorSpec) – transform 前的 spec

返回:

transform 后的预期 spec

transform_state_spec(state_spec: TensorSpec, test_input_spec: TensorSpec) TensorSpec[source]

转换 state spec 以便生成的 spec 与 transform 映射匹配。

参数:

state_spec (TensorSpec) – transform 前的 spec

返回:

transform 后的预期 spec

文档

访问 PyTorch 的全面开发者文档

查看文档

教程

获取面向初学者和高级开发者的深度教程

查看教程

资源

查找开发资源并获取问题解答

查看资源