tensordict.nn.TensorDictParams¶
- class tensordict.nn.TensorDictParams(parameters: TensorDictBase, *, no_convert=False, lock: bool = False)¶
保存一个包含参数的 TensorDictBase 实例。
此类将包含的参数公开给父 nn.Module,以便迭代模块的参数也会迭代 tensordict 的叶子节点。
索引与包装 tensordict 的索引方式完全一致。参数名称将在本模块内使用
flatten_keys("_")()
注册。因此,named_parameters()
的结果和 tensordict 的内容在键名称方面会有所不同。在 tensordict 中设置张量的任何操作都将通过
torch.nn.Parameter
转换进行增强。- 参数::
parameters (TensorDictBase) – 要表示为参数的 tensordict。除非
no_convert=True
,否则值将转换为参数。- 关键字参数::
示例
>>> from torch import nn >>> from tensordict import TensorDict >>> module = nn.Sequential(nn.Linear(3, 4), nn.Linear(4, 4)) >>> params = TensorDict.from_module(module) >>> params.lock_() >>> p = TensorDictParams(params) >>> print(p) TensorDictParams(params=TensorDict( fields={ 0: TensorDict( fields={ bias: Parameter(shape=torch.Size([4]), device=cpu, dtype=torch.float32, is_shared=False), weight: Parameter(shape=torch.Size([4, 3]), device=cpu, dtype=torch.float32, is_shared=False)}, batch_size=torch.Size([]), device=None, is_shared=False), 1: TensorDict( fields={ bias: Parameter(shape=torch.Size([4]), device=cpu, dtype=torch.float32, is_shared=False), weight: Parameter(shape=torch.Size([4, 4]), device=cpu, dtype=torch.float32, is_shared=False)}, batch_size=torch.Size([]), device=None, is_shared=False)}, batch_size=torch.Size([]), device=None, is_shared=False)) >>> class CustomModule(nn.Module): ... def __init__(self, params): ... super().__init__() ... self.params = params >>> m = CustomModule(p) >>> # the wrapper supports assignment and values are turned in Parameter >>> m.params['other'] = torch.randn(3) >>> assert isinstance(m.params['other'], nn.Parameter)