分布式环境下的 TensorDict¶
TensorDict 可用于分布式环境中,将张量从一个节点传递到另一个节点。如果两个节点可以访问共享的物理存储,则可以使用内存映射张量来有效地将数据从一个正在运行的进程传递到另一个进程。在这里,我们提供了一些关于如何在分布式 RPC 设置中实现此目标的详细信息。有关分布式 RPC 的更多详细信息,请查看官方 PyTorch 文档。
创建内存映射 TensorDict¶
内存映射张量(和数组)具有很大的优势,它们可以存储大量数据,并允许轻松访问数据切片,而无需将整个文件读入内存。TensorDict 在内存映射数组和名为MemmapTensor
的torch.Tensor
类之间提供了一个接口。MemmapTensor
实例可以存储在TensorDict
对象中,允许 tensordict 表示一个大型数据集,存储在磁盘上,并能够轻松地以批处理方式跨节点访问。
内存映射 tensordict 可以通过以下两种方式创建:(1)使用内存映射张量填充 TensorDict,或(2)调用tensordict.memmap_()
将其放置在物理存储上。可以通过查询tensordict.is_memmap() 轻松检查 tensordict 是否已放置在物理存储上。
创建内存映射张量本身可以通过多种方式完成。首先,可以简单地创建一个空张量
>>> shape = torch.Size([3, 4, 5])
>>> tensor = Memmaptensor(*shape, prefix="/tmp")
>>> tensor[:2] = torch.randn(2, 4, 5)
prefix
属性指示临时文件必须存储的位置。至关重要的是,张量必须存储在每个节点都可以访问的目录中!
另一种选择是表示磁盘上的现有张量
>>> tensor = torch.randn(3)
>>> tensor = Memmaptensor(tensor, prefix="/tmp")
当张量很大或无法放入内存时,建议使用前一种方法:它适用于非常大的张量,并作为跨节点的公共存储。例如,可以创建一个数据集,该数据集可以由单个或不同的节点轻松访问,这比每个文件必须独立加载到内存中要快得多
>>> dataset = TensorDict({
... "images": MemmapTensor(50000, 480, 480, 3),
... "masks": MemmapTensor(50000, 480, 480, 3, dtype=torch.bool),
... "labels": MemmapTensor(50000, 1, dtype=torch.uint8),
... }, batch_size=[50000], device="cpu")
>>> idx = [1, 5020, 34572, 11200]
>>> batch = dataset[idx].clone()
TensorDict(
fields={
images: Tensor(torch.Size([4, 480, 480, 3]), dtype=torch.float32),
labels: Tensor(torch.Size([4, 1]), dtype=torch.uint8),
masks: Tensor(torch.Size([4, 480, 480, 3]), dtype=torch.bool)},
batch_size=torch.Size([4]),
device=cpu,
is_shared=False)
请注意,我们已经指示了MemmapTensor
的设备。此语法糖允许在需要时将查询的张量直接加载到设备上。
另一个需要考虑的因素是,目前MemmapTensor
与自动梯度操作不兼容。
跨节点操作内存映射张量¶
我们提供了一个分布式脚本的简单示例,其中一个进程创建一个内存映射张量,并将它的引用发送到另一个负责更新它的工作程序。您将在基准测试目录 中找到此示例。
简而言之,我们的目标是展示如何在节点可以访问共享物理存储时处理大型张量的读写操作。步骤包括
在磁盘上创建空张量;
设置要执行的本地和远程操作;
使用 RPC 将命令从工作程序传递到工作程序以读取和写入共享数据。
此示例首先编写一个函数,该函数使用一个填充为 1 的张量更新 TensorDict 实例的特定索引
>>> def fill_tensordict(tensordict, idx):
... tensordict[idx] = TensorDict(
... {"memmap": torch.ones(5, 640, 640, 3, dtype=torch.uint8)}, [5]
... )
... return tensordict
>>> fill_tensordict_cp = CloudpickleWrapper(fill_tensordict)
CloudpickleWrapper
确保函数可序列化。接下来,我们创建一个相当大的 tensordict,以说明如果必须通过常规 tensorpipe 传递,则很难将它从一个工作程序传递到另一个工作程序
>>> tensordict = TensorDict(
... {"memmap": MemmapTensor(1000, 640, 640, 3, dtype=torch.uint8, prefix="/tmp/")}, [1000]
... )
最后,仍然在主节点上,我们在远程节点上调用该函数,然后检查数据是否已写入所需位置
>>> idx = [4, 5, 6, 7, 998]
>>> t0 = time.time()
>>> out = rpc.rpc_sync(
... worker_info,
... fill_tensordict_cp,
... args=(tensordict, idx),
... )
>>> print("time elapsed:", time.time() - t0)
>>> print("check all ones", out["memmap"][idx, :1, :1, :1].clone())
尽管调用rpc.rpc_sync
涉及传递整个 tensordict,更新此对象的特定索引并将其返回到原始工作程序,但此代码段的执行速度非常快(如果事先已传递对内存位置的引用,则速度更快,请参阅torchrl 的分布式回放缓冲区文档 以了解更多信息)。
该脚本包含其他 RPC 配置步骤,这些步骤超出了本文档的目的。