get_primers_from_module¶
- class torchrl.modules.utils.get_primers_from_module(module)[源代码]¶
从模块的所有子模块中获取所有 tensordict 引导程序。
此方法可用于从包含在父模块中的模块中检索引导程序。
- 参数:
**module** (torch.nn.Module) – 父模块。
- 返回:
一个 TensorDictPrimer 变换。
- 返回类型:
示例
>>> from torchrl.modules.utils import get_primers_from_module >>> from torchrl.modules import GRUModule, MLP >>> from tensordict.nn import TensorDictModule, TensorDictSequential >>> # Define a GRU module >>> gru_module = GRUModule( ... input_size=10, ... hidden_size=10, ... num_layers=1, ... in_keys=["input", "recurrent_state", "is_init"], ... out_keys=["features", ("next", "recurrent_state")], ... ) >>> # Define a head module >>> head = TensorDictModule( ... MLP( ... in_features=10, ... out_features=10, ... num_cells=[], ... ), ... in_keys=["features"], ... out_keys=["output"], ... ) >>> # Create a sequential model >>> model = TensorDictSequential(gru_module, head) >>> # Retrieve primers from the model >>> primers = get_primers_from_module(model) >>> print(primers)
- TensorDictPrimer(primers=CompositeSpec(
- recurrent_state: UnboundedContinuousTensorSpec(
shape=torch.Size([1, 10]), space=None, device=cpu, dtype=torch.float32, domain=continuous), device=None, shape=torch.Size([])), default_value={‘recurrent_state’: 0.0}, random=None)