from_namedtuple¶
- 类 tensordict.from_namedtuple(named_tuple, *, auto_batch_size: bool = False)¶
递归地将 namedtuple 转换为 TensorDict。
- 关键字参数:
auto_batch_size (bool, 可选) – 如果
True
,将自动计算批大小。默认为False
。
示例
>>> from tensordict import TensorDict, from_namedtuple >>> import torch >>> data = TensorDict({ ... "a_tensor": torch.zeros((3)), ... "nested": {"a_tensor": torch.zeros((3)), "a_string": "zero!"}}, [3]) >>> nt = data.to_namedtuple() >>> print(nt) GenericDict(a_tensor=tensor([0., 0., 0.]), nested=GenericDict(a_tensor=tensor([0., 0., 0.]), a_string='zero!')) >>> from_namedtuple(nt, auto_batch_size=True) TensorDict( fields={ a_tensor: Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float32, is_shared=False), nested: TensorDict( fields={ a_string: NonTensorData(data=zero!, batch_size=torch.Size([3]), device=None), a_tensor: Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float32, is_shared=False)}, batch_size=torch.Size([3]), device=None, is_shared=False)}, batch_size=torch.Size([3]), device=None, is_shared=False)