快捷方式

from_dataclass

class tensordict.from_dataclass(obj: Any, *, auto_batch_size: bool = False, batch_dims: Optional[int] = None, batch_size: Optional[Size] = None, frozen: bool = False, autocast: bool = False, nocast: bool = False, inplace: bool = False, shadow: bool = False, device: Optional[device] = None)

将 dataclass 实例或类型分别转换为 tensorclass 实例或类型。

此函数接收 dataclass 实例或 dataclass 类型,并将其转换为张量兼容的类,同时可选择应用自动批处理、不变性和类型转换等各种配置。

参数:

obj (Any) – 要转换的 dataclass 实例或类型。如果提供的是类型,则返回一个新的类。

关键字参数:
  • auto_batch_size (bool, 可选) – 如果为 True,则自动确定并将批次大小应用于结果对象。默认为 False

  • batch_dims (int, 可选) – 如果 auto_batch_size 为 True,则定义输出 tensordict 应具有的维度数。默认为 None(每层全批次大小)。

  • batch_size (torch.Size, 可选) – TensorDict 的批次大小。默认为 None

  • frozen (bool, 可选) – 如果为 True,则结果类或实例将是不可变的。默认为 False

  • autocast (bool, 可选) – 如果为 True,则为结果类或实例启用自动类型转换。默认为 False

  • nocast (bool, 可选) – 如果为 True,则禁用结果类或实例的任何类型转换。默认为 False

  • inplace (bool, 可选) – 如果为 True,则传入的 dataclass 类型将被原地修改。默认为 False。如果提供的是实例,则此参数无效。

  • device (torch.device, 可选) – 创建 TensorDict 的设备。默认为 None

  • shadow (bool, 可选) – 禁用字段名与 TensorDict 保留属性的验证。请谨慎使用,这可能会导致意外后果。默认为 False。

返回:

从提供的 dataclass 派生的张量兼容类或实例。

抛出:

TypeError – 如果提供的输入不是 dataclass 实例或类型。

示例

>>> from dataclasses import dataclass
>>> import torch
>>> from tensordict.tensorclass import from_dataclass
>>>
>>> @dataclass
>>> class X:
...     a: int
...     b: torch.Tensor
...
>>> x = X(0, 0)
>>> x2 = from_dataclass(x)
>>> print(x2)
X(
    a=Tensor(shape=torch.Size([]), device=cpu, dtype=torch.int64, is_shared=False),
    b=Tensor(shape=torch.Size([]), device=cpu, dtype=torch.int64, is_shared=False),
    batch_size=torch.Size([]),
    device=None,
    is_shared=False)
>>> X2 = from_dataclass(X, autocast=True)
>>> print(X2(a=0, b=0))
X(
    a=NonTensorData(data=0, batch_size=torch.Size([]), device=None),
    b=Tensor(shape=torch.Size([]), device=cpu, dtype=torch.int64, is_shared=False),
    batch_size=torch.Size([]),
    device=None,
    is_shared=False)

警告

尽管 from_dataclass() 默认返回一个 TensorDict 实例,但此方法将返回一个 tensorclass 实例或类型。

文档

访问 PyTorch 的全面开发者文档

查看文档

教程

获取针对初学者和高级开发者的深度教程

查看教程

资源

查找开发资源并获得问题解答

查看资源