快捷方式

Actor (执行器)

class torchrl.modules.tensordict_module.Actor(*args, **kwargs)[source]

RL 中确定性执行器的通用类。

Actor 类带有 out_keys ([\"action\"]) 的默认值,如果提供了 spec 但不是 Composite 对象,它将自动转换为 spec = Composite(action=spec)。

参数:
  • module (nn.Module) – 用于将输入映射到输出参数空间的 Module。

  • in_keys (str 的可迭代对象,可选) – 从输入 tensordict 读取并传递给模块的键。如果它包含多个元素,则值将按照 in_keys 可迭代对象给出的顺序传递。默认为 [\"observation\"]。

  • out_keys (str 的可迭代对象) – 要写入输入 tensordict 的键。out_keys 的长度必须与嵌入式模块返回的张量数量匹配。使用 \"_\" 作为键可以避免将张量写入输出。默认为 [\"action\"]。

关键字参数:
  • spec (TensorSpec,可选) – 仅关键字参数。输出张量的规范。如果模块输出多个输出张量,则 spec 表征第一个输出张量的空间。

  • safe (bool) – 仅关键字参数。如果为 True,则根据输入规范检查输出值。由于探索策略或数值下溢/溢出问题,可能会发生域外采样。如果此值超出范围,则使用 project() 方法将其投影回所需的空间。默认为 False。

示例

>>> import torch
>>> from tensordict import TensorDict
>>> from torchrl.data import Unbounded
>>> from torchrl.modules import Actor
>>> torch.manual_seed(0)
>>> td = TensorDict({"observation": torch.randn(3, 4)}, [3,])
>>> action_spec = Unbounded(4)
>>> module = torch.nn.Linear(4, 4)
>>> td_module = Actor(
...    module=module,
...    spec=action_spec,
...    )
>>> td_module(td)
TensorDict(
    fields={
        action: Tensor(shape=torch.Size([3, 4]), device=cpu, dtype=torch.float32, is_shared=False),
        observation: Tensor(shape=torch.Size([3, 4]), device=cpu, dtype=torch.float32, is_shared=False)},
    batch_size=torch.Size([3]),
    device=None,
    is_shared=False)
>>> print(td.get("action"))
tensor([[-1.3635, -0.0340,  0.1476, -1.3911],
        [-0.1664,  0.5455,  0.2247, -0.4583],
        [-0.2916,  0.2160,  0.5337, -0.5193]], grad_fn=<AddmmBackward0>)

© 版权所有 2022, Meta。

使用 Sphinx 构建,主题由 Read the Docs 提供。

文档

访问 PyTorch 的全面开发者文档

查看文档

教程

获取面向初学者和高级开发者的深入教程

查看教程

资源

查找开发资源并获得解答

查看资源