MultiAgentNetBase¶
- class torchrl.modules.MultiAgentNetBase(*, n_agents: int, centralized: Optional[bool] = None, share_params: Optional[bool] = None, agent_dim: Optional[int] = None, vmap_randomness: str = 'different', use_td_params: bool = True, **kwargs)[源代码]¶
多智能体网络的基础类。
注意
要使用 torch.nn.init 模块初始化 MARL 模块参数,请参考
get_stateful_net()
和from_stateful_net()
方法。- forward(*inputs: Tuple[Tensor]) Tensor [源代码]¶
定义每次调用时执行的计算。
应由所有子类重写。
注意
尽管前向传递的配方需要在该函数内定义,但应该在此之后调用
Module
实例,而不是调用此函数,因为前者负责运行注册的钩子,而后者会静默地忽略它们。
- from_stateful_net(stateful_net: Module)[源代码]¶
使用网络的有状态版本填充参数。
有关如何收集网络的有状态版本的详细信息,请参阅
get_stateful_net()
。- 参数:
stateful_net (nn.Module) – 应该从中收集参数的有状态网络。
- get_stateful_net(copy: bool = True)[源代码]¶
返回网络的有状态版本。
这可以用于初始化参数。
此类网络通常无法开箱即用,并且需要调用 vmap 才能执行。
- 参数:
copy (bool, optional) – 如果为
True
,则会创建网络的深拷贝。默认为True
。
如果参数是就地修改的(推荐),则无需将参数复制回 MARL 模块。有关如何使用已异地重新初始化的参数重新填充 MARL 模型的详细信息,请参阅
from_stateful_net()
。示例
>>> from torchrl.modules import MultiAgentMLP >>> import torch >>> n_agents = 6 >>> n_agent_inputs=3 >>> n_agent_outputs=2 >>> batch = 64 >>> obs = torch.zeros(batch, n_agents, n_agent_inputs) >>> mlp = MultiAgentMLP( ... n_agent_inputs=n_agent_inputs, ... n_agent_outputs=n_agent_outputs, ... n_agents=n_agents, ... centralized=False, ... share_params=False, ... depth=2, ... ) >>> snet = mlp.get_stateful_net() >>> def init(module): ... if hasattr(module, "weight"): ... torch.nn.init.kaiming_normal_(module.weight) >>> snet.apply(init) >>> # If the module has been updated out-of-place (not the case here) we can reset the params >>> mlp.from_stateful_net(snet)