DistributionalDQNnet¶
- class torchrl.modules.DistributionalDQNnet(*args, **kwargs)[源代码]¶
分布式深度 Q 网络 softmax 层。
此层应用于预测动作值的常规模型和作用于 logits 值的分布之间。
- 参数:
in_keys (str 列表或 str 元组) – 对数 softmax 操作的输入键。默认为
["action_value"]
。out_keys (str 列表或 str 元组) – 对数 softmax 操作的输出键。默认为
["action_value"]
。
示例
>>> import torch >>> from tensordict import TensorDict >>> net = DistributionalDQNnet() >>> td = TensorDict({"action_value": torch.randn(10, 5)}, batch_size=[10]) >>> net(td) TensorDict( fields={ action_value: Tensor(shape=torch.Size([10, 5]), device=cpu, dtype=torch.float32, is_shared=False)}, batch_size=torch.Size([10]), device=None, is_shared=False)