FullSync¶
- class torchdata.datapipes.iter.FullSync(datapipe: IterDataPipe, timeout=1800)¶
跨分布式进程同步数据,以防止训练期间因不均匀的切分数据而导致挂起(函数名称:
fullsync
)。当最短的分布式分片耗尽时,它会停止。它会由DistributedReadingService
自动附加到DataPipe
图的末尾。- 参数:
datapipe – 需要同步的 IterDataPipe
timeout – 预取数据的超时时间(以秒为单位)。默认值为 30 分钟
示例
>>> from torchdata.datapipes.iter import IterableWrapper >>> # Distributed training with world size 2 >>> world_size = 2 >>> dp = IterableWrapper(list(range(23))).sharding_filter() >>> torch.utils.data.graph_settings.apply_sharding(dp, world_size, rank) >>> # Rank 0 has 12 elements; Rank 1 has 11 elements >>> for d in dp: ... model(d) # Hanging at the end of epoch due to uneven sharding >>> dp = dp.fullsync() >>> # Both ranks have 11 elements >>> for d in dp: ... model(d) # Not hanging anymore