tensordict.nn.make_tensordict¶
- tensordict.nn.make_tensordict(input_dict: dict[str, CompatibleType] | None = None, batch_size: Sequence[int] | torch.Size | int | None = None, device: DeviceType | None = None, **kwargs: CompatibleType) TensorDict ¶
返回一个由关键字参数或输入字典创建的 TensorDict。
如果未指定
batch_size
,则返回尽可能大的批量大小。此函数也适用于嵌套字典,或可用于确定嵌套 tensordict 的批量大小。
- 参数:
input_dict (字典,可选) – 用作数据源的字典(支持嵌套键)。
**kwargs (TensorDict 或 torch.Tensor) – 关键字参数作为数据源(与嵌套键不兼容)。
batch_size (int 的可迭代对象,可选) – tensordict 的批量大小。
device (torch.device 或 兼容类型,可选) – TensorDict 的设备。
示例
>>> input_dict = {"a": torch.randn(3, 4), "b": torch.randn(3)} >>> print(make_tensordict(input_dict)) TensorDict( fields={ a: Tensor(shape=torch.Size([3, 4]), device=cpu, dtype=torch.float32, is_shared=False), b: Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float32, is_shared=False)}, batch_size=torch.Size([3]), device=None, is_shared=False) >>> # alternatively >>> td = make_tensordict(**input_dict) >>> # nested dict: the nested TensorDict can have a different batch-size >>> # as long as its leading dims match. >>> input_dict = {"a": torch.randn(3), "b": {"c": torch.randn(3, 4)}} >>> print(make_tensordict(input_dict)) TensorDict( fields={ a: Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float32, is_shared=False), b: TensorDict( fields={ c: Tensor(shape=torch.Size([3, 4]), device=cpu, dtype=torch.float32, is_shared=False)}, batch_size=torch.Size([3, 4]), device=None, is_shared=False)}, batch_size=torch.Size([3]), device=None, is_shared=False) >>> # we can also use this to work out the batch sie of a tensordict >>> input_td = TensorDict({"a": torch.randn(3), "b": {"c": torch.randn(3, 4)}}, []) >>> print(make_tensordict(input_td)) TensorDict( fields={ a: Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float32, is_shared=False), b: TensorDict( fields={ c: Tensor(shape=torch.Size([3, 4]), device=cpu, dtype=torch.float32, is_shared=False)}, batch_size=torch.Size([3, 4]), device=None, is_shared=False)}, batch_size=torch.Size([3]), device=None, is_shared=False)