注意
2024 年 6 月状态更新:移除 DataPipes 和 DataLoader V2
我们正在重新聚焦 torchdata 仓库,使其成为 torch.utils.data.DataLoader 的迭代增强。我们不计划继续开发或维护 [DataPipes] 和 [DataLoaderV2] 解决方案,它们将从 torchdata 仓库中移除。我们还将重新审视 pytorch/pytorch 中的 DataPipes 引用。在 torchdata==0.8.0 版本(2024 年 7 月)中,它们将被标记为已弃用,在 0.10.0 版本(2024 年末)中,它们将被删除。建议现有用户锁定到 torchdata<=0.9.0 或更旧版本,直到他们能够迁移出去。后续版本将不包含 DataPipes 或 DataLoaderV2。如果您有任何建议或意见,请联系我们(请使用 此 issue 进行反馈)
从 torch.utils.data
迁移到 torchdata.nodes
¶
本指南旨在帮助熟悉 torch.utils.data
或 StatefulDataLoader
的用户开始使用 torchdata.nodes
,并为定义您自己的数据加载管道提供一个起点。
我们将演示如何实现最常用的 DataLoader 功能,重用现有的采样器和数据集,以及加载/保存数据加载器状态。它的性能至少与 DataLoader
和 StatefulDataLoader
一样好,请参阅 torchdata.nodes 的性能如何?。
Map-Style 数据集¶
让我们看看 DataLoader
构造函数参数,并从那里开始
class DataLoader:
def __init__(
self,
dataset: Dataset[_T_co],
batch_size: Optional[int] = 1,
shuffle: Optional[bool] = None,
sampler: Union[Sampler, Iterable, None] = None,
batch_sampler: Union[Sampler[List], Iterable[List], None] = None,
num_workers: int = 0,
collate_fn: Optional[_collate_fn_t] = None,
pin_memory: bool = False,
drop_last: bool = False,
timeout: float = 0,
worker_init_fn: Optional[_worker_init_fn_t] = None,
multiprocessing_context=None,
generator=None,
*,
prefetch_factor: Optional[int] = None,
persistent_workers: bool = False,
pin_memory_device: str = "",
in_order: bool = True,
):
...
作为回顾,以下大致是 torch.utils.data.DataLoader
中数据加载的工作方式:DataLoader
首先从 sampler
生成索引,并创建 batch_size 索引批次。如果未提供采样器,则默认创建 RandomSampler 或 SequentialSampler。索引被传递给 Dataset.__getitem__()
,然后将 collate_fn
应用于样本批次。如果 num_workers > 0
,它将使用多进程来创建子进程,并将索引批次传递给工作进程,工作进程随后将调用 Dataset.__getitem__()
并应用 collate_fn
,然后再将批次返回给主进程。在这一点上,可能会将 pin_memory
应用于批次中的张量。
现在让我们看看使用 torchdata.nodes
构建的 DataLoader 的等效实现可能是什么样子。
from typing import List, Callable
import torchdata.nodes as tn
from torch.utils.data import RandomSampler, SequentialSampler, default_collate, Dataset
class MapAndCollate:
"""A simple transform that takes a batch of indices, maps with dataset, and then applies
collate.
TODO: make this a standard utility in torchdata.nodes
"""
def __init__(self, dataset, collate_fn):
self.dataset = dataset
self.collate_fn = collate_fn
def __call__(self, batch_of_indices: List[int]):
batch = [self.dataset[i] for i in batch_of_indices]
return self.collate_fn(batch)
# To keep things simple, let's assume that the following args are provided by the caller
def NodesDataLoader(
dataset: Dataset,
batch_size: int,
shuffle: bool,
num_workers: int,
collate_fn: Callable | None,
pin_memory: bool,
drop_last: bool,
):
# Assume we're working with a map-style dataset
assert hasattr(dataset, "__getitem__") and hasattr(dataset, "__len__")
# Start with a sampler, since caller did not provide one
sampler = RandomSampler(dataset) if shuffle else SequentialSampler(dataset)
# Sampler wrapper converts a Sampler to a BaseNode
node = tn.SamplerWrapper(sampler)
# Now let's batch sampler indices together
node = tn.Batcher(node, batch_size=batch_size, drop_last=drop_last)
# Create a Map Function that accepts a list of indices, applies getitem to it, and
# then collates them
map_and_collate = MapAndCollate(dataset, collate_fn or default_collate)
# MapAndCollate is doing most of the heavy lifting, so let's parallelize it. We could
# choose process or thread workers. Note that if you're not using Free-Threaded
# Python (eg 3.13t) with -Xgil=0, then multi-threading might result in GIL contention,
# and slow down training.
node = tn.ParallelMapper(
node,
map_fn=map_and_collate,
num_workers=num_workers,
method="process", # Set this to "thread" for multi-threading
in_order=True,
)
# Optionally apply pin-memory, and we usually do some pre-fetching
if pin_memory:
node = tn.PinMemory(node)
node = tn.Prefetcher(node, prefetch_factor=num_workers * 2)
# Note that node is an iterator, and once it's exhausted, you'll need to call .reset()
# on it to start a new Epoch.
# Insteaad, we wrap the node in a Loader, which is an iterable and handles reset. It
# also provides state_dict and load_state_dict methods.
return tn.Loader(node)
现在让我们用一个简单的玩具数据集来测试一下,并演示状态管理是如何工作的。
class SquaredDataset(Dataset):
def __init__(self, len: int):
self.len = len
def __len__(self):
return self.len
def __getitem__(self, i: int) -> int:
return i**2
loader = NodesDataLoader(
dataset=SquaredDataset(14),
batch_size=3,
shuffle=False,
num_workers=2,
collate_fn=None,
pin_memory=False,
drop_last=False,
)
batches = []
for idx, batch in enumerate(loader):
if idx == 2:
state_dict = loader.state_dict()
# Saves the state_dict after batch 2 has been returned
batches.append(batch)
loader.load_state_dict(state_dict)
batches_after_loading = list(loader)
print(batches[3:])
# [tensor([ 81, 100, 121]), tensor([144, 169])]
print(batches_after_loading)
# [tensor([ 81, 100, 121]), tensor([144, 169])]
让我们也将其与 torch.utils.data.DataLoader 进行比较,作为健全性检查。
loaderv1 = torch.utils.data.DataLoader(
dataset=SquaredDataset(14),
batch_size=3,
shuffle=False,
num_workers=2,
collate_fn=None,
pin_memory=False,
drop_last=False,
persistent_workers=False, # Coming soon to torchdata.nodes!
)
print(list(loaderv1))
# [tensor([0, 1, 4]), tensor([ 9, 16, 25]), tensor([36, 49, 64]), tensor([ 81, 100, 121]), tensor([144, 169])]
print(batches)
# [tensor([0, 1, 4]), tensor([ 9, 16, 25]), tensor([36, 49, 64]), tensor([ 81, 100, 121]), tensor([144, 169])]
IterableDatasets¶
即将推出!虽然您已经可以将您的 IterableDataset 插入到 tn.IterableWrapper
中,但某些函数(如 get_worker_info
)目前尚不支持。但是,我们认为通常情况下,在多进程工作器之间分片工作实际上是不必要的,您可以保持某种形式的在主进程中索引,同时仅并行化一些较重的转换,类似于上面 Map-style 数据集的工作方式。