快捷方式

tensordict.nn.make_tensordict

tensordict.nn.make_tensordict(input_dict: dict[str, CompatibleType] | None = None, batch_size: Sequence[int] | torch.Size | int | None = None, device: DeviceType | None = None, **kwargs: CompatibleType) TensorDict

返回一个由关键字参数或输入字典创建的 TensorDict。

如果未指定 batch_size,则返回尽可能大的批量大小。

此函数也适用于嵌套字典,或可用于确定嵌套 tensordict 的批量大小。

参数:
  • input_dict (字典可选) – 用作数据源的字典(支持嵌套键)。

  • **kwargs (TensorDicttorch.Tensor) – 关键字参数作为数据源(与嵌套键不兼容)。

  • batch_size (int 的可迭代对象可选) – tensordict 的批量大小。

  • device (torch.device兼容类型可选) – TensorDict 的设备。

示例

>>> input_dict = {"a": torch.randn(3, 4), "b": torch.randn(3)}
>>> print(make_tensordict(input_dict))
TensorDict(
    fields={
        a: Tensor(shape=torch.Size([3, 4]), device=cpu, dtype=torch.float32, is_shared=False),
        b: Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float32, is_shared=False)},
    batch_size=torch.Size([3]),
    device=None,
    is_shared=False)
>>> # alternatively
>>> td = make_tensordict(**input_dict)
>>> # nested dict: the nested TensorDict can have a different batch-size
>>> # as long as its leading dims match.
>>> input_dict = {"a": torch.randn(3), "b": {"c": torch.randn(3, 4)}}
>>> print(make_tensordict(input_dict))
TensorDict(
    fields={
        a: Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float32, is_shared=False),
        b: TensorDict(
            fields={
                c: Tensor(shape=torch.Size([3, 4]), device=cpu, dtype=torch.float32, is_shared=False)},
            batch_size=torch.Size([3, 4]),
            device=None,
            is_shared=False)},
    batch_size=torch.Size([3]),
    device=None,
    is_shared=False)
>>> # we can also use this to work out the batch sie of a tensordict
>>> input_td = TensorDict({"a": torch.randn(3), "b": {"c": torch.randn(3, 4)}}, [])
>>> print(make_tensordict(input_td))
TensorDict(
    fields={
        a: Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float32, is_shared=False),
        b: TensorDict(
            fields={
                c: Tensor(shape=torch.Size([3, 4]), device=cpu, dtype=torch.float32, is_shared=False)},
            batch_size=torch.Size([3, 4]),
            device=None,
            is_shared=False)},
    batch_size=torch.Size([3]),
    device=None,
    is_shared=False)

文档

访问 PyTorch 的全面开发者文档

查看文档

教程

获取针对初学者和高级开发者的深入教程

查看教程

资源

查找开发资源并获取问题的解答

查看资源