TensorDict 在分布式设置中¶
TensorDict 可用于分布式设置,以便在不同节点之间传递张量。如果两个节点可以访问共享物理存储,则可以使用内存映射张量高效地在正在运行的不同进程之间传递数据。本文档提供了一些关于如何在分布式 RPC 环境中实现这一目标的详细信息。有关分布式 RPC 的更多详细信息,请查阅 官方 PyTorch 文档。
创建内存映射 TensorDict¶
内存映射张量(和数组)的一大优势在于它们可以存储大量数据,并允许快速访问数据切片,而无需将整个文件读入内存。TensorDict 在内存映射数组和 torch.Tensor
类之间提供了接口,该接口名为 MemmapTensor
。MemmapTensor
实例可以存储在 TensorDict
对象中,从而允许 tensordict 表示存储在磁盘上的大型数据集,并且可以轻松地跨节点进行批处理访问。
创建内存映射 tensordict 非常简单,可以通过 (1) 用内存映射张量填充 TensorDict 或 (2) 调用 tensordict.memmap_()
将其放置到物理存储上。可以通过查询 tensordict.is_memmap() 轻松检查 tensordict 是否已放置在物理存储上。
创建内存映射张量本身可以通过多种方式完成。首先,可以简单地创建一个空张量
>>> shape = torch.Size([3, 4, 5])
>>> tensor = Memmaptensor(*shape, prefix="/tmp")
>>> tensor[:2] = torch.randn(2, 4, 5)
prefix
属性指示临时文件应存储的位置。至关重要的是,张量必须存储在每个节点都可以访问的目录中!
另一种选择是表示磁盘上的现有张量
>>> tensor = torch.randn(3)
>>> tensor = Memmaptensor(tensor, prefix="/tmp")
前一种方法将在张量很大或不适合内存时更受欢迎:它适用于非常大的张量,并作为跨节点的通用存储。例如,可以创建一个数据集,该数据集可以由单个或不同节点轻松访问,比每个文件都必须独立加载到内存中要快得多
>>> dataset = TensorDict({
... "images": MemmapTensor(50000, 480, 480, 3),
... "masks": MemmapTensor(50000, 480, 480, 3, dtype=torch.bool),
... "labels": MemmapTensor(50000, 1, dtype=torch.uint8),
... }, batch_size=[50000], device="cpu")
>>> idx = [1, 5020, 34572, 11200]
>>> batch = dataset[idx].clone()
TensorDict(
fields={
images: Tensor(torch.Size([4, 480, 480, 3]), dtype=torch.float32),
labels: Tensor(torch.Size([4, 1]), dtype=torch.uint8),
masks: Tensor(torch.Size([4, 480, 480, 3]), dtype=torch.bool)},
batch_size=torch.Size([4]),
device=cpu,
is_shared=False)
请注意,我们已经指明了 MemmapTensor
的设备。这种语法糖允许在需要时直接将查询到的张量加载到设备上。
需要考虑的另一个问题是,目前 MemmapTensor
不兼容 autograd 操作。
跨节点操作内存映射张量¶
我们提供了一个简单的分布式脚本示例,其中一个进程创建一个内存映射张量,并将其引用发送给另一个负责更新它的工作进程。您可以在 benchmark 目录中找到此示例。
简而言之,我们的目标是展示当节点可以访问共享物理存储时,如何处理大型张量的读写操作。步骤包括
在磁盘上创建空张量;
设置要执行的本地和远程操作;
使用 RPC 在工作进程之间传递命令以读写共享数据。
这个示例首先编写一个函数,该函数使用一个填充了 1 的张量在特定索引处更新 TensorDict 实例
>>> def fill_tensordict(tensordict, idx):
... tensordict[idx] = TensorDict(
... {"memmap": torch.ones(5, 640, 640, 3, dtype=torch.uint8)}, [5]
... )
... return tensordict
>>> fill_tensordict_cp = CloudpickleWrapper(fill_tensordict)
CloudpickleWrapper
确保函数是可序列化的。接下来,我们创建一个相当大尺寸的 tensordict,以说明如果必须通过常规 tensorpipe 传递,这将很难在工作进程之间传递。
>>> tensordict = TensorDict(
... {"memmap": MemmapTensor(1000, 640, 640, 3, dtype=torch.uint8, prefix="/tmp/")}, [1000]
... )
最后,仍在主节点上,我们在远程节点上调用该函数,然后检查数据是否已写入所需位置
>>> idx = [4, 5, 6, 7, 998]
>>> t0 = time.time()
>>> out = rpc.rpc_sync(
... worker_info,
... fill_tensordict_cp,
... args=(tensordict, idx),
... )
>>> print("time elapsed:", time.time() - t0)
>>> print("check all ones", out["memmap"][idx, :1, :1, :1].clone())
尽管调用 rpc.rpc_sync
涉及传递整个 tensordict,更新此对象的特定索引并将其返回给原始工作进程,但此代码段的执行速度非常快(如果事先已传递内存位置的引用,速度更快,请参阅 torchrl 的分布式重放缓冲区文档了解更多信息)。
该脚本包含本文档目的之外的其他 RPC 配置步骤。