快捷方式

ModelBasedEnvBase

torchrl.envs.ModelBasedEnvBase(*args, **kwargs)[源代码]

基于模型的强化学习 (MBRL) 最先进实现的基本环境。

MBRL 算法模型的包装器。旨在为世界模型提供环境框架(包括但不限于观察、奖励、完成状态和安全约束模型),并表现得像一个经典的环境。

这是一个其他环境的基类,不应直接使用。

示例

>>> import torch
>>> from tensordict import TensorDict
>>> from torchrl.data import CompositeSpec, UnboundedContinuousTensorSpec
>>> class MyMBEnv(ModelBasedEnvBase):
...     def __init__(self, world_model, device="cpu", dtype=None, batch_size=None):
...         super().__init__(world_model, device=device, dtype=dtype, batch_size=batch_size)
...         self.observation_spec = CompositeSpec(
...             hidden_observation=UnboundedContinuousTensorSpec((4,))
...         )
...         self.state_spec = CompositeSpec(
...             hidden_observation=UnboundedContinuousTensorSpec((4,)),
...         )
...         self.action_spec = UnboundedContinuousTensorSpec((1,))
...         self.reward_spec = UnboundedContinuousTensorSpec((1,))
...
...     def _reset(self, tensordict: TensorDict) -> TensorDict:
...         tensordict = TensorDict({},
...             batch_size=self.batch_size,
...             device=self.device,
...         )
...         tensordict = tensordict.update(self.state_spec.rand())
...         tensordict = tensordict.update(self.observation_spec.rand())
...         return tensordict
>>> # This environment is used as follows:
>>> import torch.nn as nn
>>> from torchrl.modules import MLP, WorldModelWrapper
>>> world_model = WorldModelWrapper(
...     TensorDictModule(
...         MLP(out_features=4, activation_class=nn.ReLU, activate_last_layer=True, depth=0),
...         in_keys=["hidden_observation", "action"],
...         out_keys=["hidden_observation"],
...     ),
...     TensorDictModule(
...         nn.Linear(4, 1),
...         in_keys=["hidden_observation"],
...         out_keys=["reward"],
...     ),
... )
>>> env = MyMBEnv(world_model)
>>> tensordict = env.rollout(max_steps=10)
>>> print(tensordict)
TensorDict(
    fields={
        action: Tensor(torch.Size([10, 1]), dtype=torch.float32),
        done: Tensor(torch.Size([10, 1]), dtype=torch.bool),
        hidden_observation: Tensor(torch.Size([10, 4]), dtype=torch.float32),
        next: LazyStackedTensorDict(
            fields={
                hidden_observation: Tensor(torch.Size([10, 4]), dtype=torch.float32)},
            batch_size=torch.Size([10]),
            device=cpu,
            is_shared=False),
        reward: Tensor(torch.Size([10, 1]), dtype=torch.float32)},
    batch_size=torch.Size([10]),
    device=cpu,
    is_shared=False)
属性
  • observation_spec (CompositeSpec): 观察的采样规范;

  • action_spec (TensorSpec): 动作的采样规范;

  • reward_spec (TensorSpec): 奖励的采样规范;

  • input_spec (CompositeSpec): 输入的采样规范;

  • batch_size (torch.Size): 环境将使用的批次大小。如果未设置,则环境接受所有批次大小的张量字典。

  • device (torch.device): 预期环境输入和输出所在的设备

参数:
  • world_model (nn.Module) – 生成世界状态及其相应奖励的模型;

  • params (List[torch.Tensor], 可选) – 世界模型参数的列表;

  • buffers (List[torch.Tensor], 可选) – 世界模型缓冲区的列表;

  • device (torch.device, 可选) – 预期环境输入和输出所在的设备

  • dtype (torch.dtype, 可选) – 环境输入和输出的数据类型

  • batch_size (torch.Size, 可选) – 实例中包含的环境数量

  • run_type_check (bool, 可选) – 是否在环境步骤上运行类型检查

torchrl.envs.step(TensorDict -> TensorDict)

环境中的步骤

torchrl.envs.reset(TensorDict, 可选 -> TensorDict)

重置环境

torchrl.envs.set_seed(int -> int)

设置环境的种子

torchrl.envs.rand_step(TensorDict, 可选 -> TensorDict)

根据动作规范进行随机步骤

torchrl.envs.rollout(Callable, ... -> TensorDict)

使用给定的策略在环境中执行展开(如果未提供策略,则执行随机步骤)

文档

访问 PyTorch 的全面开发者文档

查看文档

教程

获取针对初学者和高级开发人员的深入教程

查看教程

资源

查找开发资源并获得问题的解答

查看资源