快捷方式

动作离散化

class torchrl.envs.transforms.ActionDiscretizer(num_intervals: int | torch.Tensor, action_key: NestedKey = 'action', out_action_key: NestedKey = None, sampling=None, categorical: bool = True)[源代码]

用于离散化连续动作空间的转换。

这种转换使得可以在具有连续动作空间的环境中使用为离散动作空间设计的算法,例如 DQN。

参数:
  • num_intervals (int torch.Tensor) – 动作空间中每个元素的离散值数量。如果提供单个整数,则所有动作项都使用相同数量的元素进行切片。如果提供张量,则它必须与动作空间具有相同数量的元素(即,num_intervals 张量的长度必须与动作空间的最后一个维度匹配)。

  • action_key (NestedKey可选) – 要使用的动作键。指向父环境的动作(浮点动作)。默认值为 "action"

  • out_action_key (NestedKey可选) – 应写入离散动作的键。如果提供 None,则默认为 action_key 的值。如果两个键不匹配,则将连续动作规格从 full_action_spec 环境属性移动到 full_state_spec 容器,因为只有离散动作应被采样才能执行动作。提供 out_action_key 可以确保浮点动作可用以记录。

  • sampling (ActionDiscretizer.SamplingStrategy可选) – ActionDiscretizer.SamplingStrategy IntEnum 对象的元素(MEDIANLOWHIGHRANDOM)。指示如何在提供的间隔中对连续动作进行采样。

  • categorical (bool可选) – 如果为 False,则使用独热编码。默认为 True

示例

>>> from torchrl.envs import GymEnv, check_env_specs
>>> import torch
>>> base_env = GymEnv("HalfCheetah-v4")
>>> num_intervals = torch.arange(5, 11)
>>> categorical = True
>>> sampling = ActionDiscretizer.SamplingStrategy.MEDIAN
>>> t = ActionDiscretizer(
...     num_intervals=num_intervals,
...     categorical=categorical,
...     sampling=sampling,
...     out_action_key="action_disc",
... )
>>> env = base_env.append_transform(t)
TransformedEnv(
    env=GymEnv(env=HalfCheetah-v4, batch_size=torch.Size([]), device=cpu),
    transform=ActionDiscretizer(
        num_intervals=tensor([ 5,  6,  7,  8,  9, 10]),
        action_key=action,
        out_action_key=action_disc,,
        sampling=0,
        categorical=True))
>>> check_env_specs(env)
>>> # Produce a rollout
>>> r = env.rollout(4)
>>> print(r)
TensorDict(
    fields={
        action: Tensor(shape=torch.Size([4, 6]), device=cpu, dtype=torch.float32, is_shared=False),
        action_disc: Tensor(shape=torch.Size([4, 6]), device=cpu, dtype=torch.int64, is_shared=False),
        done: Tensor(shape=torch.Size([4, 1]), device=cpu, dtype=torch.bool, is_shared=False),
        next: TensorDict(
            fields={
                done: Tensor(shape=torch.Size([4, 1]), device=cpu, dtype=torch.bool, is_shared=False),
                observation: Tensor(shape=torch.Size([4, 17]), device=cpu, dtype=torch.float64, is_shared=False),
                reward: Tensor(shape=torch.Size([4, 1]), device=cpu, dtype=torch.float32, is_shared=False),
                terminated: Tensor(shape=torch.Size([4, 1]), device=cpu, dtype=torch.bool, is_shared=False),
                truncated: Tensor(shape=torch.Size([4, 1]), device=cpu, dtype=torch.bool, is_shared=False)},
            batch_size=torch.Size([4]),
            device=cpu,
            is_shared=False),
        observation: Tensor(shape=torch.Size([4, 17]), device=cpu, dtype=torch.float64, is_shared=False),
        terminated: Tensor(shape=torch.Size([4, 1]), device=cpu, dtype=torch.bool, is_shared=False),
        truncated: Tensor(shape=torch.Size([4, 1]), device=cpu, dtype=torch.bool, is_shared=False)},
    batch_size=torch.Size([4]),
    device=cpu,
    is_shared=False)
>>> assert r["action"].dtype == torch.float
>>> assert r["action_disc"].dtype == torch.int64
>>> assert (r["action"] < base_env.action_spec.high).all()
>>> assert (r["action"] > base_env.action_spec.low).all()
class SamplingStrategy(value)[源代码]

ActionDiscretizer 的采样策略。

transform_input_spec(input_spec)[源代码]

转换输入规格,以便生成的规格匹配转换映射。

参数:

input_spec (TensorSpec) – 转换前的规格

返回:

转换后的预期规格

文档

访问 PyTorch 的全面开发者文档

查看文档

教程

获取针对初学者和高级开发者的深入教程

查看教程

资源

查找开发资源并获得问题的解答

查看资源