• 文档 >
  • 使用 TensorDict 简化 PyTorch 内存管理
快捷方式

使用 TensorDict 简化 PyTorch 内存管理

作者: Tom Begley

在本教程中,您将学习如何控制 TensorDict 的内容在内存中的存储位置,可以通过将这些内容发送到设备,或者通过利用内存映射。

设备

当您创建一个 TensorDict 时,您可以使用 device 关键字参数指定设备。如果设置了 device,则 TensorDict 的所有条目都将放置在该设备上。如果未设置 device,则不要求 TensorDict 中的条目必须在同一设备上。

在本示例中,我们使用 device="cuda:0" 实例化一个 TensorDict。当我们打印内容时,我们可以看到它们已被移动到设备上。

>>> import torch
>>> from tensordict import TensorDict
>>> tensordict = TensorDict({"a": torch.rand(10)}, [10], device="cuda:0")
>>> print(tensordict)
TensorDict(
    fields={
        a: Tensor(shape=torch.Size([10]), device=cuda:0, dtype=torch.float32, is_shared=True)},
    batch_size=torch.Size([10]),
    device=cuda:0,
    is_shared=True)

如果 TensorDict 的设备不是 None,则新条目也会被移动到设备上。

>>> tensordict["b"] = torch.rand(10, 10)
>>> print(tensordict)
TensorDict(
    fields={
        a: Tensor(shape=torch.Size([10]), device=cuda:0, dtype=torch.float32, is_shared=True),
        b: Tensor(shape=torch.Size([10, 10]), device=cuda:0, dtype=torch.float32, is_shared=True)},
    batch_size=torch.Size([10]),
    device=cuda:0,
    is_shared=True)

您可以使用 device 属性检查 TensorDict 的当前设备。

>>> print(tensordict.device)
cuda:0

可以将 TensorDict 的内容发送到设备,就像 PyTorch 张量一样,可以使用 TensorDict.cuda()TensorDict.device(device),其中 device 是所需的设备。

>>> tensordict.to(torch.device("cpu"))
>>> print(tensordict)
TensorDict(
    fields={
        a: Tensor(shape=torch.Size([10]), device=cpu, dtype=torch.float32, is_shared=False),
        b: Tensor(shape=torch.Size([10, 10]), device=cpu, dtype=torch.float32, is_shared=False)},
    batch_size=torch.Size([10]),
    device=cpu,
    is_shared=False)
>>> tensordict.cuda()
>>> print(tensordict)
TensorDict(
    fields={
        a: Tensor(shape=torch.Size([10]), device=cuda:0, dtype=torch.float32, is_shared=True),
        b: Tensor(shape=torch.Size([10, 10]), device=cuda:0, dtype=torch.float32, is_shared=True)},
    batch_size=torch.Size([10]),
    device=cuda:0,
    is_shared=True)

TensorDict.device 方法需要传递有效的设备作为参数。如果您想从 TensorDict 中移除设备,以允许使用不同设备的值,则应使用 TensorDict.clear_device 方法。

>>> tensordict.clear_device()
>>> print(tensordict)
TensorDict(
    fields={
        a: Tensor(shape=torch.Size([10]), device=cuda:0, dtype=torch.float32, is_shared=True),
        b: Tensor(shape=torch.Size([10, 10]), device=cuda:0, dtype=torch.float32, is_shared=True)},
    batch_size=torch.Size([10]),
    device=None,
    is_shared=False)

内存映射张量

tensordict 提供了一个类 MemoryMappedTensor,它允许我们将张量的内容存储在磁盘上,同时仍然支持快速索引和批量加载内容。有关实际应用示例,请参阅 ImageNet 教程

要将 TensorDict 转换为内存映射张量的集合,请使用 TensorDict.memmap_

tensordict = TensorDict({"a": torch.rand(10), "b": {"c": torch.rand(10)}}, [10])
tensordict.memmap_()

print(tensordict)
TensorDict(
    fields={
        a: MemoryMappedTensor(shape=torch.Size([10]), device=cpu, dtype=torch.float32, is_shared=False),
        b: TensorDict(
            fields={
                c: MemoryMappedTensor(shape=torch.Size([10]), device=cpu, dtype=torch.float32, is_shared=False)},
            batch_size=torch.Size([10]),
            device=cpu,
            is_shared=False)},
    batch_size=torch.Size([10]),
    device=cpu,
    is_shared=False)

或者,可以使用 TensorDict.memmap_like 方法。这将创建一个新的 TensorDict,其结构相同,但值是 MemoryMappedTensor,但它不会将原始张量的内容复制到内存映射张量。这允许您创建内存映射的 TensorDict,然后缓慢地填充它,因此通常应优先于 memmap_

tensordict = TensorDict({"a": torch.rand(10), "b": {"c": torch.rand(10)}}, [10])
mm_tensordict = tensordict.memmap_like()

print(mm_tensordict["a"].contiguous())
MemoryMappedTensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0.])

默认情况下,TensorDict 的内容将保存到磁盘上的临时位置,但是,如果您想控制它们的保存位置,可以使用关键字参数 prefix="/path/to/root"

TensorDict 的内容保存在目录结构中,该结构模仿了 TensorDict 本身的结构。张量的内容保存在 NumPy 内存映射中,元数据保存在关联的 PyTorch 保存文件中。例如,上面的 TensorDict 保存如下

├── a.memmap
├── a.meta.pt
├── b
│ ├── c.memmap
│ ├── c.meta.pt
│ └── meta.pt
└── meta.pt

脚本总运行时间: (0 分 0.004 秒)

图库由 Sphinx-Gallery 生成

文档

访问 PyTorch 的全面开发者文档

查看文档

教程

获取面向初学者和高级开发者的深入教程

查看教程

资源

查找开发资源并获得问题解答

查看资源