快捷方式

洗牌器

class torchdata.datapipes.iter.Shuffler(datapipe: IterDataPipe[T_co], *, buffer_size: int = 10000, unbatch_level: int = 0)

使用缓冲区对输入数据管道进行洗牌(函数名:shuffle)。

首先,使用buffer_size填充来自数据管道的元素的缓冲区。然后,通过迭代器使用蓄水池抽样从缓冲区中生成每个项目。

buffer_size 必须大于 0。对于 buffer_size == 1,数据管道不会被洗牌。为了完全对数据管道中的所有元素进行洗牌,buffer_size 必须大于或等于数据管道的大小。

当与 torch.utils.data.DataLoader 一起使用时,用于设置随机种子的方法根据 num_workers 不同。

对于单进程模式(num_workers == 0),在主进程中的 DataLoader 之前设置随机种子。对于多进程模式(num_worker > 0),worker_init_fn 用于为每个工作进程设置随机种子。

参数:
  • datapipe – 正在洗牌的 IterDataPipe

  • buffer_size – 洗牌的缓冲区大小(默认为 10000

  • unbatch_level – 指定在应用洗牌之前是否需要取消源数据的批次

示例

>>> # xdoctest: +SKIP
>>> from torchdata.datapipes.iter import IterableWrapper
>>> dp = IterableWrapper(range(10))
>>> shuffle_dp = dp.shuffle()
>>> list(shuffle_dp)
[0, 4, 1, 6, 3, 2, 9, 5, 7, 8]

文档

访问 PyTorch 的全面开发者文档

查看文档

教程

获取针对初学者和高级开发人员的深入教程

查看教程

资源

查找开发资源并获得问题的解答

查看资源