• 文档 >
  • torchdata.nodes 入门 (beta)
快捷方式

注意

2024 年 6 月状态更新:移除 DataPipes 和 DataLoader V2

我们正在重新调整 torchdata 仓库的重点,使其成为 torch.utils.data.DataLoader 的迭代增强。我们不计划继续开发或维护 [DataPipes] 和 [DataLoaderV2] 解决方案,它们将从 torchdata 仓库中移除。我们还将重新检查 pytorch/pytorch 中对 DataPipes 的引用。在 torchdata==0.8.0 版本(2024 年 7 月)中,它们将被标记为已弃用,并在 0.10.0 版本(2024 年末)中被删除。建议现有用户在能够迁移之前,将版本固定到 torchdata<=0.9.0 或更旧的版本。后续版本将不包括 DataPipes 或 DataLoaderV2。如果您有任何建议或意见,请联系我们(请使用 此议题 提供反馈)

torchdata.nodes 入门 (beta)

使用 pip 安装 torchdata。

pip install torchdata>=0.10.0

生成器示例

包装一个生成器(或任何可迭代对象)将其转换为 BaseNode 并开始使用

from torchdata.nodes import IterableWrapper, ParallelMapper, Loader

node = IterableWrapper(range(10))
node = ParallelMapper(node, map_fn=lambda x: x**2, num_workers=3, method="thread")
loader = Loader(node)
result = list(loader)
print(result)
# [0, 1, 4, 9, 16, 25, 36, 49, 64, 81]

采样器示例

采样器仍受支持,您可以使用现有的 torch.utils.data.Dataset。有关深入示例,请参见 从 torch.utils.data 迁移到 torchdata.nodes

import torch.utils.data
from torchdata.nodes import SamplerWrapper, ParallelMapper, Loader


class SquaredDataset(torch.utils.data.Dataset):
    def __getitem__(self, i: int) -> int:
        return i**2
    def __len__(self):
        return 10

dataset = SquaredDataset()
sampler = RandomSampler(dataset)

# For fine-grained control of iteration order, define your own sampler
node = SamplerWrapper(sampler)
# Simply apply dataset's __getitem__ as a map function to the indices generated from sampler
node = ParallelMapper(node, map_fn=dataset.__getitem__, num_workers=3, method="thread")
# Loader is used to convert a node (iterator) into an Iterable that may be reused for multi epochs
loader = Loader(node)
print(list(loader))
# [25, 36, 9, 49, 0, 81, 4, 16, 64, 1]
print(list(loader))
# [0, 4, 1, 64, 49, 25, 9, 16, 81, 36]

文档

访问全面的 PyTorch 开发者文档

查看文档

教程

获取面向初学者和高级开发者的深度教程

查看教程

资源

查找开发资源并获得问题解答

查看资源