洗牌器¶
- class torchdata.datapipes.iter.Shuffler(datapipe: IterDataPipe[T_co], *, buffer_size: int = 10000, unbatch_level: int = 0)¶
使用缓冲区对输入数据管道进行洗牌(函数名:
shuffle
)。首先,使用
buffer_size
填充来自数据管道的元素的缓冲区。然后,通过迭代器使用蓄水池抽样从缓冲区中生成每个项目。buffer_size
必须大于0
。对于buffer_size == 1
,数据管道不会被洗牌。为了完全对数据管道中的所有元素进行洗牌,buffer_size
必须大于或等于数据管道的大小。当与
torch.utils.data.DataLoader
一起使用时,用于设置随机种子的方法根据num_workers
不同。对于单进程模式(
num_workers == 0
),在主进程中的DataLoader
之前设置随机种子。对于多进程模式(num_worker > 0
),worker_init_fn 用于为每个工作进程设置随机种子。- 参数:
datapipe – 正在洗牌的 IterDataPipe
buffer_size – 洗牌的缓冲区大小(默认为
10000
)unbatch_level – 指定在应用洗牌之前是否需要取消源数据的批次
示例
>>> # xdoctest: +SKIP >>> from torchdata.datapipes.iter import IterableWrapper >>> dp = IterableWrapper(range(10)) >>> shuffle_dp = dp.shuffle() >>> list(shuffle_dp) [0, 4, 1, 6, 3, 2, 9, 5, 7, 8]