快捷方式

QValueActor

class torchrl.modules.tensordict_module.QValueActor(*args, **kwargs)[source]

Q 值 Actor 类。

此类在输入模块后附加一个 QValueModule,以便使用动作值来选择动作。

参数:

module (nn.Module) – 一个 torch.nn.Module,用于将输入映射到输出参数空间。如果提供的类与 tensordict.nn.TensorDictModuleBase 不兼容,它将被包装在一个 tensordict.nn.TensorDictModule 中,其 in_keys 由以下关键字参数指示。

关键字参数:
  • in_keys (str 的可迭代对象, 可选) – 如果提供的类与 tensordict.nn.TensorDictModuleBase 不兼容,此键列表指示需要将哪些观察结果传递给包装的模块以获取动作值。默认为 ["observation"]

  • spec (TensorSpec, 可选) – 仅限关键字的参数。输出张量的规格。如果模块输出多个输出张量,则 spec 描述第一个输出张量的空间。

  • safe (bool) – 仅限关键字的参数。如果 True,则检查输出值是否与输入规范匹配。由于探索策略或数值下溢/上溢问题,可能会发生域外采样。如果此值超出范围,则使用 TensorSpec.project 方法将其投影回所需的空间。默认为 False

  • action_space (str, 可选) – 动作空间。必须是 "one-hot""mult-one-hot""binary""categorical" 之一。此参数与 spec 互斥,因为 spec 决定了动作空间。

  • action_value_key (strstr 的元组, 可选) – 如果输入模块是 tensordict.nn.TensorDictModuleBase 实例,则它必须与它的一个输出键匹配。否则,此字符串表示输出 tensordict 中动作值条目的名称。

  • action_mask_key (strstr 的元组, 可选) – 表示动作掩码的输入键。默认为 "None"(相当于没有掩码)。

注意

out_keys 不能传递。如果模块是 tensordict.nn.TensorDictModule 实例,则将相应更新 out_keys。对于常规 torch.nn.Module 实例,将使用三元组 ["action", action_value_key, "chosen_action_value"]

示例

>>> import torch
>>> from tensordict import TensorDict
>>> from torch import nn
>>> from torchrl.data import OneHotDiscreteTensorSpec
>>> from torchrl.modules.tensordict_module.actors import QValueActor
>>> td = TensorDict({'observation': torch.randn(5, 4)}, [5])
>>> # with a regular nn.Module
>>> module = nn.Linear(4, 4)
>>> action_spec = OneHotDiscreteTensorSpec(4)
>>> qvalue_actor = QValueActor(module=module, spec=action_spec)
>>> td = qvalue_actor(td)
>>> print(td)
TensorDict(
    fields={
        action: Tensor(shape=torch.Size([5, 4]), device=cpu, dtype=torch.int64, is_shared=False),
        action_value: Tensor(shape=torch.Size([5, 4]), device=cpu, dtype=torch.float32, is_shared=False),
        chosen_action_value: Tensor(shape=torch.Size([5, 1]), device=cpu, dtype=torch.float32, is_shared=False),
        observation: Tensor(shape=torch.Size([5, 4]), device=cpu, dtype=torch.float32, is_shared=False)},
    batch_size=torch.Size([5]),
    device=None,
    is_shared=False)
>>> # with a TensorDictModule
>>> td = TensorDict({'obs': torch.randn(5, 4)}, [5])
>>> module = TensorDictModule(lambda x: x, in_keys=["obs"], out_keys=["action_value"])
>>> action_spec = OneHotDiscreteTensorSpec(4)
>>> qvalue_actor = QValueActor(module=module, spec=action_spec)
>>> td = qvalue_actor(td)
>>> print(td)
TensorDict(
    fields={
        action: Tensor(shape=torch.Size([5, 4]), device=cpu, dtype=torch.int64, is_shared=False),
        action_value: Tensor(shape=torch.Size([5, 4]), device=cpu, dtype=torch.float32, is_shared=False),
        chosen_action_value: Tensor(shape=torch.Size([5, 1]), device=cpu, dtype=torch.float32, is_shared=False),
        obs: Tensor(shape=torch.Size([5, 4]), device=cpu, dtype=torch.float32, is_shared=False)},
    batch_size=torch.Size([5]),
    device=None,
    is_shared=False)

文档

访问 PyTorch 的全面开发者文档

查看文档

教程

获取初学者和高级开发人员的深度教程

查看教程

资源

查找开发资源并解答您的问题

查看资源