快捷方式

ValueOperator

class torchrl.modules.tensordict_module.ValueOperator(*args, **kwargs)[source]

强化学习中值函数的通用类。

ValueOperator 类带有 in_keys 和 out_keys 参数的默认值(分别为 [“observation”] 和 [“state_value”] 或 [“state_action_value”],具体取决于“action”键是否为 in_keys 列表的一部分)。

参数:
  • module (nn.Module) – 用于将输入映射到输出参数空间的 torch.nn.Module

  • in_keys (str 的可迭代对象, 可选) – 要从输入 tensordict 读取并传递到模块的键。如果它包含多个元素,则值将按照 in_keys 可迭代对象给出的顺序传递。默认为 ["observation"]

  • out_keys (str 的可迭代对象) – 要写入输入 tensordict 的键。out_keys 的长度必须与嵌入模块返回的张量数量匹配。使用“_”作为键可以避免将张量写入输出。默认为 ["state_value"]["state_action_value"],如果 "action"in_keys 的一部分。

示例

>>> import torch
>>> from tensordict import TensorDict
>>> from torch import nn
>>> from torchrl.data import UnboundedContinuousTensorSpec
>>> from torchrl.modules import ValueOperator
>>> td = TensorDict({"observation": torch.randn(3, 4), "action": torch.randn(3, 2)}, [3,])
>>> class CustomModule(nn.Module):
...     def __init__(self):
...         super().__init__()
...         self.linear = torch.nn.Linear(6, 1)
...     def forward(self, obs, action):
...         return self.linear(torch.cat([obs, action], -1))
>>> module = CustomModule()
>>> td_module = ValueOperator(
...    in_keys=["observation", "action"], module=module
... )
>>> td = td_module(td)
>>> print(td)
TensorDict(
    fields={
        action: Tensor(shape=torch.Size([3, 2]), device=cpu, dtype=torch.float32, is_shared=False),
        observation: Tensor(shape=torch.Size([3, 4]), device=cpu, dtype=torch.float32, is_shared=False),
        state_action_value: Tensor(shape=torch.Size([3, 1]), device=cpu, dtype=torch.float32, is_shared=False)},
    batch_size=torch.Size([3]),
    device=None,
    is_shared=False)

文档

访问 PyTorch 的全面开发者文档

查看文档

教程

获取针对初学者和高级开发人员的深入教程

查看教程

资源

查找开发资源并解答您的问题

查看资源