tensorclass¶
- class tensordict.tensorclass(autocast: bool = False)¶
一个用于创建
tensorclass
类别的装饰器。tensorclass
类别是专门的dataclass
实例,可以开箱即用地执行一些预定义的张量操作,例如索引、项目赋值、重塑、转换为设备或存储以及许多其他操作。示例
>>> from tensordict import tensorclass >>> import torch >>> from typing import Optional >>> >>> @tensorclass ... class MyData: ... X: torch.Tensor ... y: torch.Tensor ... z: str ... def expand_and_mask(self): ... X = self.X.unsqueeze(-1).expand_as(self.y) ... X = X[self.y] ... return X ... >>> data = MyData( ... X=torch.ones(3, 4, 1), ... y=torch.zeros(3, 4, 2, 2, dtype=torch.bool), ... z="test" ... batch_size=[3, 4]) >>> print(data) MyData( X=Tensor(torch.Size([3, 4, 1]), dtype=torch.float32), y=Tensor(torch.Size([3, 4, 2, 2]), dtype=torch.bool), z="test" batch_size=[3, 4], device=None, is_shared=False) >>> print(data.expand_and_mask()) tensor([])
- 也可以将 tensorclasses 实例嵌套在彼此之中
示例: >>> from tensordict import tensorclass >>> import torch >>> from typing import Optional >>> >>> @tensorclass … class NestingMyData: … nested: MyData … >>> nesting_data = NestingMyData(nested=data, batch_size=[3, 4]) >>> # 尽管数据存储为 TensorDict,但类型提示帮助我们 >>> # 将数据适当地转换为正确的类型 >>> assert isinstance(nesting_data.nested, type(data))