快捷方式

tensorclass

@tensorclass 装饰器可帮助您构建自定义类,这些类继承了 TensorDict 的行为,同时能够将可能的条目限制为预定义的集合或为其类实现自定义方法。

TensorDict 类似,@tensorclass 支持嵌套、索引、重塑、项目赋值。它还支持 clonesqueezetorch.catsplit 等许多张量操作。@tensorclass 允许非张量条目,但所有张量操作都严格限制于张量属性。

需要为非张量数据实现自定义方法。重要的是要注意 @tensorclass 不强制执行严格的类型匹配。

>>> from __future__ import annotations
>>> from tensordict.prototype import tensorclass
>>> import torch
>>> from torch import nn
>>> from typing import Optional
>>>
>>> @tensorclass
... class MyData:
...     floatdata: torch.Tensor
...     intdata: torch.Tensor
...     non_tensordata: str
...     nested: Optional[MyData] = None
...
...     def check_nested(self):
...         assert self.nested is not None
>>>
>>> data = MyData(
...   floatdata=torch.randn(3, 4, 5),
...   intdata=torch.randint(10, (3, 4, 1)),
...   non_tensordata="test",
...   batch_size=[3, 4]
... )
>>> print("data:", data)
data: MyData(
  floatdata=Tensor(shape=torch.Size([3, 4, 5]), device=cpu, dtype=torch.float32, is_shared=False),
  intdata=Tensor(shape=torch.Size([3, 4, 1]), device=cpu, dtype=torch.int64, is_shared=False),
  non_tensordata='test',
  nested=None,
  batch_size=torch.Size([3, 4]),
  device=None,
  is_shared=False)
>>> data.nested = MyData(
...     floatdata = torch.randn(3, 4, 5),
...     intdata=torch.randint(10, (3, 4, 1)),
...     non_tensordata="nested_test",
...     batch_size=[3, 4]
... )
>>> print("nested:", data)
nested: MyData(
  floatdata=Tensor(shape=torch.Size([3, 4, 5]), device=cpu, dtype=torch.float32, is_shared=False),
  intdata=Tensor(shape=torch.Size([3, 4, 1]), device=cpu, dtype=torch.int64, is_shared=False),
  non_tensordata='test',
  nested=MyData(
      floatdata=Tensor(shape=torch.Size([3, 4, 5]), device=cpu, dtype=torch.float32, is_shared=False),
      intdata=Tensor(shape=torch.Size([3, 4, 1]), device=cpu, dtype=torch.int64, is_shared=False),
      non_tensordata='nested_test',
      nested=None,
      batch_size=torch.Size([3, 4]),
      device=None,
      is_shared=False),
  batch_size=torch.Size([3, 4]),
  device=None,
  is_shared=False)

TensorDict 一样,从 v0.4 开始,如果省略批量大小,则将其视为空。

如果提供了非空批量大小,@tensorclass 支持索引。在内部,张量对象会被索引,但非张量数据保持不变。

>>> print("indexed:", data[:2])
indexed: MyData(
   floatdata=Tensor(shape=torch.Size([2, 4, 5]), device=cpu, dtype=torch.float32, is_shared=False),
   intdata=Tensor(shape=torch.Size([2, 4, 1]), device=cpu, dtype=torch.int64, is_shared=False),
   non_tensordata='test',
   nested=MyData(
      floatdata=Tensor(shape=torch.Size([2, 4, 5]), device=cpu, dtype=torch.float32, is_shared=False),
      intdata=Tensor(shape=torch.Size([2, 4, 1]), device=cpu, dtype=torch.int64, is_shared=False),
      non_tensordata='nested_test',
      nested=None,
      batch_size=torch.Size([2, 4]),
      device=None,
      is_shared=False),
   batch_size=torch.Size([2, 4]),
   device=None,
   is_shared=False)

@tensorclass 还支持设置和重置属性,甚至对于嵌套对象也是如此。

>>> data.non_tensordata = "test_changed"
>>> print("data.non_tensordata: ", repr(data.non_tensordata))
data.non_tensordata: 'test_changed'

>>> data.floatdata = torch.ones(3, 4, 5)
>>> print("data.floatdata:", data.floatdata)
data.floatdata: tensor([[[1., 1., 1., 1., 1.],
       [1., 1., 1., 1., 1.],
       [1., 1., 1., 1., 1.],
       [1., 1., 1., 1., 1.]],

      [[1., 1., 1., 1., 1.],
       [1., 1., 1., 1., 1.],
       [1., 1., 1., 1., 1.],
       [1., 1., 1., 1., 1.]],

      [[1., 1., 1., 1., 1.],
       [1., 1., 1., 1., 1.],
       [1., 1., 1., 1., 1.],
       [1., 1., 1., 1., 1.]]])

>>> # Changing nested tensor data
>>> data.nested.non_tensordata = "nested_test_changed"
>>> print("data.nested.non_tensordata:", repr(data.nested.non_tensordata))
data.nested.non_tensordata: 'nested_test_changed'

@tensorclass 支持对其内容的形状和设备进行多种 torch 操作,例如 stackcatreshapeto(device)。要获取支持的操作的完整列表,请查看 tensordict 文档。

这是一个例子

>>> data2 = data.clone()
>>> cat_tc = torch.cat([data, data2], 0)
>>> print("Concatenated data:", catted_tc)
Concatenated data: MyData(
   floatdata=Tensor(shape=torch.Size([6, 4, 5]), device=cpu, dtype=torch.float32, is_shared=False),
   intdata=Tensor(shape=torch.Size([6, 4, 1]), device=cpu, dtype=torch.int64, is_shared=False),
   non_tensordata='test_changed',
   nested=MyData(
       floatdata=Tensor(shape=torch.Size([6, 4, 5]), device=cpu, dtype=torch.float32, is_shared=False),
       intdata=Tensor(shape=torch.Size([6, 4, 1]), device=cpu, dtype=torch.int64, is_shared=False),
       non_tensordata='nested_test_changed',
       nested=None,
       batch_size=torch.Size([6, 4]),
       device=None,
       is_shared=False),
   batch_size=torch.Size([6, 4]),
   device=None,
   is_shared=False)

序列化

保存 tensorclass 实例可以使用 memmap 方法实现。保存策略如下:张量数据将使用内存映射张量保存,可以使用 json 格式序列化的非张量数据将以这种方式保存。其他数据类型将使用 save() 保存,该方法依赖于 pickle

反序列化 tensorclass 可以通过 load_memmap() 完成。创建的实例将与保存的实例具有相同的类型,前提是 tensorclass 在工作环境中可用。

>>> data.memmap("path/to/saved/directory")
>>> data_loaded = TensorDict.load_memmap("path/to/saved/directory")
>>> assert isinstance(data_loaded, type(data))

边界情况

@tensorclass 支持相等和不等运算符,甚至对于嵌套对象也是如此。请注意,非张量/元数据未经验证。这将返回一个 tensor class 对象,其中张量属性具有布尔值,非张量属性具有 None。

这是一个例子

>>> print(data == data2)
MyData(
   floatdata=Tensor(shape=torch.Size([3, 4, 5]), device=cpu, dtype=torch.bool, is_shared=False),
   intdata=Tensor(shape=torch.Size([3, 4, 1]), device=cpu, dtype=torch.bool, is_shared=False),
   non_tensordata=None,
   nested=MyData(
       floatdata=Tensor(shape=torch.Size([3, 4, 5]), device=cpu, dtype=torch.bool, is_shared=False),
       intdata=Tensor(shape=torch.Size([3, 4, 1]), device=cpu, dtype=torch.bool, is_shared=False),
       non_tensordata=None,
       nested=None,
       batch_size=torch.Size([3, 4]),
       device=None,
       is_shared=False),
   batch_size=torch.Size([3, 4]),
   device=None,
   is_shared=False)

@tensorclass 支持设置项目。但是,在设置项目时,为了避免性能问题,对非张量/元数据进行身份检查而不是相等性检查。用户需要确保项目的非张量数据与对象匹配,以避免差异。

这是一个例子

在设置具有不同 non_tensor 数据的项目时,将引发 UserWarning

>>> data2.non_tensordata = "test_new"
>>> data[0] = data2[0]
UserWarning: Meta data at 'non_tensordata' may or may not be equal, this may result in undefined behaviours

尽管 @tensorclass 支持 cat()stack() 等 torch 函数,但非张量/元数据未经验证。torch 操作是在张量数据上执行的,并且在返回输出时,考虑第一个 tensor class 对象的非张量/元数据。用户需要确保所有 tensor class 对象列表具有相同的非张量数据,以避免差异。

这是一个例子

>>> data2.non_tensordata = "test_new"
>>> stack_tc = torch.cat([data, data2], dim=0)
>>> print(stack_tc)
MyData(
    floatdata=Tensor(shape=torch.Size([2, 3, 4, 5]), device=cpu, dtype=torch.float32, is_shared=False),
    intdata=Tensor(shape=torch.Size([2, 3, 4, 1]), device=cpu, dtype=torch.int64, is_shared=False),
    non_tensordata='test',
    nested=MyData(
        floatdata=Tensor(shape=torch.Size([2, 3, 4, 5]), device=cpu, dtype=torch.float32, is_shared=False),
        intdata=Tensor(shape=torch.Size([2, 3, 4, 1]), device=cpu, dtype=torch.int64, is_shared=False),
        non_tensordata='nested_test',
        nested=None,
        batch_size=torch.Size([2, 3, 4]),
        device=None,
        is_shared=False),
    batch_size=torch.Size([2, 3, 4]),
    device=None,
    is_shared=False)

@tensorclass 还支持预分配,您可以使用属性为 None 初始化对象,稍后设置它们。请注意,在初始化时,内部 None 属性将保存为非张量/元数据,而在重置时,根据属性值的类型,它将保存为张量数据或非张量/元数据。

这是一个例子

>>> @tensorclass
... class MyClass:
...   X: Any
...   y: Any

>>> data = MyClass(X=None, y=None, batch_size = [3,4])
>>> data.X = torch.ones(3, 4, 5)
>>> data.y = "testing"
>>> print(data)
MyClass(
   X=Tensor(shape=torch.Size([3, 4, 5]), device=cpu, dtype=torch.float32, is_shared=False),
   y='testing',
   batch_size=torch.Size([3, 4]),
   device=None,
   is_shared=False)

tensorclass([cls, autocast, frozen, nocast, ...])

一个创建 tensorclass 类的装饰器。

TensorClass()

TensorClass 是 @tensorclass 装饰器的基于继承的版本。

NonTensorData(data[, _metadata, ...])

NonTensorStack(*args, **kwargs)

LazyStackedTensorDict 的一个轻量级包装器,用于使非张量数据上的堆叠易于识别。

from_dataclass(obj, *[, auto_batch_size, ...])

将 dataclass 实例或类型分别转换为 tensorclass 实例或类型。

自动类型转换

警告

自动类型转换是一个实验性功能,将来可能会有所更改。与 python<=3.9 的兼容性有限。

@tensorclass 部分支持自动类型转换,作为实验性功能。诸如 __setattr__updateupdate_from_dict 之类的方法将尝试将类型标注的条目转换为所需的 TensorDict / tensorclass 实例(除了下述情况外)。例如,以下代码会将 td 字典转换为 TensorDict,将 tc 条目转换为 MyClass 实例。

>>> @tensorclass
... class MyClass:
...     tensor: torch.Tensor
...     td: TensorDict
...     tc: MyClass
...
>>> obj = MyClass(
...     tensor=torch.randn(()),
...     td={"a": torch.randn(())},
...     tc={"tensor": torch.randn(()), "td": None, "tc": None})
>>> assert isinstance(obj.tensor, torch.Tensor)
>>> assert isinstance(obj.tc, TensorDict)
>>> assert isinstance(obj.td, MyClass)

注意

包含 typing.Optionaltyping.Union 的类型标注项目将不兼容自动类型转换,但 tensorclass 中的其他项目将兼容。

>>> @tensorclass
... class MyClass:
...     tensor: torch.Tensor
...     tc_autocast: MyClass = None
...     tc_not_autocast: Optional[MyClass] = None
>>> obj = MyClass(
...     tensor=torch.randn(()),
...     tc_autocast={"tensor": torch.randn(())},
...     tc_not_autocast={"tensor": torch.randn(())},
... )
>>> assert isinstance(obj.tc_autocast, MyClass)
>>> # because the type is Optional or Union, auto-casting is disabled for
>>> # that variable.
>>> assert not isinstance(obj.tc_not_autocast, MyClass)

如果类中至少有一个项目使用 type0 | type1 语义进行标注,则整个类的自动类型转换功能将被禁用。因为 tensorclass 支持非张量叶节点,在这些情况下设置字典会导致将其设置为普通字典,而不是张量集合子类(TensorDicttensorclass)。

>>> @tensorclass
... class MyClass:
...     tensor: torch.Tensor
...     td: TensorDict
...     tc: MyClass | None
...
>>> obj = MyClass(
...     tensor=torch.randn(()),
...     td={"a": torch.randn(())},
...     tc={"tensor": torch.randn(()), "td": None, "tc": None})
>>> assert isinstance(obj.tensor, torch.Tensor)
>>> # tc and td have not been cast
>>> assert isinstance(obj.tc, dict)
>>> assert isinstance(obj.td, dict)

注意

叶节点(张量)未启用自动类型转换。原因是此功能与包含 type0 | type1 类型提示语义的类型标注不兼容,而这种语义很普遍。如果类型标注仅略有不同,允许自动类型转换将导致非常相似的代码具有截然不同的行为。

文档

访问 PyTorch 的完整开发者文档

查看文档

教程

获取面向初学者和高级开发者的深度教程

查看教程

资源

查找开发资源并获得问题解答

查看资源