快捷方式

随机拆分器

class torchdata.datapipes.iter.RandomSplitter(source_datapipe: IterDataPipe, weights: Dict[T, Union[int, float]], seed, total_length: Optional[int] = None, target: Optional[T] = None)

将来自源数据管道的样本随机拆分成组(函数名:random_split)。由于没有缓冲,因此一次只能迭代一组样本(即一个子数据管道)。尝试同时迭代多个数据管道将失败。

请注意,默认情况下,此数据管道的多次迭代将产生相同的拆分,以确保跨 epoch 的一致性。您可以在输出上调用 override_seed 以根据需要更新种子(例如,每个 epoch 更新一次以获得每个 epoch 的不同拆分)。

参数:
  • source_datapipe – 被拆分的可迭代数据管道

  • weights – 权重的字典;此列表的长度决定了将生成多少个输出数据管道。建议提供加起来为 total_length 的整数权重,这使得能够提前知道结果数据管道的长度值。

  • seed – 用于确定拆分随机性的随机种子

  • total_lengthsource_datapipe 的长度,可选,但强烈建议提供一个整数,因为并非所有 IterDataPipe 都具有 len,尤其是那些可以轻松提前知道的。

  • target – 可选键(必须存在于 weights 中),用于指示要返回的特定组。如果设置为默认值 None,则返回 List[IterDataPipe]。如果指定了目标,则返回 IterDataPipe

示例

>>> from torchdata.datapipes.iter import IterableWrapper
>>> dp = IterableWrapper(range(10))
>>> train, valid = dp.random_split(total_length=10, weights={"train": 0.5, "valid": 0.5}, seed=0)
>>> list(train)
[2, 3, 5, 7, 8]
>>> list(valid)
[0, 1, 4, 6, 9]
>>> # You can also specify a target key if you only need a specific group of samples
>>> train = dp.random_split(total_length=10, weights={"train": 0.5, "valid": 0.5}, seed=0, target='train')
>>> list(train)
[2, 3, 5, 7, 8]
>>> # Be careful to use the same seed as before when specifying `target` to get the correct split.
>>> valid = dp.random_split(total_length=10, weights={"train": 0.5, "valid": 0.5}, seed=0, target='valid')
>>> list(valid)
[0, 1, 4, 6, 9]

文档

访问 PyTorch 的全面开发者文档

查看文档

教程

获取针对初学者和高级开发者的深入教程

查看教程

资源

查找开发资源并获得问题的解答

查看资源