分组器¶
- class torchdata.datapipes.iter.Grouper(datapipe: IterDataPipe[T_co], group_key_fn: Callable[[T_co], Any], *, keep_key: bool = False, buffer_size: int = 10000, group_size: Optional[int] = None, guaranteed_group_size: Optional[int] = None, drop_remaining: bool = False)¶
通过来自
group_key_fn
的键对来自 IterDataPipe 的数据进行分组,生成一个具有不超过group_size
批量大小的DataChunk
。(函数名:
groupby
)。样本从源
datapipe
中按顺序读取,当批次大小达到group_size
时,将生成属于同一组的样本批次。当缓冲区已满时,数据管道将生成具有相同键的最大批次,前提是其大小大于guaranteed_group_size
。如果其大小较小,则如果drop_remaining=True
,它将被丢弃。在遍历源
datapipe
的全部内容后,由于缓冲区容量而未被丢弃的所有内容都将从缓冲区生成,即使组大小小于guaranteed_group_size
。- 参数:
datapipe – 要分组的可迭代数据管道
group_key_fn – 用于从源数据管道的数据生成组键的函数
keep_key – 选项是在元组中将匹配的键与项目一起生成,结果为 (key, [items]),否则返回 [items]
buffer_size – 未分组数据的缓冲区大小
group_size – 每组的最大大小,一旦达到此大小,就会生成一个批次
guaranteed_group_size – 缓冲区已满时,保证的最小组大小将被生成
drop_remaining – 指定如果缓冲区已满时,小于
guaranteed_group_size
的组是否将从缓冲区中丢弃
示例
>>> import os >>> # xdoctest: +SKIP >>> from torchdata.datapipes.iter import IterableWrapper >>> def group_fn(file): ... return os.path.basename(file).split(".")[0] >>> source_dp = IterableWrapper(["a.png", "b.png", "a.json", "b.json", "a.jpg", "c.json"]) >>> dp0 = source_dp.groupby(group_key_fn=group_fn) >>> list(dp0) [['a.png', 'a.json', 'a.jpg'], ['b.png', 'b.json'], ['c.json']] >>> # A group is yielded as soon as its size equals to `group_size` >>> dp1 = source_dp.groupby(group_key_fn=group_fn, group_size=2) >>> list(dp1) [['a.png', 'a.json'], ['b.png', 'b.json'], ['a.jpg'], ['c.json']] >>> # Scenario where `buffer` is full, and group 'a' needs to be yielded since its size > `guaranteed_group_size` >>> dp2 = source_dp.groupby(group_key_fn=group_fn, buffer_size=3, group_size=3, guaranteed_group_size=2) >>> list(dp2) [['a.png', 'a.json'], ['b.png', 'b.json'], ['a.jpg'], ['c.json']]