• 文档 >
  • 分布式设置中的 TensorDict
快捷方式

分布式设置中的 TensorDict

TensorDict 可用于分布式设置中,以将张量从一个节点传递到另一个节点。如果两个节点可以访问共享物理存储,则可以使用内存映射张量来有效地将数据从一个运行进程传递到另一个进程。 在这里,我们提供有关如何在分布式 RPC 设置中实现此目的的一些详细信息。 有关分布式 RPC 的更多详细信息,请查看官方 pytorch 文档

创建内存映射 TensorDict

内存映射张量(和数组)具有很大的优势,它们可以存储大量数据,并允许轻松访问数据切片,而无需将整个文件读入内存。 TensorDict 提供了内存映射数组和名为 MemmapTensortorch.Tensor 类之间的接口。 MemmapTensor 实例可以存储在 TensorDict 对象中,从而允许 tensordict 表示存储在磁盘上的大型数据集,并且可以在节点之间以批量方式轻松访问。

内存映射 tensordict 只是通过(1)使用内存映射张量填充 TensorDict 或(2)调用 tensordict.memmap_() 将其放置在物理存储上来创建的。 可以通过查询 tensordict.is_memmap() 轻松检查 tensordict 是否放置在物理存储上。

创建内存映射张量本身可以通过多种方式完成。 首先,可以简单地创建一个空张量

>>> shape = torch.Size([3, 4, 5])
>>> tensor = Memmaptensor(*shape, prefix="/tmp")
>>> tensor[:2] = torch.randn(2, 4, 5)

prefix 属性指示临时文件必须存储的位置。 重要的是,张量必须存储在每个节点都可以访问的目录中!

另一种选择是在磁盘上表示现有张量

>>> tensor = torch.randn(3)
>>> tensor = Memmaptensor(tensor, prefix="/tmp")

当张量很大或不适合内存时,前者方法将是首选:它适用于非常大的张量,并用作跨节点的公共存储。 例如,可以创建一个数据集,单个节点或不同节点可以轻松访问该数据集,这比每个文件都必须独立加载到内存中要快得多

在磁盘上创建空数据集
>>> dataset = TensorDict({
...      "images": MemmapTensor(50000, 480, 480, 3),
...      "masks": MemmapTensor(50000, 480, 480, 3, dtype=torch.bool),
...      "labels": MemmapTensor(50000, 1, dtype=torch.uint8),
... }, batch_size=[50000], device="cpu")
>>> idx = [1, 5020, 34572, 11200]
>>> batch = dataset[idx].clone()
TensorDict(
    fields={
        images: Tensor(torch.Size([4, 480, 480, 3]), dtype=torch.float32),
        labels: Tensor(torch.Size([4, 1]), dtype=torch.uint8),
        masks: Tensor(torch.Size([4, 480, 480, 3]), dtype=torch.bool)},
    batch_size=torch.Size([4]),
    device=cpu,
    is_shared=False)

请注意,我们已指示 MemmapTensor 的设备。 这种语法糖允许在需要时将查询的张量直接加载到设备上。

需要考虑的另一个因素是,目前 MemmapTensor 与 autograd 操作不兼容。

跨节点操作内存映射张量

我们提供了一个分布式脚本的简单示例,其中一个进程创建一个内存映射张量,并将其引用发送给另一个负责更新它的工作进程。 您将在 benchmark 目录中找到此示例。

简而言之,我们的目标是展示当节点可以访问共享物理存储时,如何处理大型张量的读取和写入操作。 步骤包括

  • 在磁盘上创建空张量;

  • 设置要执行的本地和远程操作;

  • 使用 RPC 将命令从工作进程传递到工作进程,以读取和写入共享数据。

此示例首先编写一个函数,该函数使用一个填充为 1 的张量更新特定索引处的 TensorDict 实例

>>> def fill_tensordict(tensordict, idx):
...     tensordict[idx] = TensorDict(
...         {"memmap": torch.ones(5, 640, 640, 3, dtype=torch.uint8)}, [5]
...     )
...     return tensordict
>>> fill_tensordict_cp = CloudpickleWrapper(fill_tensordict)

CloudpickleWrapper 确保该函数是可序列化的。 接下来,我们创建一个相当大的 tensordict,以说明如果必须通过常规 tensorpipe 传递它,则很难从工作进程传递到工作进程

>>> tensordict = TensorDict(
...     {"memmap": MemmapTensor(1000, 640, 640, 3, dtype=torch.uint8, prefix="/tmp/")}, [1000]
... )

最后,仍然在主节点上,我们在远程节点上调用该函数,然后检查数据是否已写入所需位置

>>> idx = [4, 5, 6, 7, 998]
>>> t0 = time.time()
>>> out = rpc.rpc_sync(
...     worker_info,
...     fill_tensordict_cp,
...     args=(tensordict, idx),
... )
>>> print("time elapsed:", time.time() - t0)
>>> print("check all ones", out["memmap"][idx, :1, :1, :1].clone())

尽管对 rpc.rpc_sync 的调用涉及传递整个 tensordict、更新此对象的特定索引并将其返回给原始工作进程,但此代码段的执行速度非常快(如果内存位置的引用已预先传递,则速度更快,请参阅 torchrl 的分布式重放缓冲区文档以了解更多信息)。

该脚本包含超出本文档目的的其他 RPC 配置步骤。

文档

访问 PyTorch 的全面开发者文档

查看文档

教程

获取面向初学者和高级开发者的深入教程

查看教程

资源

查找开发资源并获得解答

查看资源