快捷方式

tensordict.nn.TensorDictParams

class tensordict.nn.TensorDictParams(parameters: TensorDictBase, *, no_convert=False, lock: bool = False)

保存一个包含参数的 TensorDictBase 实例。

此类将包含的参数公开给父 nn.Module,以便迭代模块的参数也会迭代 tensordict 的叶子节点。

索引与包装 tensordict 的索引方式完全一致。参数名称将在本模块内使用 flatten_keys("_")() 注册。因此,named_parameters() 的结果和 tensordict 的内容在键名称方面会有所不同。

在 tensordict 中设置张量的任何操作都将通过 torch.nn.Parameter 转换进行增强。

参数::

parameters (TensorDictBase) – 要表示为参数的 tensordict。除非 no_convert=True,否则值将转换为参数。

关键字参数::
  • no_convert (bool) – 如果为 True,则在构造期间和之后不会进行 nn.Parameter 转换(除非更改了 no_convert 属性)。如果 no_convertTrue 并且存在非参数,则它们将注册为缓冲区。默认为 False

  • lock (bool) – 如果为 True,则 TensorDictParams 托管的 tensordict 将被锁定。这对于避免意外修改非常有用,但也限制了可以在对象上执行的操作(并且在需要 unlock_() 时会对性能产生重大影响)。默认为 False

示例

>>> from torch import nn
>>> from tensordict import TensorDict
>>> module = nn.Sequential(nn.Linear(3, 4), nn.Linear(4, 4))
>>> params = TensorDict.from_module(module)
>>> params.lock_()
>>> p = TensorDictParams(params)
>>> print(p)
TensorDictParams(params=TensorDict(
    fields={
        0: TensorDict(
            fields={
                bias: Parameter(shape=torch.Size([4]), device=cpu, dtype=torch.float32, is_shared=False),
                weight: Parameter(shape=torch.Size([4, 3]), device=cpu, dtype=torch.float32, is_shared=False)},
            batch_size=torch.Size([]),
            device=None,
            is_shared=False),
        1: TensorDict(
            fields={
                bias: Parameter(shape=torch.Size([4]), device=cpu, dtype=torch.float32, is_shared=False),
                weight: Parameter(shape=torch.Size([4, 4]), device=cpu, dtype=torch.float32, is_shared=False)},
            batch_size=torch.Size([]),
            device=None,
            is_shared=False)},
    batch_size=torch.Size([]),
    device=None,
    is_shared=False))
>>> class CustomModule(nn.Module):
...     def __init__(self, params):
...         super().__init__()
...         self.params = params
>>> m = CustomModule(p)
>>> # the wrapper supports assignment and values are turned in Parameter
>>> m.params['other'] = torch.randn(3)
>>> assert isinstance(m.params['other'], nn.Parameter)

文档

访问 PyTorch 的全面开发者文档

查看文档

教程

获取面向初学者和高级开发人员的深入教程

查看教程

资源

查找开发资源并获得问题的解答

查看资源