快捷方式

分组器

class torchdata.datapipes.iter.Grouper(datapipe: IterDataPipe[T_co], group_key_fn: Callable[[T_co], Any], *, keep_key: bool = False, buffer_size: int = 10000, group_size: Optional[int] = None, guaranteed_group_size: Optional[int] = None, drop_remaining: bool = False)

通过来自 group_key_fn 的键对来自 IterDataPipe 的数据进行分组,生成一个具有不超过 group_size 批量大小的 DataChunk

(函数名: groupby)。

样本从源 datapipe 中按顺序读取,当批次大小达到 group_size 时,将生成属于同一组的样本批次。当缓冲区已满时,数据管道将生成具有相同键的最大批次,前提是其大小大于 guaranteed_group_size。如果其大小较小,则如果 drop_remaining=True,它将被丢弃。

在遍历源 datapipe 的全部内容后,由于缓冲区容量而未被丢弃的所有内容都将从缓冲区生成,即使组大小小于 guaranteed_group_size

参数:
  • datapipe – 要分组的可迭代数据管道

  • group_key_fn – 用于从源数据管道的数据生成组键的函数

  • keep_key – 选项是在元组中将匹配的键与项目一起生成,结果为 (key, [items]),否则返回 [items]

  • buffer_size – 未分组数据的缓冲区大小

  • group_size – 每组的最大大小,一旦达到此大小,就会生成一个批次

  • guaranteed_group_size – 缓冲区已满时,保证的最小组大小将被生成

  • drop_remaining – 指定如果缓冲区已满时,小于 guaranteed_group_size 的组是否将从缓冲区中丢弃

示例

>>> import os
>>> # xdoctest: +SKIP
>>> from torchdata.datapipes.iter import IterableWrapper
>>> def group_fn(file):
...     return os.path.basename(file).split(".")[0]
>>> source_dp = IterableWrapper(["a.png", "b.png", "a.json", "b.json", "a.jpg", "c.json"])
>>> dp0 = source_dp.groupby(group_key_fn=group_fn)
>>> list(dp0)
[['a.png', 'a.json', 'a.jpg'], ['b.png', 'b.json'], ['c.json']]
>>> # A group is yielded as soon as its size equals to `group_size`
>>> dp1 = source_dp.groupby(group_key_fn=group_fn, group_size=2)
>>> list(dp1)
[['a.png', 'a.json'], ['b.png', 'b.json'], ['a.jpg'], ['c.json']]
>>> # Scenario where `buffer` is full, and group 'a' needs to be yielded since its size > `guaranteed_group_size`
>>> dp2 = source_dp.groupby(group_key_fn=group_fn, buffer_size=3, group_size=3, guaranteed_group_size=2)
>>> list(dp2)
[['a.png', 'a.json'], ['b.png', 'b.json'], ['a.jpg'], ['c.json']]

文档

访问 PyTorch 的全面开发者文档

查看文档

教程

获取针对初学者和高级开发人员的深入教程

查看教程

资源

查找开发资源并获得答案

查看资源