TensorClass¶
- class tensordict.TensorClass¶
TensorClass 是 @tensorclass 装饰器的基于继承的版本。
TensorClass 允许您编写类型检查更佳、更符合 Python 风格的数据类,相比于使用 @tensorclass 装饰器构建的数据类。
示例
>>> from typing import Any >>> import torch >>> from tensordict import TensorClass >>> class Foo(TensorClass): ... tensor: torch.Tensor ... non_tensor: Any ... nested: Any = None >>> foo = Foo(tensor=torch.randn(3), non_tensor="a string!", nested=None, batch_size=[3]) >>> print(foo) Foo( non_tensor=NonTensorData(data=a string!, batch_size=torch.Size([3]), device=None), tensor=Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float32, is_shared=False), nested=None, batch_size=torch.Size([3]), device=None, is_shared=False)
您可以通过两种方式传递关键字参数:使用方括号或直接使用关键字参数。
示例
>>> class Foo(TensorClass["autocast"]): ... integer: int >>> Foo(integer=torch.ones(())).integer 1 >>> class Foo(TensorClass, autocast=True): # equivalent ... integer: int >>> Foo(integer=torch.ones(())).integer 1 >>> class Foo(TensorClass["nocast"]): ... integer: int >>> Foo(integer=1).integer 1 >>> class Foo(TensorClass["nocast", "frozen"]): # multiple keywords can be used ... integer: int >>> Foo(integer=1).integer 1 >>> class Foo(TensorClass, nocast=True): # equivalent ... integer: int >>> Foo(integer=1).integer 1 >>> class Foo(TensorClass): ... integer: int >>> Foo(integer=1).integer tensor(1)
警告
TensorClass 本身没有被装饰为 tensorclass,但其子类会。这是因为我们无法预知 frozen 参数是否会被设置,如果设置了,它可能与父类冲突(如果父类未被冻结,子类也不能被冻结)。