DTActor¶
- class torchrl.modules.DTActor(state_dim: int, action_dim: int, transformer_config: Dict | DecisionTransformer.DTConfig = None, device: DEVICE_TYPING | None = None)[source]¶
决策 Transformer Actor 类。
Decision Transformer 的 Actor 类,用于输出确定性动作,如 “Decision Transformer” <https://arxiv.org/abs/2202.05607.pdf> 中所述。返回确定性动作。
- 参数:
state_dim (int) – 状态维度。
action_dim (int) – 动作维度。
transformer_config (Dict 或
DecisionTransformer.DTConfig
,可选) – GPT2 变换器的配置。默认为default_config()
。device (torch.device,可选) – 要使用的设备。默认为 None。
示例
>>> model = DTActor(state_dim=4, action_dim=2, ... transformer_config=DTActor.default_config()) >>> observation = torch.randn(32, 10, 4) >>> action = torch.randn(32, 10, 2) >>> return_to_go = torch.randn(32, 10, 1) >>> output = model(observation, action, return_to_go) >>> output.shape torch.Size([32, 10, 2])