随机拆分器¶
- class torchdata.datapipes.iter.RandomSplitter(source_datapipe: IterDataPipe, weights: Dict[T, Union[int, float]], seed, total_length: Optional[int] = None, target: Optional[T] = None)¶
将来自源数据管道的样本随机拆分成组(函数名:
random_split
)。由于没有缓冲,因此一次只能迭代一组样本(即一个子数据管道)。尝试同时迭代多个数据管道将失败。请注意,默认情况下,此数据管道的多次迭代将产生相同的拆分,以确保跨 epoch 的一致性。您可以在输出上调用
override_seed
以根据需要更新种子(例如,每个 epoch 更新一次以获得每个 epoch 的不同拆分)。- 参数:
source_datapipe – 被拆分的可迭代数据管道
weights – 权重的字典;此列表的长度决定了将生成多少个输出数据管道。建议提供加起来为
total_length
的整数权重,这使得能够提前知道结果数据管道的长度值。seed – 用于确定拆分随机性的随机种子
total_length –
source_datapipe
的长度,可选,但强烈建议提供一个整数,因为并非所有IterDataPipe
都具有len
,尤其是那些可以轻松提前知道的。target – 可选键(必须存在于
weights
中),用于指示要返回的特定组。如果设置为默认值None
,则返回List[IterDataPipe]
。如果指定了目标,则返回IterDataPipe
。
示例
>>> from torchdata.datapipes.iter import IterableWrapper >>> dp = IterableWrapper(range(10)) >>> train, valid = dp.random_split(total_length=10, weights={"train": 0.5, "valid": 0.5}, seed=0) >>> list(train) [2, 3, 5, 7, 8] >>> list(valid) [0, 1, 4, 6, 9] >>> # You can also specify a target key if you only need a specific group of samples >>> train = dp.random_split(total_length=10, weights={"train": 0.5, "valid": 0.5}, seed=0, target='train') >>> list(train) [2, 3, 5, 7, 8] >>> # Be careful to use the same seed as before when specifying `target` to get the correct split. >>> valid = dp.random_split(total_length=10, weights={"train": 0.5, "valid": 0.5}, seed=0, target='valid') >>> list(valid) [0, 1, 4, 6, 9]