DistributionalDQNnet¶
- class torchrl.modules.DistributionalDQNnet(*args, **kwargs)[source]¶
分布深度 Q 网络 Softmax 层。
此层应位于预测动作值的常规模型与作用于 logits 值的分布之间使用。
- 参数:
in_keys (str 列表或 str 元组) – log-softmax 操作的输入键。默认为
["action_value"]
。out_keys (str 列表或 str 元组) – log-softmax 操作的输出键。默认为
["action_value"]
。
示例
>>> import torch >>> from tensordict import TensorDict >>> net = DistributionalDQNnet() >>> td = TensorDict({"action_value": torch.randn(10, 5)}, batch_size=[10]) >>> net(td) TensorDict( fields={ action_value: Tensor(shape=torch.Size([10, 5]), device=cpu, dtype=torch.float32, is_shared=False)}, batch_size=torch.Size([10]), device=None, is_shared=False)