from_pytree¶
- class tensordict.from_pytree(pytree, *, batch_size: Optional[Size] = None, auto_batch_size: bool = False, batch_dims: Optional[int] = None)¶
将 pytree 转换为 TensorDict 实例。
此方法旨在尽可能保留 pytree 的嵌套结构。
添加额外的非张量键以跟踪每个级别的标识,从而提供内置的 pytree 到 tensordict 的双射变换 API。
目前接受的类包括列表、元组、命名元组和字典。
注意
对于字典,非 NestedKey 键将作为
NonTensorData
实例单独注册。注意
可转换为张量的类型(例如 int、float 或 np.ndarray)将转换为 torch.Tensor 实例。请注意,此转换是满射的:将 tensordict 转换回 pytree 将不会恢复原始类型。
示例
>>> # Create a pytree with tensor leaves, and one "weird"-looking dict key >>> class WeirdLookingClass: ... pass ... >>> weird_key = WeirdLookingClass() >>> # Make a pytree with tuple, lists, dict and namedtuple >>> pytree = ( ... [torch.randint(10, (3,)), torch.zeros(2)], ... { ... "tensor": torch.randn( ... 2, ... ), ... "td": TensorDict({"one": 1}), ... weird_key: torch.randint(10, (2,)), ... "list": [1, 2, 3], ... }, ... {"named_tuple": TensorDict({"two": torch.ones(1) * 2}).to_namedtuple()}, ... ) >>> # Build a TensorDict from that pytree >>> td = from_pytree(pytree) >>> # Recover the pytree >>> pytree_recon = td.to_pytree() >>> # Check that the leaves match >>> def check(v1, v2): >>> assert (v1 == v2).all() >>> >>> torch.utils._pytree.tree_map(check, pytree, pytree_recon) >>> assert weird_key in pytree_recon[1]