快捷方式

ActorValueOperator

class torchrl.modules.tensordict_module.ActorValueOperator(*args, **kwargs)[源代码]

Actor-value 算子。

此类将一个演员和一个价值模型包装在一起,它们共享一个共同的观察嵌入网络。

../../_images/aafig-2229301c32d3e27b4cec9be5284f11e681ba0607.svg

注意

对于返回动作和质量值 \(Q(s, a)\) 的类似类,请参阅 ActorCriticOperator。对于没有共同嵌入的版本,请参考 ActorCriticWrapper

为了简化工作流程,此类提供了一个 get_policy_operator() 和 get_value_operator() 方法,这两个方法都将返回一个具有专用功能的独立 TDModule。

参数::
  • common_operator (TensorDictModule) – 一个读取观察值并生成隐藏变量的通用算子

  • policy_operator (TensorDictModule) – 一个读取隐藏变量并返回动作的策略算子

  • value_operator (TensorDictModule) – 一个价值算子,读取隐藏变量并返回一个价值

示例

>>> import torch
>>> from tensordict import TensorDict
>>> from torchrl.modules import ProbabilisticActor, SafeModule
>>> from torchrl.modules import ValueOperator, TanhNormal, ActorValueOperator, NormalParamExtractor
>>> module_hidden = torch.nn.Linear(4, 4)
>>> td_module_hidden = SafeModule(
...    module=module_hidden,
...    in_keys=["observation"],
...    out_keys=["hidden"],
...    )
>>> module_action = TensorDictModule(
...     nn.Sequential(torch.nn.Linear(4, 8), NormalParamExtractor()),
...     in_keys=["hidden"],
...     out_keys=["loc", "scale"],
...     )
>>> td_module_action = ProbabilisticActor(
...    module=module_action,
...    in_keys=["loc", "scale"],
...    out_keys=["action"],
...    distribution_class=TanhNormal,
...    return_log_prob=True,
...    )
>>> module_value = torch.nn.Linear(4, 1)
>>> td_module_value = ValueOperator(
...    module=module_value,
...    in_keys=["hidden"],
...    )
>>> td_module = ActorValueOperator(td_module_hidden, td_module_action, td_module_value)
>>> td = TensorDict({"observation": torch.randn(3, 4)}, [3,])
>>> td_clone = td_module(td.clone())
>>> print(td_clone)
TensorDict(
    fields={
        action: Tensor(shape=torch.Size([3, 4]), device=cpu, dtype=torch.float32, is_shared=False),
        hidden: Tensor(shape=torch.Size([3, 4]), device=cpu, dtype=torch.float32, is_shared=False),
        loc: Tensor(shape=torch.Size([3, 4]), device=cpu, dtype=torch.float32, is_shared=False),
        observation: Tensor(shape=torch.Size([3, 4]), device=cpu, dtype=torch.float32, is_shared=False),
        sample_log_prob: Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float32, is_shared=False),
        scale: Tensor(shape=torch.Size([3, 4]), device=cpu, dtype=torch.float32, is_shared=False),
        state_value: Tensor(shape=torch.Size([3, 1]), device=cpu, dtype=torch.float32, is_shared=False)},
    batch_size=torch.Size([3]),
    device=None,
    is_shared=False)
>>> td_clone = td_module.get_policy_operator()(td.clone())
>>> print(td_clone)  # no value
TensorDict(
    fields={
        action: Tensor(shape=torch.Size([3, 4]), device=cpu, dtype=torch.float32, is_shared=False),
        hidden: Tensor(shape=torch.Size([3, 4]), device=cpu, dtype=torch.float32, is_shared=False),
        loc: Tensor(shape=torch.Size([3, 4]), device=cpu, dtype=torch.float32, is_shared=False),
        observation: Tensor(shape=torch.Size([3, 4]), device=cpu, dtype=torch.float32, is_shared=False),
        sample_log_prob: Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float32, is_shared=False),
        scale: Tensor(shape=torch.Size([3, 4]), device=cpu, dtype=torch.float32, is_shared=False)},
    batch_size=torch.Size([3]),
    device=None,
    is_shared=False)
>>> td_clone = td_module.get_value_operator()(td.clone())
>>> print(td_clone)  # no action
TensorDict(
    fields={
        hidden: Tensor(shape=torch.Size([3, 4]), device=cpu, dtype=torch.float32, is_shared=False),
        observation: Tensor(shape=torch.Size([3, 4]), device=cpu, dtype=torch.float32, is_shared=False),
        state_value: Tensor(shape=torch.Size([3, 1]), device=cpu, dtype=torch.float32, is_shared=False)},
    batch_size=torch.Size([3]),
    device=None,
    is_shared=False)
get_policy_head() SafeSequential[源代码]

返回策略头部。

get_policy_operator() SafeSequential[源代码]

返回一个将观察值映射到动作的独立策略算子。

get_value_head() SafeSequential[源代码]

返回价值头部。

get_value_operator() SafeSequential[源代码]

返回一个将观察值映射到价值估计的独立价值网络算子。

文档

访问 PyTorch 的全面开发者文档

查看文档

教程

获取针对初学者和高级开发人员的深入教程

查看教程

资源

查找开发资源并获得解答

查看资源