快捷方式

MaxTokenBucketizer

class torchdata.datapipes.iter.MaxTokenBucketizer(datapipe: ~IterDataPipe[~T_co], max_token_count: int, len_fn: ~Callable = <function _default_len_fn>, min_len: int = 0, max_len: ~Optional[int] = None, buffer_size: int = 1000, include_padding: bool = False)

从大小有限的最小堆中创建数据的小批量,并且每个批次中由 len_fn 返回的样本的总长度将受 max_token_count 限制(功能名称: max_token_bucketize)。如果设置了 min_lenmax_len,则长度超出 [min_len, max_len] 的样本将被过滤掉。

此 DataPipe 的目的是根据 len_fn 对长度相似的样本进行批处理。此处使用最小堆以确保样本根据长度递增排序。并且,保证每个批次中样本的总长度小于 max_token_count。例如,在音频领域,它可能是对长度相似的样本进行批处理。然后,给定 max_token_count,每个批次可以连接到大小相同且填充最少的张量。

如果将 include_padding 设置为 True,则每个批次的令牌计数包括后续 DataPipe 可能添加的填充。这保证了即使在批次填充后,也不会超过 max_token_count。这可以防止数据长度差异较大时出现内存不足问题。

请注意,批次是从缓冲区中最小的尺寸开始进行分桶的。如果 buffer_size 很大,这可能会限制批次的变异性。要增加变异性,请在此 DataPipe 之前和之后应用 torchdata.datapipes.iter.Shuffler,并将 buffer_size 保持较小。

参数:
  • datapipe – 正在进行批处理的可迭代 DataPipe

  • max_token_count – 每个批次中数据总长度的最大长度

  • len_fn – 应用于每个元素以获取长度的函数。默认情况下使用 len(data)

  • min_len – 要包含在每个批次中的可选最小长度

  • max_len – 要包含在每个批次中的可选最大长度。

  • buffer_size – 这限制了从先前 DataPipe 中获取多少样本进行分桶

  • include_padding – 如果为 True,则每个批次的大小包括批次中最大长度的额外填充。

示例

>>> from torchdata.datapipes.iter import IterableWrapper
>>> source_dp = IterableWrapper(['1', '11', '1', '1111', '111', '1', '11', '11', '111'])
>>> # Using default len_fn to sort samples based on length (string length in this case)
>>> batch_dp = source_dp.max_token_bucketize(max_token_count=5)
>>> list(batch_dp)
[['1', '1', '1', '11'], ['11', '11'], ['111'], ['111'], ['1111']]
>>> batch_dp = source_dp.max_token_bucketize(max_token_count=4, buffer_size=4)
>>> list(batch_dp)
[['1', '1', '1'], ['11', '11'], ['11'], ['111'], ['111'], ['1111']]

文档

访问 PyTorch 的全面开发者文档

查看文档

教程

获取适合初学者和高级开发人员的深入教程

查看教程

资源

查找开发资源并获得问题的解答

查看资源