快捷方式

tensorclass

class tensordict.tensorclass(cls=None, /, *, autocast: bool = False, frozen: bool = False)

用于创建 tensorclass 类的装饰器。

tensorclass 类是专门的 dataclasses.dataclass() 实例,可以开箱即用地执行一些预定义的张量操作,例如索引、项目赋值、重塑、转换为设备或存储以及许多其他操作。

参数:
  • autocast (bool, 可选) – 如果 True,则在设置参数时将强制执行指示的类型。默认为 False

  • frozen (bool, 可选) – 如果 True,则无法修改 tensorclass 的内容。提供此参数是为了与 dataclass 兼容,可以通过类构造函数中的 lock 参数获得类似的行为。默认为 False

tensorclass 可以带参数或不带参数使用: .. rubric:: 示例

>>> @tensorclass
... class X:
...     y: torch.Tensor
>>> X(1).y
1
>>> @tensorclass(autocast=False)
... class X:
...     y: torch.Tensor
>>> X(1).y
1
>>> @tensorclass(autocast=True)
... class X:
...     y: torch.Tensor
>>> X(1).y
torch.tensor(1)

示例

>>> from tensordict import tensorclass
>>> import torch
>>> from typing import Optional
>>>
>>> @tensorclass
... class MyData:
...     X: torch.Tensor
...     y: torch.Tensor
...     z: str
...     def expand_and_mask(self):
...         X = self.X.unsqueeze(-1).expand_as(self.y)
...         X = X[self.y]
...         return X
...
>>> data = MyData(
...     X=torch.ones(3, 4, 1),
...     y=torch.zeros(3, 4, 2, 2, dtype=torch.bool),
...     z="test"
...     batch_size=[3, 4])
>>> print(data)
MyData(
    X=Tensor(torch.Size([3, 4, 1]), dtype=torch.float32),
    y=Tensor(torch.Size([3, 4, 2, 2]), dtype=torch.bool),
    z="test"
    batch_size=[3, 4],
    device=None,
    is_shared=False)
>>> print(data.expand_and_mask())
tensor([])
也可以在彼此内部嵌套 tensorclass 实例

示例: >>> from tensordict import tensorclass >>> import torch >>> from typing import Optional >>> >>> @tensorclass … class NestingMyData: … nested: MyData … >>> nesting_data = NestingMyData(nested=data, batch_size=[3, 4]) >>> # 虽然数据存储为 TensorDict,但类型提示帮助我们 >>> # 将数据正确地转换为正确的类型 >>> assert isinstance(nesting_data.nested, type(data))

文档

访问 PyTorch 的全面开发者文档

查看文档

教程

获取面向初学者和高级开发者的深度教程

查看教程

资源

查找开发资源并获得问题解答

查看资源