快捷方式

QValueModule

class torchrl.modules.tensordict_module.QValueModule(*args, **kwargs)[源]

用于 Q 值策略的 Q 值 TensorDictModule。

此模块根据给定的动作空间(one-hot、binary 或 categorical),将包含动作值的张量处理为其 argmax 分量(即结果的贪婪动作)。它适用于 tensordict 和普通张量。

参数:
  • action_space (str, 可选) – 动作空间。必须是 "one-hot", "mult-one-hot", "binary""categorical" 之一。此参数与 spec 互斥,因为 spec 决定了动作空间。

  • action_value_key (strtuple of str, 可选) – 表示动作值的输入键。默认为 "action_value"

  • action_mask_key (strtuple of str, 可选) – 表示动作掩码的输入键。默认为 "None"(等同于无掩码)。

  • out_keys (list of strtuple of str, 可选) – 表示动作、动作值和所选动作值的输出键。默认为 ["action", "action_value", "chosen_action_value"]

  • var_nums (int, 可选) – 如果 action_space = "mult-one-hot",此值表示每个动作分量的基数。

  • spec (TensorSpec, 可选) – 如果提供,则表示动作(以及/或其它输出)的规格。此参数与 action_space 互斥,因为 spec 决定了动作空间。

  • safe (bool) – 如果为 True,则检查输出值是否符合输入规格。由于探索策略或数值下溢/溢出问题,可能会发生超出范围的采样。如果此值超出范围,将使用 TensorSpec.project 方法将其投影回所需的空间。默认为 False

返回:

如果输入是单个张量,则返回一个包含所选动作、值和所选动作值的三个元素组。如果提供了 tensordict,则使用 out_keys 字段指示的键在其中更新这些条目。

示例

>>> from tensordict import TensorDict
>>> action_space = "categorical"
>>> action_value_key = "my_action_value"
>>> actor = QValueModule(action_space, action_value_key=action_value_key)
>>> # This module works with both tensordict and regular tensors:
>>> value = torch.zeros(4)
>>> value[-1] = 1
>>> actor(my_action_value=value)
(tensor(3), tensor([0., 0., 0., 1.]), tensor([1.]))
>>> actor(value)
(tensor(3), tensor([0., 0., 0., 1.]), tensor([1.]))
>>> actor(TensorDict({action_value_key: value}, []))
TensorDict(
    fields={
        action: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.int64, is_shared=False),
        action_value: Tensor(shape=torch.Size([4]), device=cpu, dtype=torch.float32, is_shared=False),
        chosen_action_value: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.float32, is_shared=False),
        my_action_value: Tensor(shape=torch.Size([4]), device=cpu, dtype=torch.float32, is_shared=False)},
    batch_size=torch.Size([]),
    device=None,
    is_shared=False)
forward(tensordict: Tensor = None) TensorDictBase[源]

定义每次调用时执行的计算。

应被所有子类覆盖。

注意

尽管前向传播 (forward pass) 的实现需要在该函数内定义,但后续应调用 Module 实例而不是此函数本身,因为前者负责运行已注册的钩子,而后者会静默忽略它们。


© 版权所有 2022, Meta。

使用 Sphinx 构建,主题由 Read the Docs 提供。

文档

查阅 PyTorch 的全面开发者文档

查看文档

教程

获取面向初学者和高级开发者的深入教程

查看教程

资源

查找开发资源并获得问题解答

查看资源