快捷方式

tensorclass

tensordict.tensorclass(cls: 可选[T] = , /, *, autocast: 布尔值 = False, frozen: 布尔值 = False, nocast: 布尔值 = False, shadow: 布尔值 = False)

一个用于创建 tensorclass 类的装饰器。

tensorclass 类是专门化的 dataclasses.dataclass() 实例,它们可以即时执行一些预定义的张量操作,例如索引、项赋值、重塑、转换为设备或存储等许多操作。

关键字参数:
  • autocast (布尔值, 可选) – 如果为 True,则在设置参数时将强制执行指定的类型。此参数与 nocast 互斥(两者不能同时为 True)。默认为 False

  • frozen (布尔值, 可选) – 如果为 True,则 tensorclass 的内容无法修改。提供此参数是为了与 dataclass 兼容,通过类构造函数中的 lock 参数可以获得类似的行为。默认为 False

  • nocast (布尔值, 可选) – 如果为 True,则 Tensor 兼容的类型,如 intnp.ndarray 等,将不会被转换为张量类型。此参数与 autocast 互斥(两者不能同时为 True)。默认为 False

  • shadow (布尔值, 可选) – 禁用对字段名与 TensorDict 保留属性的验证。请谨慎使用,因为这可能导致意外后果。默认为 False。

tensorclass 可以带或不带参数使用

示例

>>> @tensorclass
... class X:
...     y: int
>>> X(torch.ones(())).y
tensor(1.)
>>> @tensorclass(autocast=False)
... class X:
...     y: int
>>> X(torch.ones(())).y
tensor(1.)
>>> @tensorclass(autocast=True)
... class X:
...     y: int
>>> X(torch.ones(())).y
1
>>> @tensorclass(nocast=True)
... class X:
...     y: Any
>>> X(1).y
1
>>> @tensorclass(nocast=False)
... class X:
...     y: Any
>>> X(1).y
tensor(1)

示例

>>> from tensordict import tensorclass
>>> import torch
>>> from typing import Optional
>>>
>>> @tensorclass
... class MyData:
...     X: torch.Tensor
...     y: torch.Tensor
...     z: str
...     def expand_and_mask(self):
...         X = self.X.unsqueeze(-1).expand_as(self.y)
...         X = X[self.y]
...         return X
...
>>> data = MyData(
...     X=torch.ones(3, 4, 1),
...     y=torch.zeros(3, 4, 2, 2, dtype=torch.bool),
...     z="test"
...     batch_size=[3, 4])
>>> print(data)
MyData(
    X=Tensor(torch.Size([3, 4, 1]), dtype=torch.float32),
    y=Tensor(torch.Size([3, 4, 2, 2]), dtype=torch.bool),
    z="test"
    batch_size=[3, 4],
    device=None,
    is_shared=False)
>>> print(data.expand_and_mask())
tensor([])
也可以将 tensorclass 实例互相嵌套

示例: >>> from tensordict import tensorclass >>> import torch >>> from typing import Optional >>> >>> @tensorclass … class NestingMyData: … nested: MyData … >>> nesting_data = NestingMyData(nested=data, batch_size=[3, 4]) >>> # 尽管数据存储为 TensorDict,但类型提示有助于我们将数据适当地转换为正确的类型 >>> assert isinstance(nesting_data.nested, type(data))

文档

访问 PyTorch 的全面开发者文档

查看文档

教程

获取面向初学者和高级开发者的深度教程

查看教程

资源

查找开发资源并获得问题解答

查看资源