注意
2024 年 6 月状态更新:移除 DataPipes 和 DataLoader V2
我们正在重新聚焦 torchdata 仓库,使其成为 torch.utils.data.DataLoader 的迭代增强。我们不计划继续开发或维护 [DataPipes] 和 [DataLoaderV2] 解决方案,它们将从 torchdata 仓库中移除。我们还将重新审视 pytorch/pytorch 中的 DataPipes 引用。在 torchdata==0.8.0 版本(2024 年 7 月)中,它们将被标记为已弃用,并在 0.10.0 版本(2024 年末)中被删除。建议现有用户锁定到 torchdata<=0.9.0 或更旧版本,直到他们能够迁移为止。后续版本将不包含 DataPipes 或 DataLoaderV2。如果您有建议或意见,请联系我们(请使用 此 issue 进行反馈)
TorchData.nodes (beta) 入门¶
使用 pip 安装 torchdata。
pip install torchdata>=0.10.0
生成器示例¶
包装一个生成器(或任何可迭代对象)以将其转换为 BaseNode 并开始使用
from torchdata.nodes import IterableWrapper, ParallelMapper, Loader
node = IterableWrapper(range(10))
node = ParallelMapper(node, map_fn=lambda x: x**2, num_workers=3, method="thread")
loader = Loader(node)
result = list(loader)
print(result)
# [0, 1, 4, 9, 16, 25, 36, 49, 64, 81]
Sampler 示例¶
仍然支持 Sampler,您可以继续使用现有的 torch.utils.data.Dataset
。有关深入示例,请参阅 从 torch.utils.data 迁移到 torchdata.nodes。
import torch.utils.data
from torchdata.nodes import SamplerWrapper, ParallelMapper, Loader
class SquaredDataset(torch.utils.data.Dataset):
def __getitem__(self, i: int) -> int:
return i**2
def __len__(self):
return 10
dataset = SquaredDataset()
sampler = RandomSampler(dataset)
# For fine-grained control of iteration order, define your own sampler
node = SamplerWrapper(sampler)
# Simply apply dataset's __getitem__ as a map function to the indices generated from sampler
node = ParallelMapper(node, map_fn=dataset.__getitem__, num_workers=3, method="thread")
# Loader is used to convert a node (iterator) into an Iterable that may be reused for multi epochs
loader = Loader(node)
print(list(loader))
# [25, 36, 9, 49, 0, 81, 4, 16, 64, 1]
print(list(loader))
# [0, 4, 1, 64, 49, 25, 9, 16, 81, 36]