快捷方式

get_primers_from_module

class torchrl.modules.utils.get_primers_from_module(module)[源代码]

从模块的所有子模块中获取所有 tensordict 引导程序。

此方法可用于从包含在父模块中的模块中检索引导程序。

参数:

**module** (torch.nn.Module) – 父模块。

返回:

一个 TensorDictPrimer 变换。

返回类型:

TensorDictPrimer

示例

>>> from torchrl.modules.utils import get_primers_from_module
>>> from torchrl.modules import GRUModule, MLP
>>> from tensordict.nn import TensorDictModule, TensorDictSequential
>>> # Define a GRU module
>>> gru_module = GRUModule(
...     input_size=10,
...     hidden_size=10,
...     num_layers=1,
...     in_keys=["input", "recurrent_state", "is_init"],
...     out_keys=["features", ("next", "recurrent_state")],
... )
>>> # Define a head module
>>> head = TensorDictModule(
...     MLP(
...         in_features=10,
...         out_features=10,
...         num_cells=[],
...     ),
...     in_keys=["features"],
...     out_keys=["output"],
... )
>>> # Create a sequential model
>>> model = TensorDictSequential(gru_module, head)
>>> # Retrieve primers from the model
>>> primers = get_primers_from_module(model)
>>> print(primers)
TensorDictPrimer(primers=CompositeSpec(
recurrent_state: UnboundedContinuousTensorSpec(

shape=torch.Size([1, 10]), space=None, device=cpu, dtype=torch.float32, domain=continuous), device=None, shape=torch.Size([])), default_value={‘recurrent_state’: 0.0}, random=None)

文档

访问 PyTorch 的全面开发者文档

查看文档

教程

获取面向初学者和高级开发者的深入教程

查看教程

资源

查找开发资源并获得问题的解答

查看资源