ModelBasedEnvBase¶
- torchrl.envs.ModelBasedEnvBase(*args, **kwargs)[源代码]¶
基于模型的强化学习 (MBRL) 最先进实现的基本环境。
MBRL 算法模型的包装器。旨在为世界模型提供环境框架(包括但不限于观察、奖励、完成状态和安全约束模型),并表现得像一个经典的环境。
这是一个其他环境的基类,不应直接使用。
示例
>>> import torch >>> from tensordict import TensorDict >>> from torchrl.data import CompositeSpec, UnboundedContinuousTensorSpec >>> class MyMBEnv(ModelBasedEnvBase): ... def __init__(self, world_model, device="cpu", dtype=None, batch_size=None): ... super().__init__(world_model, device=device, dtype=dtype, batch_size=batch_size) ... self.observation_spec = CompositeSpec( ... hidden_observation=UnboundedContinuousTensorSpec((4,)) ... ) ... self.state_spec = CompositeSpec( ... hidden_observation=UnboundedContinuousTensorSpec((4,)), ... ) ... self.action_spec = UnboundedContinuousTensorSpec((1,)) ... self.reward_spec = UnboundedContinuousTensorSpec((1,)) ... ... def _reset(self, tensordict: TensorDict) -> TensorDict: ... tensordict = TensorDict({}, ... batch_size=self.batch_size, ... device=self.device, ... ) ... tensordict = tensordict.update(self.state_spec.rand()) ... tensordict = tensordict.update(self.observation_spec.rand()) ... return tensordict >>> # This environment is used as follows: >>> import torch.nn as nn >>> from torchrl.modules import MLP, WorldModelWrapper >>> world_model = WorldModelWrapper( ... TensorDictModule( ... MLP(out_features=4, activation_class=nn.ReLU, activate_last_layer=True, depth=0), ... in_keys=["hidden_observation", "action"], ... out_keys=["hidden_observation"], ... ), ... TensorDictModule( ... nn.Linear(4, 1), ... in_keys=["hidden_observation"], ... out_keys=["reward"], ... ), ... ) >>> env = MyMBEnv(world_model) >>> tensordict = env.rollout(max_steps=10) >>> print(tensordict) TensorDict( fields={ action: Tensor(torch.Size([10, 1]), dtype=torch.float32), done: Tensor(torch.Size([10, 1]), dtype=torch.bool), hidden_observation: Tensor(torch.Size([10, 4]), dtype=torch.float32), next: LazyStackedTensorDict( fields={ hidden_observation: Tensor(torch.Size([10, 4]), dtype=torch.float32)}, batch_size=torch.Size([10]), device=cpu, is_shared=False), reward: Tensor(torch.Size([10, 1]), dtype=torch.float32)}, batch_size=torch.Size([10]), device=cpu, is_shared=False)
- 属性
observation_spec (CompositeSpec): 观察的采样规范;
action_spec (TensorSpec): 动作的采样规范;
reward_spec (TensorSpec): 奖励的采样规范;
input_spec (CompositeSpec): 输入的采样规范;
batch_size (torch.Size): 环境将使用的批次大小。如果未设置,则环境接受所有批次大小的张量字典。
device (torch.device): 预期环境输入和输出所在的设备
- 参数:
world_model (nn.Module) – 生成世界状态及其相应奖励的模型;
params (List[torch.Tensor], 可选) – 世界模型参数的列表;
buffers (List[torch.Tensor], 可选) – 世界模型缓冲区的列表;
device (torch.device, 可选) – 预期环境输入和输出所在的设备
dtype (torch.dtype, 可选) – 环境输入和输出的数据类型
batch_size (torch.Size, 可选) – 实例中包含的环境数量
run_type_check (bool, 可选) – 是否在环境步骤上运行类型检查
- torchrl.envs.step(TensorDict -> TensorDict)¶
环境中的步骤
- torchrl.envs.reset(TensorDict, 可选 -> TensorDict)¶
重置环境
- torchrl.envs.set_seed(int -> int)¶
设置环境的种子
- torchrl.envs.rand_step(TensorDict, 可选 -> TensorDict)¶
根据动作规范进行随机步骤
- torchrl.envs.rollout(Callable, ... -> TensorDict)¶
使用给定的策略在环境中执行展开(如果未提供策略,则执行随机步骤)