tensordict.nn.make_functional¶
- tensordict.nn.make_functional(module: nn.Module, funs_to_decorate: Iterable[str] | None = None, keep_params: bool = False, return_params: bool = True) TensorDict ¶
将 nn.Module 就地转换为函数式模块,并返回其参数。
- 参数:
module (torch.nn.Module) – 要转换为函数式模块的模块。
funs_to_decorate (字符串的可迭代对象, 可选) – 每个字符串必须对应于属于模块的函数。对于嵌套模块,
torch.nn.Module.forward()
方法将被装饰。默认为"forward"
。keep_params (bool, 可选) – 如果为
True
,则模块将保留其参数。默认为False
。return_params (bool, 可选) – 如果为
True
,则参数将被收集到嵌套的 tensordict 中并返回。如果为False
,则模块将被转换为函数式模块,但仍保持有状态。