快捷方式

TensorClass

class tensordict.TensorClass

TensorClass 是 @tensorclass 装饰器的基于继承的版本。

TensorClass 允许您编写类型检查更佳、更符合 Python 风格的数据类,相比于使用 @tensorclass 装饰器构建的数据类。

示例

>>> from typing import Any
>>> import torch
>>> from tensordict import TensorClass
>>> class Foo(TensorClass):
...     tensor: torch.Tensor
...     non_tensor: Any
...     nested: Any = None
>>> foo = Foo(tensor=torch.randn(3), non_tensor="a string!", nested=None, batch_size=[3])
>>> print(foo)
Foo(
    non_tensor=NonTensorData(data=a string!, batch_size=torch.Size([3]), device=None),
    tensor=Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float32, is_shared=False),
    nested=None,
    batch_size=torch.Size([3]),
    device=None,
    is_shared=False)

您可以通过两种方式传递关键字参数:使用方括号或直接使用关键字参数。

示例

>>> class Foo(TensorClass["autocast"]):
...     integer: int
>>> Foo(integer=torch.ones(())).integer
1
>>> class Foo(TensorClass, autocast=True):  # equivalent
...     integer: int
>>> Foo(integer=torch.ones(())).integer
1
>>> class Foo(TensorClass["nocast"]):
...     integer: int
>>> Foo(integer=1).integer
1
>>> class Foo(TensorClass["nocast", "frozen"]):  # multiple keywords can be used
...     integer: int
>>> Foo(integer=1).integer
1
>>> class Foo(TensorClass, nocast=True):  # equivalent
...     integer: int
>>> Foo(integer=1).integer
1
>>> class Foo(TensorClass):
...     integer: int
>>> Foo(integer=1).integer
tensor(1)

警告

TensorClass 本身没有被装饰为 tensorclass,但其子类会。这是因为我们无法预知 frozen 参数是否会被设置,如果设置了,它可能与父类冲突(如果父类未被冻结,子类也不能被冻结)。

文档

获取 PyTorch 的全面开发者文档

查看文档

教程

获取针对初学者和高级开发者的深入教程

查看教程

资源

查找开发资源并获得问题解答

查看资源