快捷方式

MultiAgentConvNet

class torchrl.modules.MultiAgentConvNet(n_agents: int, centralized: ~typing.Optional[bool] = None, share_params: ~typing.Optional[bool] = None, *, in_features: ~typing.Optional[int] = None, device: ~typing.Optional[~typing.Union[~torch.device, str, int]] = None, num_cells: ~typing.Optional[~typing.Sequence[int]] = None, kernel_sizes: ~typing.Union[~typing.Sequence[~typing.Union[int, ~typing.Sequence[int]]], int] = 5, strides: ~typing.Union[~typing.Sequence, int] = 2, paddings: ~typing.Union[~typing.Sequence, int] = 0, activation_class: ~typing.Type[~torch.nn.modules.module.Module] = <class 'torch.nn.modules.activation.ELU'>, use_td_params: bool = True, **kwargs)[source]

多智能体 CNN。

在 MARL 设置中,智能体可以共享或不共享相同的策略来执行动作:我们说参数可以是共享的或不共享的。类似地,网络可以采用整个观察空间(跨智能体)或基于每个智能体来计算其输出,我们分别将其称为“中心化”和“非中心化”。

它期望输入形状为 (*B, n_agents, channels, x, y)

注意

要使用 torch.nn.init 模块初始化 MARL 模块参数,请参考 get_stateful_net()from_stateful_net() 方法。

参数:
  • n_agents (int) – 智能体数量。

  • centralized (bool) – 如果为 True,则每个智能体将使用所有智能体的输入来计算其输出,从而导致输入形状为 (*B, n_agents * channels, x, y)。否则,每个智能体将仅使用其自身的数据作为输入。

  • share_params (bool) – 如果为 True,则将使用相同的 ConvNet 为所有智能体进行前向传递(同质策略)。否则,每个智能体将使用不同的 ConvNet 来处理其输入(异质策略)。

关键字参数:
  • in_features (int, 可选) – 输入特征维度。如果留空为 None,则使用惰性模块。

  • device (strtorch.device, 可选) – 在其上创建模块的设备。

  • num_cells (intSequence[int], 可选) – 输入和输出之间每层的单元数。如果提供整数,则每层将具有相同数量的单元。如果提供可迭代对象,则线性层 out_features 将与 num_cells 的内容匹配。

  • kernel_sizes (int, Sequence[Union[int, Sequence[int]]]) – 卷积网络的内核大小。默认为 5

  • strides (intSequence[int]) – 卷积网络的步幅。如果为可迭代对象,则长度必须与由 num_cells 或 depth 参数定义的深度匹配。默认为 2

  • activation_class (Type[nn.Module]) – 要使用的激活类。默认为 torch.nn.ELU

  • use_td_params (bool, 可选) – 如果为 True,则可以在 self.params 中找到参数,这是一个 TensorDictParams 对象(它同时继承自 TensorDictnn.Module)。如果为 False,则参数包含在 self._empty_net 中。 考虑到所有因素,这两种方法应该大致相同,但不可互换:例如,使用 use_td_params=True 创建的 state_dict 不能在 use_td_params=False 时使用。

  • **kwargs – 可以传递给 ConvNet 以自定义 ConvNet。

示例

>>> import torch
>>> from torchrl.modules import MultiAgentConvNet
>>> batch = (3,2)
>>> n_agents = 7
>>> channels, x, y = 3, 100, 100
>>> obs = torch.randn(*batch, n_agents, channels, x, y)
>>> # Let's consider a centralized network with shared parameters.
>>> cnn = MultiAgentConvNet(
...     n_agents,
...     centralized = True,
...     share_params = True
... )
>>> print(cnn)
MultiAgentConvNet(
    (agent_networks): ModuleList(
        (0): ConvNet(
        (0): LazyConv2d(0, 32, kernel_size=(5, 5), stride=(2, 2))
        (1): ELU(alpha=1.0)
        (2): Conv2d(32, 32, kernel_size=(5, 5), stride=(2, 2))
        (3): ELU(alpha=1.0)
        (4): Conv2d(32, 32, kernel_size=(5, 5), stride=(2, 2))
        (5): ELU(alpha=1.0)
        (6): SquashDims()
        )
    )
)
>>> result = cnn(obs)
>>> # The final dimension of the resulting tensor would be determined based on the layer definition arguments and the shape of input 'obs'.
>>> print(result.shape)
torch.Size([3, 2, 7, 2592])
>>> # Since both observations and parameters are shared, we expect all agents to have identical outputs (eg. for a value function)
>>> print(all(result[0,0,0] == result[0,0,1]))
True
>>> # Alternatively, a local network with parameter sharing (eg. decentralized weight sharing policy)
>>> cnn = MultiAgentConvNet(
...     n_agents,
...     centralized = False,
...     share_params = True
... )
>>> print(cnn)
MultiAgentConvNet(
    (agent_networks): ModuleList(
        (0): ConvNet(
        (0): Conv2d(4, 32, kernel_size=(5, 5), stride=(2, 2))
        (1): ELU(alpha=1.0)
        (2): Conv2d(32, 32, kernel_size=(5, 5), stride=(2, 2))
        (3): ELU(alpha=1.0)
        (4): Conv2d(32, 32, kernel_size=(5, 5), stride=(2, 2))
        (5): ELU(alpha=1.0)
        (6): SquashDims()
        )
    )
)
>>> print(result.shape)
torch.Size([3, 2, 7, 2592])
>>> # Parameters are shared but not observations, hence each agent has a different output.
>>> print(all(result[0,0,0] == result[0,0,1]))
False
>>> # Or multiple local networks identical in structure but with differing weights.
>>> cnn = MultiAgentConvNet(
...     n_agents,
...     centralized = False,
...     share_params = False
... )
>>> print(cnn)
MultiAgentConvNet(
    (agent_networks): ModuleList(
        (0-6): 7 x ConvNet(
        (0): Conv2d(4, 32, kernel_size=(5, 5), stride=(2, 2))
        (1): ELU(alpha=1.0)
        (2): Conv2d(32, 32, kernel_size=(5, 5), stride=(2, 2))
        (3): ELU(alpha=1.0)
        (4): Conv2d(32, 32, kernel_size=(5, 5), stride=(2, 2))
        (5): ELU(alpha=1.0)
        (6): SquashDims()
        )
    )
)
>>> print(result.shape)
torch.Size([3, 2, 7, 2592])
>>> print(all(result[0,0,0] == result[0,0,1]))
False
>>> # Or where inputs are shared but not parameters.
>>> cnn = MultiAgentConvNet(
...     n_agents,
...     centralized = True,
...     share_params = False
... )
>>> print(cnn)
MultiAgentConvNet(
    (agent_networks): ModuleList(
        (0-6): 7 x ConvNet(
        (0): Conv2d(28, 32, kernel_size=(5, 5), stride=(2, 2))
        (1): ELU(alpha=1.0)
        (2): Conv2d(32, 32, kernel_size=(5, 5), stride=(2, 2))
        (3): ELU(alpha=1.0)
        (4): Conv2d(32, 32, kernel_size=(5, 5), stride=(2, 2))
        (5): ELU(alpha=1.0)
        (6): SquashDims()
        )
    )
)
>>> print(result.shape)
torch.Size([3, 2, 7, 2592])
>>> print(all(result[0,0,0] == result[0,0,1]))
False

文档

访问 PyTorch 的全面开发者文档

查看文档

教程

获取面向初学者和高级开发者的深入教程

查看教程

资源

查找开发资源并获得解答

查看资源