ActorCriticWrapper¶
- class torchrl.modules.tensordict_module.ActorCriticWrapper(*args, **kwargs)[source]¶
没有公共模块的 Actor-Value 运算符。
此类将 actor 和 value 模型包装在一起,它们不共享一个公共观察嵌入网络
为了方便工作流程,此类附带 get_policy_operator() 和 get_value_operator() 方法,这两个方法都将返回一个具有专用功能的独立 TDModule。
- 参数:
policy_operator (TensorDictModule) – 一个策略运算符,它读取隐藏变量并返回一个动作
value_operator (TensorDictModule) – 一个值运算符,它读取隐藏变量并返回一个值
示例
>>> import torch >>> from tensordict import TensorDict >>> from tensordict.nn import TensorDictModule >>> from torchrl.modules import ( ... ActorCriticWrapper, ... ProbabilisticActor, ... NormalParamExtractor, ... TanhNormal, ... ValueOperator, ... ) >>> action_module = TensorDictModule( ... nn.Sequential(torch.nn.Linear(4, 8), NormalParamExtractor()), ... in_keys=["observation"], ... out_keys=["loc", "scale"], ... ) >>> td_module_action = ProbabilisticActor( ... module=action_module, ... in_keys=["loc", "scale"], ... distribution_class=TanhNormal, ... return_log_prob=True, ... ) >>> module_value = torch.nn.Linear(4, 1) >>> td_module_value = ValueOperator( ... module=module_value, ... in_keys=["observation"], ... ) >>> td_module = ActorCriticWrapper(td_module_action, td_module_value) >>> td = TensorDict({"observation": torch.randn(3, 4)}, [3,]) >>> td_clone = td_module(td.clone()) >>> print(td_clone) TensorDict( fields={ action: Tensor(shape=torch.Size([3, 4]), device=cpu, dtype=torch.float32, is_shared=False), loc: Tensor(shape=torch.Size([3, 4]), device=cpu, dtype=torch.float32, is_shared=False), observation: Tensor(shape=torch.Size([3, 4]), device=cpu, dtype=torch.float32, is_shared=False), sample_log_prob: Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float32, is_shared=False), scale: Tensor(shape=torch.Size([3, 4]), device=cpu, dtype=torch.float32, is_shared=False), state_value: Tensor(shape=torch.Size([3, 1]), device=cpu, dtype=torch.float32, is_shared=False)}, batch_size=torch.Size([3]), device=None, is_shared=False) >>> td_clone = td_module.get_policy_operator()(td.clone()) >>> print(td_clone) # no value TensorDict( fields={ action: Tensor(shape=torch.Size([3, 4]), device=cpu, dtype=torch.float32, is_shared=False), loc: Tensor(shape=torch.Size([3, 4]), device=cpu, dtype=torch.float32, is_shared=False), observation: Tensor(shape=torch.Size([3, 4]), device=cpu, dtype=torch.float32, is_shared=False), sample_log_prob: Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float32, is_shared=False), scale: Tensor(shape=torch.Size([3, 4]), device=cpu, dtype=torch.float32, is_shared=False)}, batch_size=torch.Size([3]), device=None, is_shared=False) >>> td_clone = td_module.get_value_operator()(td.clone()) >>> print(td_clone) # no action TensorDict( fields={ observation: Tensor(shape=torch.Size([3, 4]), device=cpu, dtype=torch.float32, is_shared=False), state_value: Tensor(shape=torch.Size([3, 1]), device=cpu, dtype=torch.float32, is_shared=False)}, batch_size=torch.Size([3]), device=None, is_shared=False)
- get_policy_head() SafeSequential ¶
返回一个独立的策略运算符,它将观察映射到动作。
- get_policy_operator() SafeSequential [source]¶
返回一个独立的策略运算符,它将观察映射到动作。
- get_value_head() SafeSequential ¶
返回一个独立的值网络运算符,它将观察映射到值估计。
- get_value_operator() SafeSequential [source]¶
返回一个独立的值网络运算符,它将观察映射到值估计。