DuelingCnnDQNet¶
- class torchrl.modules.DuelingCnnDQNet(out_features: int, out_features_value: int = 1, cnn_kwargs: dict | None = None, mlp_kwargs: dict | None = None, device: DEVICE_TYPING | None = None)[source]¶
双重 CNN Q 网络。
在 https://arxiv.org/abs/1511.06581 中提出
- 参数:
out_features (int) – 优势网络的特征数量。
out_features_value (int) – 值网络的特征数量。
cnn_kwargs (dict 或 dicts 列表, 可选) –
特征网络的关键字参数。默认为
>>> cnn_kwargs = { ... 'num_cells': [32, 64, 64], ... 'strides': [4, 2, 1], ... 'kernels': [8, 4, 3], ... }
mlp_kwargs (dict 或 dicts 列表, 可选) –
优势和值网络的关键字参数。默认为
>>> mlp_kwargs = { ... "depth": 1, ... "activation_class": nn.ELU, ... "num_cells": 512, ... "bias_last_layer": True, ... }
device (torch.device, 可选) – 创建模块的设备。
示例
>>> import torch >>> from torchrl.modules import DuelingCnnDQNet >>> net = DuelingCnnDQNet(out_features=20) >>> print(net) DuelingCnnDQNet( (features): ConvNet( (0): LazyConv2d(0, 32, kernel_size=(8, 8), stride=(4, 4)) (1): ELU(alpha=1.0) (2): Conv2d(32, 64, kernel_size=(4, 4), stride=(2, 2)) (3): ELU(alpha=1.0) (4): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1)) (5): ELU(alpha=1.0) (6): SquashDims() ) (advantage): MLP( (0): LazyLinear(in_features=0, out_features=512, bias=True) (1): ELU(alpha=1.0) (2): Linear(in_features=512, out_features=20, bias=True) ) (value): MLP( (0): LazyLinear(in_features=0, out_features=512, bias=True) (1): ELU(alpha=1.0) (2): Linear(in_features=512, out_features=1, bias=True) ) ) >>> x = torch.zeros(1, 3, 64, 64) >>> y = net(x) >>> print(y.shape) torch.Size([1, 20])