MaxTokenBucketizer¶
- class torchdata.datapipes.iter.MaxTokenBucketizer(datapipe: ~IterDataPipe[~T_co], max_token_count: int, len_fn: ~Callable = <function _default_len_fn>, min_len: int = 0, max_len: ~Optional[int] = None, buffer_size: int = 1000, include_padding: bool = False)¶
从大小有限的最小堆中创建数据的小批量,并且每个批次中由
len_fn
返回的样本的总长度将受max_token_count
限制(功能名称:max_token_bucketize
)。如果设置了min_len
或max_len
,则长度超出[min_len, max_len]
的样本将被过滤掉。此 DataPipe 的目的是根据
len_fn
对长度相似的样本进行批处理。此处使用最小堆以确保样本根据长度递增排序。并且,保证每个批次中样本的总长度小于max_token_count
。例如,在音频领域,它可能是对长度相似的样本进行批处理。然后,给定max_token_count
,每个批次可以连接到大小相同且填充最少的张量。如果将
include_padding
设置为True
,则每个批次的令牌计数包括后续 DataPipe 可能添加的填充。这保证了即使在批次填充后,也不会超过max_token_count
。这可以防止数据长度差异较大时出现内存不足问题。请注意,批次是从缓冲区中最小的尺寸开始进行分桶的。如果
buffer_size
很大,这可能会限制批次的变异性。要增加变异性,请在此 DataPipe 之前和之后应用torchdata.datapipes.iter.Shuffler
,并将buffer_size
保持较小。- 参数:
datapipe – 正在进行批处理的可迭代 DataPipe
max_token_count – 每个批次中数据总长度的最大长度
len_fn – 应用于每个元素以获取长度的函数。默认情况下使用
len(data)
。min_len – 要包含在每个批次中的可选最小长度
max_len – 要包含在每个批次中的可选最大长度。
buffer_size – 这限制了从先前 DataPipe 中获取多少样本进行分桶
include_padding – 如果为 True,则每个批次的大小包括批次中最大长度的额外填充。
示例
>>> from torchdata.datapipes.iter import IterableWrapper >>> source_dp = IterableWrapper(['1', '11', '1', '1111', '111', '1', '11', '11', '111']) >>> # Using default len_fn to sort samples based on length (string length in this case) >>> batch_dp = source_dp.max_token_bucketize(max_token_count=5) >>> list(batch_dp) [['1', '1', '1', '11'], ['11', '11'], ['111'], ['111'], ['1111']] >>> batch_dp = source_dp.max_token_bucketize(max_token_count=4, buffer_size=4) >>> list(batch_dp) [['1', '1', '1'], ['11', '11'], ['11'], ['111'], ['111'], ['1111']]