SampleMultiplexer¶
- class torchdata.datapipes.iter.SampleMultiplexer(pipes_to_weights_dict: Dict[IterDataPipe[T_co], float], seed: Optional[int] = None)¶
接收一个 (IterDataPipe, Weight) 的 Dict,并根据其权重从这些 DataPipes 中采样生成项目。当各个 DataPipes 用尽时,会继续根据剩余 DataPipes 的相对权重进行采样。如果您希望无限期地保持相同的权重比率,则需要确保输入永不耗尽,例如,通过对它们应用
cycle
。采样由提供的随机
seed
控制。如果您未提供它,则采样将不是确定性的。- 参数:
pipes_to_weights_dict – 一个包含 IterDataPipes 和权重的 Dict。出于采样目的,将对未用尽的 DataPipes 的总权重归一化为 1。
seed – 初始化随机数生成器的随机种子
示例
>>> from torchdata.datapipes.iter import IterableWrapper, SampleMultiplexer >>> source_dp1 = IterableWrapper([0] * 5) >>> source_dp2 = IterableWrapper([1] * 5) >>> d = {source_dp1: 99999999, source_dp2: 0.0000001} >>> sample_mul_dp = SampleMultiplexer(pipes_to_weights_dict=d, seed=0) >>> list(sample_mul_dp) [0, 0, 0, 0, 0, 1, 1, 1, 1, 1]