快捷方式

DistributionalQValueModule

class torchrl.modules.tensordict_module.DistributionalQValueModule(*args, **kwargs)[source]

用于 Q 值策略的分布式 Q 值钩子。

此模块将包含动作值 logits 的张量处理为其 argmax 分量(即由此产生的贪婪动作),遵循给定的动作空间(独热编码、二进制或分类)。它适用于 tensordict 和常规张量。

预期输入动作值是 log-softmax 操作的结果。

有关分布式 DQN 的更多详细信息,请参阅“强化学习的分布式视角”,https://arxiv.org/pdf/1707.06887.pdf

参数:
  • action_space (str, 可选) – 动作空间。必须是以下之一:"one-hot""mult-one-hot""binary""categorical"。此参数与 spec 互斥,因为 spec 决定了动作空间。

  • support (torch.Tensor) – 动作值的支撑。

  • action_value_key (strstr 元组, 可选) – 表示动作值的输入键。默认值为 "action_value"

  • action_mask_key (strstr 元组, 可选) – 表示动作掩码的输入键。默认值为 "None"(等效于无掩码)。

  • out_keys (str 列表str 元组, 可选) – 表示动作和动作值的输出键。默认值为 ["action", "action_value"]

  • var_nums (int, 可选) – 如果 action_space = "mult-one-hot",此值表示每个动作组件的基数。

  • spec (TensorSpec, 可选) – 如果提供,则为动作(和/或其他输出)的规范。这与 action_space 互斥,因为规范决定了动作空间。

  • safe (bool) – 如果为 True,则会根据输入规范检查输出值。由于探索策略或数值下溢/上溢问题,可能会发生域外采样。如果此值超出范围,则使用 TensorSpec.project 方法将其投影回所需空间。默认值为 False

示例

>>> from tensordict import TensorDict
>>> torch.manual_seed(0)
>>> action_space = "categorical"
>>> action_value_key = "my_action_value"
>>> support = torch.tensor([-1, 0.0, 1.0]) # the action value is between -1 and 1
>>> actor = DistributionalQValueModule(action_space, support=support, action_value_key=action_value_key)
>>> # This module works with both tensordict and regular tensors:
>>> value = torch.full((3, 4), -100)
>>> # the first bin (-1) of the first action is high: there's a high chance that it has a low value
>>> value[0, 0] = 0
>>> # the second bin (0) of the second action is high: there's a high chance that it has an intermediate value
>>> value[1, 1] = 0
>>> # the third bin (0) of the thid action is high: there's a high chance that it has an high value
>>> value[2, 2] = 0
>>> actor(my_action_value=value)
(tensor(2), tensor([[   0, -100, -100, -100],
        [-100,    0, -100, -100],
        [-100, -100,    0, -100]]))
>>> actor(value)
(tensor(2), tensor([[   0, -100, -100, -100],
        [-100,    0, -100, -100],
        [-100, -100,    0, -100]]))
>>> actor(TensorDict({action_value_key: value}, []))
TensorDict(
    fields={
        action: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.int64, is_shared=False),
        my_action_value: Tensor(shape=torch.Size([3, 4]), device=cpu, dtype=torch.int64, is_shared=False)},
    batch_size=torch.Size([]),
    device=None,
    is_shared=False)
forward(tensordict: Tensor) TensorDictBase[source]

定义每次调用时执行的计算。

所有子类都应该覆盖此方法。

注意

虽然前向传递的配方需要在此函数内定义,但之后应该调用 Module 实例,而不是此函数,因为前者负责运行注册的钩子,而后者会静默地忽略它们。

文档

访问 PyTorch 的全面开发者文档

查看文档

教程

获取针对初学者和高级开发人员的深入教程

查看教程

资源

查找开发资源并获得问题的解答

查看资源