快捷方式

DistributionalQValueActor

class torchrl.modules.tensordict_module.DistributionalQValueActor(*args, **kwargs)[source]

一个分布式的 DQN Actor 类。

此类在输入模块后附加一个 QValueModule,以便使用动作值来选择动作。

参数:

module (nn.Module) – 一个 torch.nn.Module,用于将输入映射到输出参数空间。如果 module 不是 torchrl.modules.DistributionalDQNnet 类型,DistributionalQValueActor 将确保对动作值张量沿维度 -2 应用 log-softmax 操作。可以通过关闭 make_log_softmax 关键字参数来禁用此操作。

关键字参数:
  • in_keys (字符串的可迭代对象, 可选) – 要从输入 tensordict 中读取并传递给 module 的键。如果包含多个元素,则值将按照 in_keys 可迭代对象给出的顺序传递。默认为 ["observation"]

  • spec (TensorSpec, 可选) – 仅关键字参数。输出张量的规范。如果 module 输出多个输出张量,则 spec 表征第一个输出张量的空间。

  • safe (布尔值) – 仅关键字参数。如果为 True,则输出值将根据输入规范进行检查。由于探索策略或数值下溢/溢出问题,可能会发生域外采样。如果此值超出范围,则使用 TensorSpec.project 方法将其投影回期望空间。默认为 False

  • var_nums (整数, 可选) – 如果 action_space = "mult-one-hot",则此值表示每个动作组件的基数。

  • support (torch.Tensor) – 动作值的支撑。

  • action_space (字符串, 可选) – 动作空间。必须是 "one-hot""mult-one-hot""binary""categorical" 之一。此参数与 spec 互斥,因为 spec 条件化了 action_space。

  • make_log_softmax (布尔值, 可选) – 如果为 True 且 module 不是 torchrl.modules.DistributionalDQNnet 类型,则将沿动作值张量的维度 -2 应用 log-softmax 操作。

  • action_value_key (字符串字符串元组, 可选) – 如果输入 module 是 tensordict.nn.TensorDictModuleBase 实例,则它必须匹配其输出键之一。否则,此字符串表示输出 tensordict 中动作值条目的名称。

  • action_mask_key (字符串字符串元组, 可选) – 表示动作掩码的输入键。默认为 "None"(等效于无掩码)。

示例

>>> import torch
>>> from tensordict import TensorDict
>>> from tensordict.nn import TensorDictModule, TensorDictSequential
>>> from torch import nn
>>> from torchrl.data import OneHot
>>> from torchrl.modules import DistributionalQValueActor, MLP
>>> td = TensorDict({'observation': torch.randn(5, 4)}, [5])
>>> nbins = 3
>>> module = MLP(out_features=(nbins, 4), depth=2)
>>> # let us make sure that the output is a log-softmax
>>> module = TensorDictSequential(
...     TensorDictModule(module, ["observation"], ["action_value"]),
...     TensorDictModule(lambda x: x.log_softmax(-2), ["action_value"], ["action_value"]),
... )
>>> action_spec = OneHot(4)
>>> qvalue_actor = DistributionalQValueActor(
...     module=module,
...     spec=action_spec,
...     support=torch.arange(nbins))
>>> td = qvalue_actor(td)
>>> print(td)
TensorDict(
    fields={
        action: Tensor(shape=torch.Size([5, 4]), device=cpu, dtype=torch.int64, is_shared=False),
        action_value: Tensor(shape=torch.Size([5, 3, 4]), device=cpu, dtype=torch.float32, is_shared=False),
        observation: Tensor(shape=torch.Size([5, 4]), device=cpu, dtype=torch.float32, is_shared=False)},
    batch_size=torch.Size([5]),
    device=None,
    is_shared=False)

© 版权所有 2022, Meta。

使用 Sphinx 构建,主题由 Read the Docs 提供。

文档

访问 PyTorch 的全面开发者文档

查看文档

教程

获取面向初学者和高级开发者的深入教程

查看教程

资源

查找开发资源并获得解答

查看资源