DistributionalQValueActor¶
- class torchrl.modules.tensordict_module.DistributionalQValueActor(*args, **kwargs)[source]¶
一个分布式的 DQN Actor 类。
此类在输入模块后附加一个
QValueModule
,以便使用动作值来选择动作。- 参数:
module (nn.Module) – 一个
torch.nn.Module
,用于将输入映射到输出参数空间。如果 module 不是torchrl.modules.DistributionalDQNnet
类型,DistributionalQValueActor
将确保对动作值张量沿维度-2
应用 log-softmax 操作。可以通过关闭make_log_softmax
关键字参数来禁用此操作。- 关键字参数:
in_keys (字符串的可迭代对象, 可选) – 要从输入 tensordict 中读取并传递给 module 的键。如果包含多个元素,则值将按照 in_keys 可迭代对象给出的顺序传递。默认为
["observation"]
。spec (TensorSpec, 可选) – 仅关键字参数。输出张量的规范。如果 module 输出多个输出张量,则 spec 表征第一个输出张量的空间。
safe (布尔值) – 仅关键字参数。如果为
True
,则输出值将根据输入规范进行检查。由于探索策略或数值下溢/溢出问题,可能会发生域外采样。如果此值超出范围,则使用TensorSpec.project
方法将其投影回期望空间。默认为False
。var_nums (整数, 可选) – 如果
action_space = "mult-one-hot"
,则此值表示每个动作组件的基数。support (torch.Tensor) – 动作值的支撑。
action_space (字符串, 可选) – 动作空间。必须是
"one-hot"
、"mult-one-hot"
、"binary"
或"categorical"
之一。此参数与spec
互斥,因为spec
条件化了 action_space。make_log_softmax (布尔值, 可选) – 如果为
True
且 module 不是torchrl.modules.DistributionalDQNnet
类型,则将沿动作值张量的维度 -2 应用 log-softmax 操作。action_value_key (字符串 或 字符串元组, 可选) – 如果输入 module 是
tensordict.nn.TensorDictModuleBase
实例,则它必须匹配其输出键之一。否则,此字符串表示输出 tensordict 中动作值条目的名称。action_mask_key (字符串 或 字符串元组, 可选) – 表示动作掩码的输入键。默认为
"None"
(等效于无掩码)。
示例
>>> import torch >>> from tensordict import TensorDict >>> from tensordict.nn import TensorDictModule, TensorDictSequential >>> from torch import nn >>> from torchrl.data import OneHot >>> from torchrl.modules import DistributionalQValueActor, MLP >>> td = TensorDict({'observation': torch.randn(5, 4)}, [5]) >>> nbins = 3 >>> module = MLP(out_features=(nbins, 4), depth=2) >>> # let us make sure that the output is a log-softmax >>> module = TensorDictSequential( ... TensorDictModule(module, ["observation"], ["action_value"]), ... TensorDictModule(lambda x: x.log_softmax(-2), ["action_value"], ["action_value"]), ... ) >>> action_spec = OneHot(4) >>> qvalue_actor = DistributionalQValueActor( ... module=module, ... spec=action_spec, ... support=torch.arange(nbins)) >>> td = qvalue_actor(td) >>> print(td) TensorDict( fields={ action: Tensor(shape=torch.Size([5, 4]), device=cpu, dtype=torch.int64, is_shared=False), action_value: Tensor(shape=torch.Size([5, 3, 4]), device=cpu, dtype=torch.float32, is_shared=False), observation: Tensor(shape=torch.Size([5, 4]), device=cpu, dtype=torch.float32, is_shared=False)}, batch_size=torch.Size([5]), device=None, is_shared=False)