QValueModule¶
- class torchrl.modules.tensordict_module.QValueModule(*args, **kwargs)[源代码]¶
用于 Q 值策略的 Q 值 TensorDictModule。
此模块将包含动作值的张量处理为其 argmax 组件(即由此产生的贪婪动作),遵循给定的动作空间(独热、二进制或分类)。它可以使用张量字典和普通张量。
- 参数:
action_space (str, 可选) – 动作空间。必须是
"one-hot"
、"mult-one-hot"
、"binary"
或"categorical"
之一。此参数与spec
互斥,因为spec
决定了 action_space。action_value_key (str 或 str 元组, 可选) – 表示动作值的输入键。默认值为
"action_value"
。action_mask_key (str 或 str 元组, 可选) – 表示动作掩码的输入键。默认值为
"None"
(相当于没有掩码)。out_keys (str 列表 或 str 元组, 可选) – 表示动作、动作值和选定动作值的输出键。默认值为
["action", "action_value", "chosen_action_value"]
。var_nums (int, 可选) – 如果
action_space = "mult-one-hot"
,则此值表示每个动作组件的基数。spec (TensorSpec, 可选) – 如果提供,则提供动作(和/或其他输出)的规格。这与
action_space
互斥,因为规格决定了动作空间。safe (bool) – 如果为
True
,则会根据输入规格检查输出的值。由于探索策略或数值下溢/上溢问题,可能会出现域外采样。如果此值超出范围,则使用TensorSpec.project
方法将其投影回所需的空間。默认值为False
。
- 返回:
如果输入是单个张量,则返回包含选定动作、值和选定动作的值的三元组。如果提供了张量字典,则使用
out_keys
字段中指示的键更新该字典。
示例
>>> from tensordict import TensorDict >>> action_space = "categorical" >>> action_value_key = "my_action_value" >>> actor = QValueModule(action_space, action_value_key=action_value_key) >>> # This module works with both tensordict and regular tensors: >>> value = torch.zeros(4) >>> value[-1] = 1 >>> actor(my_action_value=value) (tensor(3), tensor([0., 0., 0., 1.]), tensor([1.])) >>> actor(value) (tensor(3), tensor([0., 0., 0., 1.]), tensor([1.])) >>> actor(TensorDict({action_value_key: value}, [])) TensorDict( fields={ action: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.int64, is_shared=False), action_value: Tensor(shape=torch.Size([4]), device=cpu, dtype=torch.float32, is_shared=False), chosen_action_value: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.float32, is_shared=False), my_action_value: Tensor(shape=torch.Size([4]), device=cpu, dtype=torch.float32, is_shared=False)}, batch_size=torch.Size([]), device=None, is_shared=False)