快捷方式

排列转换

class torchrl.envs.transforms.PermuteTransform(dims, in_keys=None, out_keys=None, in_keys_inv=None, out_keys_inv=None)[源代码]

排列转换。

沿所需维度对输入张量进行排列。排列必须沿特征维度提供(而不是批次维度)。

参数::
  • dims (int 列表) – 维度的排列顺序。必须是维度 [-(len(dims)), ..., -1] 的重新排序。

  • in_keys (嵌套键列表) – 输入项(读取)。

  • out_keys (嵌套键列表) – 输入项(写入)。如果未提供,则默认为 in_keys

  • in_keys_inv (嵌套键列表) – 在 inv() 调用期间的输入项(读取)。

  • out_keys_inv (嵌套键列表) – 在 inv() 调用期间的输入项(写入)。如果未提供,则默认为 in_keys_in

示例

>>> from torchrl.envs.libs.gym import GymEnv
>>> base_env = GymEnv("ALE/Pong-v5")
>>> base_env.rollout(2)
TensorDict(
    fields={
        action: Tensor(shape=torch.Size([2, 6]), device=cpu, dtype=torch.int64, is_shared=False),
        done: Tensor(shape=torch.Size([2, 1]), device=cpu, dtype=torch.bool, is_shared=False),
        next: TensorDict(
            fields={
                done: Tensor(shape=torch.Size([2, 1]), device=cpu, dtype=torch.bool, is_shared=False),
                pixels: Tensor(shape=torch.Size([2, 210, 160, 3]), device=cpu, dtype=torch.uint8, is_shared=False),
                reward: Tensor(shape=torch.Size([2, 1]), device=cpu, dtype=torch.float32, is_shared=False)},
            batch_size=torch.Size([2]),
            device=cpu,
            is_shared=False),
        pixels: Tensor(shape=torch.Size([2, 210, 160, 3]), device=cpu, dtype=torch.uint8, is_shared=False)},
    batch_size=torch.Size([2]),
    device=cpu,
    is_shared=False)
>>> env = TransformedEnv(base_env, PermuteTransform((-1, -3, -2), in_keys=["pixels"]))
>>> env.rollout(2)  # channels are at the end
TensorDict(
    fields={
        action: Tensor(shape=torch.Size([2, 6]), device=cpu, dtype=torch.int64, is_shared=False),
        done: Tensor(shape=torch.Size([2, 1]), device=cpu, dtype=torch.bool, is_shared=False),
        next: TensorDict(
            fields={
                done: Tensor(shape=torch.Size([2, 1]), device=cpu, dtype=torch.bool, is_shared=False),
                pixels: Tensor(shape=torch.Size([2, 3, 210, 160]), device=cpu, dtype=torch.uint8, is_shared=False),
                reward: Tensor(shape=torch.Size([2, 1]), device=cpu, dtype=torch.float32, is_shared=False)},
            batch_size=torch.Size([2]),
            device=cpu,
            is_shared=False),
        pixels: Tensor(shape=torch.Size([2, 3, 210, 160]), device=cpu, dtype=torch.uint8, is_shared=False)},
    batch_size=torch.Size([2]),
    device=cpu,
    is_shared=False)
transform_input_spec(input_spec)[源代码]

转换输入规范,使结果规范与转换映射匹配。

参数::

input_spec (TensorSpec) – 转换前的规范

返回值::

转换后的预期规范

transform_observation_spec(observation_spec: TensorSpec) TensorSpec[源代码]

转换观察规范,使结果规范与转换映射匹配。

参数::

observation_spec (TensorSpec) – 转换前的规范

返回值::

转换后的预期规范

文档

访问 PyTorch 的全面开发者文档

查看文档

教程

获取针对初学者和高级开发者的深入教程

查看教程

资源

查找开发资源并解答您的问题

查看资源