快捷方式

tensorclass

@tensorclass 装饰器帮助您构建自定义类,这些类继承自 TensorDict 的行为,同时能够将可能的条目限制为预定义的集合或为您的类实现自定义方法。

TensorDict 类似,@tensorclass 支持嵌套、索引、重塑、项目赋值。它还支持张量操作,如 clonesqueezetorch.catsplit 等等。@tensorclass 允许非张量条目,但是所有张量操作都严格限制为张量属性。

需要为非张量数据实现自定义方法。重要的是要注意,@tensorclass 不强制执行严格的类型匹配

>>> from __future__ import annotations
>>> from tensordict.prototype import tensorclass
>>> import torch
>>> from torch import nn
>>> from typing import Optional
>>>
>>> @tensorclass
... class MyData:
...     floatdata: torch.Tensor
...     intdata: torch.Tensor
...     non_tensordata: str
...     nested: Optional[MyData] = None
...
...     def check_nested(self):
...         assert self.nested is not None
>>>
>>> data = MyData(
...   floatdata=torch.randn(3, 4, 5),
...   intdata=torch.randint(10, (3, 4, 1)),
...   non_tensordata="test",
...   batch_size=[3, 4]
... )
>>> print("data:", data)
data: MyData(
  floatdata=Tensor(shape=torch.Size([3, 4, 5]), device=cpu, dtype=torch.float32, is_shared=False),
  intdata=Tensor(shape=torch.Size([3, 4, 1]), device=cpu, dtype=torch.int64, is_shared=False),
  non_tensordata='test',
  nested=None,
  batch_size=torch.Size([3, 4]),
  device=None,
  is_shared=False)
>>> data.nested = MyData(
...     floatdata = torch.randn(3, 4, 5),
...     intdata=torch.randint(10, (3, 4, 1)),
...     non_tensordata="nested_test",
...     batch_size=[3, 4]
... )
>>> print("nested:", data)
nested: MyData(
  floatdata=Tensor(shape=torch.Size([3, 4, 5]), device=cpu, dtype=torch.float32, is_shared=False),
  intdata=Tensor(shape=torch.Size([3, 4, 1]), device=cpu, dtype=torch.int64, is_shared=False),
  non_tensordata='test',
  nested=MyData(
      floatdata=Tensor(shape=torch.Size([3, 4, 5]), device=cpu, dtype=torch.float32, is_shared=False),
      intdata=Tensor(shape=torch.Size([3, 4, 1]), device=cpu, dtype=torch.int64, is_shared=False),
      non_tensordata='nested_test',
      nested=None,
      batch_size=torch.Size([3, 4]),
      device=None,
      is_shared=False),
  batch_size=torch.Size([3, 4]),
  device=None,
  is_shared=False)

TensorDict 的情况一样,从 v0.4 开始,如果省略批次大小,则将其视为空。

如果提供了非空的批次大小,@tensorclass 支持索引。在内部,张量对象会被索引,但是非张量数据保持不变

>>> print("indexed:", data[:2])
indexed: MyData(
   floatdata=Tensor(shape=torch.Size([2, 4, 5]), device=cpu, dtype=torch.float32, is_shared=False),
   intdata=Tensor(shape=torch.Size([2, 4, 1]), device=cpu, dtype=torch.int64, is_shared=False),
   non_tensordata='test',
   nested=MyData(
      floatdata=Tensor(shape=torch.Size([2, 4, 5]), device=cpu, dtype=torch.float32, is_shared=False),
      intdata=Tensor(shape=torch.Size([2, 4, 1]), device=cpu, dtype=torch.int64, is_shared=False),
      non_tensordata='nested_test',
      nested=None,
      batch_size=torch.Size([2, 4]),
      device=None,
      is_shared=False),
   batch_size=torch.Size([2, 4]),
   device=None,
   is_shared=False)

@tensorclass 还支持设置和重置属性,即使对于嵌套对象也是如此。

>>> data.non_tensordata = "test_changed"
>>> print("data.non_tensordata: ", repr(data.non_tensordata))
data.non_tensordata: 'test_changed'

>>> data.floatdata = torch.ones(3, 4, 5)
>>> print("data.floatdata:", data.floatdata)
data.floatdata: tensor([[[1., 1., 1., 1., 1.],
       [1., 1., 1., 1., 1.],
       [1., 1., 1., 1., 1.],
       [1., 1., 1., 1., 1.]],

      [[1., 1., 1., 1., 1.],
       [1., 1., 1., 1., 1.],
       [1., 1., 1., 1., 1.],
       [1., 1., 1., 1., 1.]],

      [[1., 1., 1., 1., 1.],
       [1., 1., 1., 1., 1.],
       [1., 1., 1., 1., 1.],
       [1., 1., 1., 1., 1.]]])

>>> # Changing nested tensor data
>>> data.nested.non_tensordata = "nested_test_changed"
>>> print("data.nested.non_tensordata:", repr(data.nested.non_tensordata))
data.nested.non_tensordata: 'nested_test_changed'

@tensorclass 支持对其内容的形状和设备进行多项 torch 操作,例如 stackcatreshapeto(device)。要获取支持操作的完整列表,请查看 tensordict 文档。

这是一个例子

>>> data2 = data.clone()
>>> cat_tc = torch.cat([data, data2], 0)
>>> print("Concatenated data:", catted_tc)
Concatenated data: MyData(
   floatdata=Tensor(shape=torch.Size([6, 4, 5]), device=cpu, dtype=torch.float32, is_shared=False),
   intdata=Tensor(shape=torch.Size([6, 4, 1]), device=cpu, dtype=torch.int64, is_shared=False),
   non_tensordata='test_changed',
   nested=MyData(
       floatdata=Tensor(shape=torch.Size([6, 4, 5]), device=cpu, dtype=torch.float32, is_shared=False),
       intdata=Tensor(shape=torch.Size([6, 4, 1]), device=cpu, dtype=torch.int64, is_shared=False),
       non_tensordata='nested_test_changed',
       nested=None,
       batch_size=torch.Size([6, 4]),
       device=None,
       is_shared=False),
   batch_size=torch.Size([6, 4]),
   device=None,
   is_shared=False)

序列化

可以使用 memmap 方法实现保存 tensorclass 实例。保存策略如下:张量数据将使用内存映射张量保存,可以使用 json 格式序列化的非张量数据将以这种方式保存。其他数据类型将使用 save() 保存,这依赖于 pickle

可以通过 load_memmap() 反序列化 tensorclass。创建的实例将具有与保存的实例相同的类型,前提是 tensorclass 在工作环境中可用

>>> data.memmap("path/to/saved/directory")
>>> data_loaded = TensorDict.load_memmap("path/to/saved/directory")
>>> assert isinstance(data_loaded, type(data))

边缘情况

@tensorclass 支持相等和不等运算符,即使对于嵌套对象也是如此。请注意,非张量/元数据未经过验证。这将返回一个张量类对象,其中张量属性具有布尔值,非张量属性具有 None 值

这是一个例子

>>> print(data == data2)
MyData(
   floatdata=Tensor(shape=torch.Size([3, 4, 5]), device=cpu, dtype=torch.bool, is_shared=False),
   intdata=Tensor(shape=torch.Size([3, 4, 1]), device=cpu, dtype=torch.bool, is_shared=False),
   non_tensordata=None,
   nested=MyData(
       floatdata=Tensor(shape=torch.Size([3, 4, 5]), device=cpu, dtype=torch.bool, is_shared=False),
       intdata=Tensor(shape=torch.Size([3, 4, 1]), device=cpu, dtype=torch.bool, is_shared=False),
       non_tensordata=None,
       nested=None,
       batch_size=torch.Size([3, 4]),
       device=None,
       is_shared=False),
   batch_size=torch.Size([3, 4]),
   device=None,
   is_shared=False)

@tensorclass 支持设置项目。但是,在设置项目时,会执行非张量/元数据的身份检查,而不是相等性检查,以避免性能问题。用户需要确保项目的非张量数据与对象匹配,以避免差异。

这是一个例子

当使用不同的 non_tensor 数据设置项目时,将抛出 UserWarning

>>> data2.non_tensordata = "test_new"
>>> data[0] = data2[0]
UserWarning: Meta data at 'non_tensordata' may or may not be equal, this may result in undefined behaviours

即使 @tensorclass 支持 torch 函数,如 cat()stack(),非张量/元数据也未经过验证。torch 操作在张量数据上执行,并且在返回输出时,会考虑第一个张量类对象的非张量/元数据。用户需要确保所有张量类对象的列表都具有相同的非张量数据,以避免差异

这是一个例子

>>> data2.non_tensordata = "test_new"
>>> stack_tc = torch.cat([data, data2], dim=0)
>>> print(stack_tc)
MyData(
    floatdata=Tensor(shape=torch.Size([2, 3, 4, 5]), device=cpu, dtype=torch.float32, is_shared=False),
    intdata=Tensor(shape=torch.Size([2, 3, 4, 1]), device=cpu, dtype=torch.int64, is_shared=False),
    non_tensordata='test',
    nested=MyData(
        floatdata=Tensor(shape=torch.Size([2, 3, 4, 5]), device=cpu, dtype=torch.float32, is_shared=False),
        intdata=Tensor(shape=torch.Size([2, 3, 4, 1]), device=cpu, dtype=torch.int64, is_shared=False),
        non_tensordata='nested_test',
        nested=None,
        batch_size=torch.Size([2, 3, 4]),
        device=None,
        is_shared=False),
    batch_size=torch.Size([2, 3, 4]),
    device=None,
    is_shared=False)

@tensorclass 也支持预分配,您可以初始化对象,使其属性为 None,然后在稍后设置它们。请注意,在初始化时,内部会将 None 属性保存为非张量/元数据,而在重置时,将根据属性值的类型将其保存为张量数据或非张量/元数据

这是一个例子

>>> @tensorclass
... class MyClass:
...   X: Any
...   y: Any

>>> data = MyClass(X=None, y=None, batch_size = [3,4])
>>> data.X = torch.ones(3, 4, 5)
>>> data.y = "testing"
>>> print(data)
MyClass(
   X=Tensor(shape=torch.Size([3, 4, 5]), device=cpu, dtype=torch.float32, is_shared=False),
   y='testing',
   batch_size=torch.Size([3, 4]),
   device=None,
   is_shared=False)

tensorclass([cls, autocast, frozen])

用于创建 tensorclass 类的装饰器。

NonTensorData(data[, _metadata, ...])

NonTensorStack(*args, **kwargs)

LazyStackedTensorDict 的一个薄包装器,使非张量数据的堆叠易于识别。

自动类型转换

警告

自动类型转换是一项实验性功能,未来可能会发生变化。与 python<=3.9 的兼容性有限。

@tensorclass 部分支持自动类型转换作为一项实验性功能。诸如 __setattr__updateupdate_from_dict 等方法将尝试将类型注释的条目转换为所需的 TensorDict/tensorclass 实例(除非在下面详述的情况下)。例如,以下代码会将 td 字典转换为 TensorDict,并将 tc 条目转换为 MyClass 实例

>>> @tensorclass
... class MyClass:
...     tensor: torch.Tensor
...     td: TensorDict
...     tc: MyClass
...
>>> obj = MyClass(
...     tensor=torch.randn(()),
...     td={"a": torch.randn(())},
...     tc={"tensor": torch.randn(()), "td": None, "tc": None})
>>> assert isinstance(obj.tensor, torch.Tensor)
>>> assert isinstance(obj.tc, TensorDict)
>>> assert isinstance(obj.td, MyClass)

注意

包含 typing.Optionaltyping.Union 的类型注释项目将与自动类型转换不兼容,但 tensorclass 中的其他项目将兼容

>>> @tensorclass
... class MyClass:
...     tensor: torch.Tensor
...     tc_autocast: MyClass = None
...     tc_not_autocast: Optional[MyClass] = None
>>> obj = MyClass(
...     tensor=torch.randn(()),
...     tc_autocast={"tensor": torch.randn(())},
...     tc_not_autocast={"tensor": torch.randn(())},
... )
>>> assert isinstance(obj.tc_autocast, MyClass)
>>> # because the type is Optional or Union, auto-casting is disabled for
>>> # that variable.
>>> assert not isinstance(obj.tc_not_autocast, MyClass)

如果类中至少有一个项目使用 type0 | type1 语义进行注释,则整个类的自动类型转换功能将被禁用。因为 tensorclass 支持非张量叶节点,所以在这些情况下设置字典将导致将其设置为普通字典,而不是张量集合子类(TensorDicttensorclass

>>> @tensorclass
... class MyClass:
...     tensor: torch.Tensor
...     td: TensorDict
...     tc: MyClass | None
...
>>> obj = MyClass(
...     tensor=torch.randn(()),
...     td={"a": torch.randn(())},
...     tc={"tensor": torch.randn(()), "td": None, "tc": None})
>>> assert isinstance(obj.tensor, torch.Tensor)
>>> # tc and td have not been cast
>>> assert isinstance(obj.tc, dict)
>>> assert isinstance(obj.td, dict)

注意

自动类型转换不适用于叶节点(张量)。原因是此功能与包含 type0 | type1 类型提示语义的类型注释不兼容,而这种语义很普遍。如果类型注释仅略有不同,则允许自动类型转换将导致非常相似的代码具有截然不同的行为。

文档

访问 PyTorch 的全面开发者文档

查看文档

教程

获取面向初学者和高级开发者的深入教程

查看教程

资源

查找开发资源并获得解答

查看资源