from_dataclass¶
- class tensordict.from_dataclass(obj: Any, *, auto_batch_size: bool = False, batch_dims: Optional[int] = None, batch_size: Optional[Size] = None, frozen: bool = False, autocast: bool = False, nocast: bool = False, inplace: bool = False, shadow: bool = False, device: Optional[device] = None)¶
将 dataclass 实例或类型分别转换为 tensorclass 实例或类型。
此函数接收 dataclass 实例或 dataclass 类型,并将其转换为张量兼容的类,同时可选择应用自动批处理、不变性和类型转换等各种配置。
- 参数:
obj (Any) – 要转换的 dataclass 实例或类型。如果提供的是类型,则返回一个新的类。
- 关键字参数:
auto_batch_size (bool, 可选) – 如果为
True
,则自动确定并将批次大小应用于结果对象。默认为False
。batch_dims (int, 可选) – 如果 auto_batch_size 为
True
,则定义输出 tensordict 应具有的维度数。默认为None
(每层全批次大小)。batch_size (torch.Size, 可选) – TensorDict 的批次大小。默认为
None
。frozen (bool, 可选) – 如果为
True
,则结果类或实例将是不可变的。默认为False
。autocast (bool, 可选) – 如果为
True
,则为结果类或实例启用自动类型转换。默认为False
。nocast (bool, 可选) – 如果为
True
,则禁用结果类或实例的任何类型转换。默认为False
。inplace (bool, 可选) – 如果为
True
,则传入的 dataclass 类型将被原地修改。默认为False
。如果提供的是实例,则此参数无效。device (torch.device, 可选) – 创建 TensorDict 的设备。默认为
None
。shadow (bool, 可选) – 禁用字段名与 TensorDict 保留属性的验证。请谨慎使用,这可能会导致意外后果。默认为 False。
- 返回:
从提供的 dataclass 派生的张量兼容类或实例。
- 抛出:
TypeError – 如果提供的输入不是 dataclass 实例或类型。
示例
>>> from dataclasses import dataclass >>> import torch >>> from tensordict.tensorclass import from_dataclass >>> >>> @dataclass >>> class X: ... a: int ... b: torch.Tensor ... >>> x = X(0, 0) >>> x2 = from_dataclass(x) >>> print(x2) X( a=Tensor(shape=torch.Size([]), device=cpu, dtype=torch.int64, is_shared=False), b=Tensor(shape=torch.Size([]), device=cpu, dtype=torch.int64, is_shared=False), batch_size=torch.Size([]), device=None, is_shared=False) >>> X2 = from_dataclass(X, autocast=True) >>> print(X2(a=0, b=0)) X( a=NonTensorData(data=0, batch_size=torch.Size([]), device=None), b=Tensor(shape=torch.Size([]), device=cpu, dtype=torch.int64, is_shared=False), batch_size=torch.Size([]), device=None, is_shared=False)
警告
尽管
from_dataclass()
默认返回一个TensorDict
实例,但此方法将返回一个 tensorclass 实例或类型。