快捷方式

MultiAgentNetBase

class torchrl.modules.MultiAgentNetBase(*, n_agents: int, centralized: bool | None = None, share_params: bool | None = None, agent_dim: int | None = None, vmap_randomness: str = 'different', use_td_params: bool = True, **kwargs)[source]

多智能体网络的基类。

注意

要使用 torch.nn.init 模块初始化 MARL 模块参数,请参考 get_stateful_net()from_stateful_net() 方法。

forward(*inputs: Tuple[Tensor]) Tensor[source]

定义每次调用时执行的计算。

所有子类都应覆盖此方法。

注意

虽然前向传播的逻辑需要在此函数中定义,但之后应该调用 Module 实例而不是直接调用此函数,因为前者会处理注册的钩子,而后者会默默忽略它们。

from_stateful_net(stateful_net: Module)[source]

根据网络的有状态版本填充参数。

有关如何获取网络的有状态版本的详细信息,请参阅 get_stateful_net()

参数:

stateful_net (nn.Module) – 应从中获取参数的有状态网络。

get_stateful_net(copy: bool = True)[source]

返回网络的有状态版本。

这可用于初始化参数。

此类网络通常无法直接调用,需要通过 vmap 调用才能执行。

参数:

copy (bool, 可选) – 如果为 True,则对网络进行深拷贝。默认为 True

如果参数是原地修改的(推荐),则无需将参数复制回 MARL 模块。有关如何使用非原地(out-of-place)重新初始化的参数重新填充 MARL 模型,请参阅 from_stateful_net() 的详细信息。

示例

>>> from torchrl.modules import MultiAgentMLP
>>> import torch
>>> n_agents = 6
>>> n_agent_inputs=3
>>> n_agent_outputs=2
>>> batch = 64
>>> obs = torch.zeros(batch, n_agents, n_agent_inputs)
>>> mlp = MultiAgentMLP(
...     n_agent_inputs=n_agent_inputs,
...     n_agent_outputs=n_agent_outputs,
...     n_agents=n_agents,
...     centralized=False,
...     share_params=False,
...     depth=2,
... )
>>> snet = mlp.get_stateful_net()
>>> def init(module):
...     if hasattr(module, "weight"):
...         torch.nn.init.kaiming_normal_(module.weight)
>>> snet.apply(init)
>>> # If the module has been updated out-of-place (not the case here) we can reset the params
>>> mlp.from_stateful_net(snet)
reset_parameters()[source]

重置模型的参数。

文档

获取 PyTorch 的全面开发者文档

查看文档

教程

获取面向初学者和高级开发者的深度教程

查看教程

资源

查找开发资源并获得问题解答

查看资源