快捷方式

tensorclass

@tensorclass 装饰器可以帮助您构建自定义类,这些类继承自 TensorDict 的行为,同时能够将可能的条目限制为预定义的集合或为您的类实现自定义方法。

TensorDict 一样,@tensorclass 支持嵌套、索引、重塑和项目赋值。它还支持张量操作,如 clonesqueezetorch.catsplit 等。 @tensorclass 允许非张量条目,但是所有张量操作都严格限制在张量属性上。

需要为非张量数据实现自定义方法。需要注意的是,@tensorclass 并不强制执行严格的类型匹配。

>>> from __future__ import annotations
>>> from tensordict.prototype import tensorclass
>>> import torch
>>> from torch import nn
>>> from typing import Optional
>>>
>>> @tensorclass
... class MyData:
...     floatdata: torch.Tensor
...     intdata: torch.Tensor
...     non_tensordata: str
...     nested: Optional[MyData] = None
...
...     def check_nested(self):
...         assert self.nested is not None
>>>
>>> data = MyData(
...   floatdata=torch.randn(3, 4, 5),
...   intdata=torch.randint(10, (3, 4, 1)),
...   non_tensordata="test",
...   batch_size=[3, 4]
... )
>>> print("data:", data)
data: MyData(
  floatdata=Tensor(shape=torch.Size([3, 4, 5]), device=cpu, dtype=torch.float32, is_shared=False),
  intdata=Tensor(shape=torch.Size([3, 4, 1]), device=cpu, dtype=torch.int64, is_shared=False),
  non_tensordata='test',
  nested=None,
  batch_size=torch.Size([3, 4]),
  device=None,
  is_shared=False)
>>> data.nested = MyData(
...     floatdata = torch.randn(3, 4, 5),
...     intdata=torch.randint(10, (3, 4, 1)),
...     non_tensordata="nested_test",
...     batch_size=[3, 4]
... )
>>> print("nested:", data)
nested: MyData(
  floatdata=Tensor(shape=torch.Size([3, 4, 5]), device=cpu, dtype=torch.float32, is_shared=False),
  intdata=Tensor(shape=torch.Size([3, 4, 1]), device=cpu, dtype=torch.int64, is_shared=False),
  non_tensordata='test',
  nested=MyData(
      floatdata=Tensor(shape=torch.Size([3, 4, 5]), device=cpu, dtype=torch.float32, is_shared=False),
      intdata=Tensor(shape=torch.Size([3, 4, 1]), device=cpu, dtype=torch.int64, is_shared=False),
      non_tensordata='nested_test',
      nested=None,
      batch_size=torch.Size([3, 4]),
      device=None,
      is_shared=False),
  batch_size=torch.Size([3, 4]),
  device=None,
  is_shared=False)

TensorDict 一样,从 v0.4 开始,如果省略批大小,则将其视为为空。

如果提供了非空批大小,@tensorclass 支持索引。在内部,张量对象会被索引,但是非张量数据保持不变。

>>> print("indexed:", data[:2])
indexed: MyData(
   floatdata=Tensor(shape=torch.Size([2, 4, 5]), device=cpu, dtype=torch.float32, is_shared=False),
   intdata=Tensor(shape=torch.Size([2, 4, 1]), device=cpu, dtype=torch.int64, is_shared=False),
   non_tensordata='test',
   nested=MyData(
      floatdata=Tensor(shape=torch.Size([2, 4, 5]), device=cpu, dtype=torch.float32, is_shared=False),
      intdata=Tensor(shape=torch.Size([2, 4, 1]), device=cpu, dtype=torch.int64, is_shared=False),
      non_tensordata='nested_test',
      nested=None,
      batch_size=torch.Size([2, 4]),
      device=None,
      is_shared=False),
   batch_size=torch.Size([2, 4]),
   device=None,
   is_shared=False)

@tensorclass 还支持设置和重置属性,即使对于嵌套对象也是如此。

>>> data.non_tensordata = "test_changed"
>>> print("data.non_tensordata: ", repr(data.non_tensordata))
data.non_tensordata: 'test_changed'

>>> data.floatdata = torch.ones(3, 4, 5)
>>> print("data.floatdata:", data.floatdata)
data.floatdata: tensor([[[1., 1., 1., 1., 1.],
       [1., 1., 1., 1., 1.],
       [1., 1., 1., 1., 1.],
       [1., 1., 1., 1., 1.]],

      [[1., 1., 1., 1., 1.],
       [1., 1., 1., 1., 1.],
       [1., 1., 1., 1., 1.],
       [1., 1., 1., 1., 1.]],

      [[1., 1., 1., 1., 1.],
       [1., 1., 1., 1., 1.],
       [1., 1., 1., 1., 1.],
       [1., 1., 1., 1., 1.]]])

>>> # Changing nested tensor data
>>> data.nested.non_tensordata = "nested_test_changed"
>>> print("data.nested.non_tensordata:", repr(data.nested.non_tensordata))
data.nested.non_tensordata: 'nested_test_changed'

@tensorclass 支持对其内容的形状和设备进行多种 torch 操作,例如 stackcatreshapeto(device)。要获取支持的操作的完整列表,请查看 tensordict 文档。

以下是一个示例

>>> data2 = data.clone()
>>> cat_tc = torch.cat([data, data2], 0)
>>> print("Concatenated data:", catted_tc)
Concatenated data: MyData(
   floatdata=Tensor(shape=torch.Size([6, 4, 5]), device=cpu, dtype=torch.float32, is_shared=False),
   intdata=Tensor(shape=torch.Size([6, 4, 1]), device=cpu, dtype=torch.int64, is_shared=False),
   non_tensordata='test_changed',
   nested=MyData(
       floatdata=Tensor(shape=torch.Size([6, 4, 5]), device=cpu, dtype=torch.float32, is_shared=False),
       intdata=Tensor(shape=torch.Size([6, 4, 1]), device=cpu, dtype=torch.int64, is_shared=False),
       non_tensordata='nested_test_changed',
       nested=None,
       batch_size=torch.Size([6, 4]),
       device=None,
       is_shared=False),
   batch_size=torch.Size([6, 4]),
   device=None,
   is_shared=False)

序列化

可以使用 memmap 方法保存 tensorclass 实例。保存策略如下:张量数据将使用内存映射张量保存,可以使用 json 格式序列化的非张量数据将按此方式保存。其他数据类型将使用 save() 保存,该方法依赖于 pickle

可以通过 load_memmap() 反序列化 tensorclass。创建的实例将与保存的实例具有相同的类型,前提是在工作环境中提供了 tensorclass

>>> data.memmap("path/to/saved/directory")
>>> data_loaded = TensorDict.load_memmap("path/to/saved/directory")
>>> assert isinstance(data_loaded, type(data))

极端情况

@tensorclass 支持相等和不相等运算符,即使对于嵌套对象也是如此。请注意,不会验证非张量/元数据。这将返回一个 tensor 类对象,该对象对张量属性为布尔值,对非张量属性为 None。

以下是一个示例

>>> print(data == data2)
MyData(
   floatdata=Tensor(shape=torch.Size([3, 4, 5]), device=cpu, dtype=torch.bool, is_shared=False),
   intdata=Tensor(shape=torch.Size([3, 4, 1]), device=cpu, dtype=torch.bool, is_shared=False),
   non_tensordata=None,
   nested=MyData(
       floatdata=Tensor(shape=torch.Size([3, 4, 5]), device=cpu, dtype=torch.bool, is_shared=False),
       intdata=Tensor(shape=torch.Size([3, 4, 1]), device=cpu, dtype=torch.bool, is_shared=False),
       non_tensordata=None,
       nested=None,
       batch_size=torch.Size([3, 4]),
       device=None,
       is_shared=False),
   batch_size=torch.Size([3, 4]),
   device=None,
   is_shared=False)

@tensorclass 支持设置项目。但是,在设置项目时,会执行非张量/元数据的身份检查而不是相等性检查,以避免性能问题。用户需要确保项目的非张量数据与对象匹配,以避免出现差异。

以下是一个示例

在使用不同的 non_tensor 数据设置项目时,将抛出 UserWarning

>>> data2.non_tensordata = "test_new"
>>> data[0] = data2[0]
UserWarning: Meta data at 'non_tensordata' may or may not be equal, this may result in undefined behaviours

尽管 @tensorclass 支持 cat()stack() 等 torch 函数,但不会验证非张量/元数据。torch 操作是在张量数据上执行的,在返回输出时,会考虑第一个 tensor 类对象的非张量/元数据。用户需要确保所有 tensor 类对象的列表都具有相同的非张量数据,以避免出现差异。

以下是一个示例

>>> data2.non_tensordata = "test_new"
>>> stack_tc = torch.cat([data, data2], dim=0)
>>> print(stack_tc)
MyData(
    floatdata=Tensor(shape=torch.Size([2, 3, 4, 5]), device=cpu, dtype=torch.float32, is_shared=False),
    intdata=Tensor(shape=torch.Size([2, 3, 4, 1]), device=cpu, dtype=torch.int64, is_shared=False),
    non_tensordata='test',
    nested=MyData(
        floatdata=Tensor(shape=torch.Size([2, 3, 4, 5]), device=cpu, dtype=torch.float32, is_shared=False),
        intdata=Tensor(shape=torch.Size([2, 3, 4, 1]), device=cpu, dtype=torch.int64, is_shared=False),
        non_tensordata='nested_test',
        nested=None,
        batch_size=torch.Size([2, 3, 4]),
        device=None,
        is_shared=False),
    batch_size=torch.Size([2, 3, 4]),
    device=None,
    is_shared=False)

@tensorclass 还支持预分配,您可以使用属性为 None 初始化对象,然后稍后设置它们。请注意,在初始化时,内部将 None 属性保存为非张量/元数据,而在重置时,根据属性值的类型,它将保存为张量数据或非张量/元数据。

以下是一个示例

>>> @tensorclass
... class MyClass:
...   X: Any
...   y: Any

>>> data = MyClass(X=None, y=None, batch_size = [3,4])
>>> data.X = torch.ones(3, 4, 5)
>>> data.y = "testing"
>>> print(data)
MyClass(
   X=Tensor(shape=torch.Size([3, 4, 5]), device=cpu, dtype=torch.float32, is_shared=False),
   y='testing',
   batch_size=torch.Size([3, 4]),
   device=None,
   is_shared=False)

tensorclass([autocast])

一个用于创建 tensorclass 类的装饰器。

NonTensorData(data[, _metadata, ...])

NonTensorStack(*args, **kwargs)

一个围绕 LazyStackedTensorDict 的薄包装器,使非张量数据上的堆叠易于识别。

自动转换

警告

自动转换是一个实验性功能,将来可能会发生变化。与 python<=3.9 的兼容性有限。

@tensorclass 作为实验性功能部分支持自动转换。诸如 __setattr__updateupdate_from_dict 之类的方法将尝试将类型注释的条目转换为所需的 TensorDict/tensorclass 实例(除了下面详细说明的情况)。例如,以下代码将 td 字典转换为 TensorDict,并将 tc 条目转换为 MyClass 实例。

>>> @tensorclass
... class MyClass:
...     tensor: torch.Tensor
...     td: TensorDict
...     tc: MyClass
...
>>> obj = MyClass(
...     tensor=torch.randn(()),
...     td={"a": torch.randn(())},
...     tc={"tensor": torch.randn(()), "td": None, "tc": None})
>>> assert isinstance(obj.tensor, torch.Tensor)
>>> assert isinstance(obj.tc, TensorDict)
>>> assert isinstance(obj.td, MyClass)

注意

包含 typing.Optionaltyping.Union 的类型注释项目将与自动转换不兼容,但 tensorclass 中的其他项目将兼容。

>>> @tensorclass
... class MyClass:
...     tensor: torch.Tensor
...     tc_autocast: MyClass = None
...     tc_not_autocast: Optional[MyClass] = None
>>> obj = MyClass(
...     tensor=torch.randn(()),
...     tc_autocast={"tensor": torch.randn(())},
...     tc_not_autocast={"tensor": torch.randn(())},
... )
>>> assert isinstance(obj.tc_autocast, MyClass)
>>> # because the type is Optional or Union, auto-casting is disabled for
>>> # that variable.
>>> assert not isinstance(obj.tc_not_autocast, MyClass)

如果类中的至少一个项目使用 type0 | type1 语义进行注释,则整个类的自动转换功能将被禁用。因为 tensorclass 支持非张量叶子,所以在这些情况下设置字典将导致将其设置为普通字典而不是张量集合子类(TensorDicttensorclass)。

>>> @tensorclass
... class MyClass:
...     tensor: torch.Tensor
...     td: TensorDict
...     tc: MyClass | None
...
>>> obj = MyClass(
...     tensor=torch.randn(()),
...     td={"a": torch.randn(())},
...     tc={"tensor": torch.randn(()), "td": None, "tc": None})
>>> assert isinstance(obj.tensor, torch.Tensor)
>>> # tc and td have not been cast
>>> assert isinstance(obj.tc, dict)
>>> assert isinstance(obj.td, dict)

注意

叶子(张量)未启用自动转换。这样做的原因是此功能与包含 type0 | type1 类型提示语义的类型注释不兼容,而这种注释非常普遍。如果类型注释仅略有不同,允许自动转换会导致非常相似的代码具有截然不同的行为。

文档

访问 PyTorch 的全面开发者文档

查看文档

教程

获取适合初学者和高级开发人员的深入教程

查看教程

资源

查找开发资源并获得问题的解答

查看资源