QValueActor¶
- class torchrl.modules.tensordict_module.QValueActor(*args, **kwargs)[source]¶
Q 值 Actor 类。
此类在输入模块后附加一个
QValueModule
,以便使用动作值来选择动作。- 参数:
module (nn.Module) – 一个
torch.nn.Module
,用于将输入映射到输出参数空间。如果提供的类与tensordict.nn.TensorDictModuleBase
不兼容,它将被包装在一个tensordict.nn.TensorDictModule
中,其in_keys
由以下关键字参数指示。- 关键字参数:
in_keys (str 的可迭代对象, 可选) – 如果提供的类与
tensordict.nn.TensorDictModuleBase
不兼容,此键列表指示需要将哪些观察结果传递给包装的模块以获取动作值。默认为["observation"]
。spec (TensorSpec, 可选) – 仅限关键字的参数。输出张量的规格。如果模块输出多个输出张量,则 spec 描述第一个输出张量的空间。
safe (bool) – 仅限关键字的参数。如果
True
,则检查输出值是否与输入规范匹配。由于探索策略或数值下溢/上溢问题,可能会发生域外采样。如果此值超出范围,则使用TensorSpec.project
方法将其投影回所需的空间。默认为False
。action_space (str, 可选) – 动作空间。必须是
"one-hot"
、"mult-one-hot"
、"binary"
或"categorical"
之一。此参数与spec
互斥,因为spec
决定了动作空间。action_value_key (str 或 str 的元组, 可选) – 如果输入模块是
tensordict.nn.TensorDictModuleBase
实例,则它必须与它的一个输出键匹配。否则,此字符串表示输出 tensordict 中动作值条目的名称。action_mask_key (str 或 str 的元组, 可选) – 表示动作掩码的输入键。默认为
"None"
(相当于没有掩码)。
注意
out_keys
不能传递。如果模块是tensordict.nn.TensorDictModule
实例,则将相应更新 out_keys。对于常规torch.nn.Module
实例,将使用三元组["action", action_value_key, "chosen_action_value"]
。示例
>>> import torch >>> from tensordict import TensorDict >>> from torch import nn >>> from torchrl.data import OneHotDiscreteTensorSpec >>> from torchrl.modules.tensordict_module.actors import QValueActor >>> td = TensorDict({'observation': torch.randn(5, 4)}, [5]) >>> # with a regular nn.Module >>> module = nn.Linear(4, 4) >>> action_spec = OneHotDiscreteTensorSpec(4) >>> qvalue_actor = QValueActor(module=module, spec=action_spec) >>> td = qvalue_actor(td) >>> print(td) TensorDict( fields={ action: Tensor(shape=torch.Size([5, 4]), device=cpu, dtype=torch.int64, is_shared=False), action_value: Tensor(shape=torch.Size([5, 4]), device=cpu, dtype=torch.float32, is_shared=False), chosen_action_value: Tensor(shape=torch.Size([5, 1]), device=cpu, dtype=torch.float32, is_shared=False), observation: Tensor(shape=torch.Size([5, 4]), device=cpu, dtype=torch.float32, is_shared=False)}, batch_size=torch.Size([5]), device=None, is_shared=False) >>> # with a TensorDictModule >>> td = TensorDict({'obs': torch.randn(5, 4)}, [5]) >>> module = TensorDictModule(lambda x: x, in_keys=["obs"], out_keys=["action_value"]) >>> action_spec = OneHotDiscreteTensorSpec(4) >>> qvalue_actor = QValueActor(module=module, spec=action_spec) >>> td = qvalue_actor(td) >>> print(td) TensorDict( fields={ action: Tensor(shape=torch.Size([5, 4]), device=cpu, dtype=torch.int64, is_shared=False), action_value: Tensor(shape=torch.Size([5, 4]), device=cpu, dtype=torch.float32, is_shared=False), chosen_action_value: Tensor(shape=torch.Size([5, 1]), device=cpu, dtype=torch.float32, is_shared=False), obs: Tensor(shape=torch.Size([5, 4]), device=cpu, dtype=torch.float32, is_shared=False)}, batch_size=torch.Size([5]), device=None, is_shared=False)