BucketBatcher¶
- class torchdata.datapipes.iter.BucketBatcher(datapipe: IterDataPipe[T_co], batch_size: int, drop_last: bool = False, batch_num: int = 100, bucket_num: int = 1, sort_key: Optional[Callable] = None, use_in_batch_shuffle: bool = True)¶
从排序的桶中创建数据的小批量(函数名称:
bucketbatch
)。如果drop_last
设置为True
,则会添加一个外部维度作为batch_size
,或者如果drop_last
设置为False
,则为最后一批添加length % batch_size
。此 DataPipe 的目的是根据传递的排序函数对具有相似性的样本进行批处理。例如,在文本领域中,它可能将具有类似令牌数的示例进行批处理,以最大程度地减少填充并提高吞吐量。
- 参数::
datapipe – 要进行批处理的可迭代 DataPipe
batch_size – 每个批次的大小
drop_last – 如果最后一批不完整,则选择丢弃最后一批
batch_num – 桶内的批次数量(即 bucket_size = batch_size * batch_num)
bucket_num – 用于洗牌的池中包含的桶的数量(即 pool_size = bucket_size * bucket_num)
sort_key – 用于对桶(列表)进行排序的可调用对象
use_in_batch_shuffle – 如果为 True,则进行批内洗牌;如果为 False,则进行缓冲洗牌
示例
>>> from torchdata.datapipes.iter import IterableWrapper >>> source_dp = IterableWrapper(range(10)) >>> batch_dp = source_dp.bucketbatch(batch_size=3, drop_last=True) >>> list(batch_dp) [[5, 6, 7], [9, 0, 1], [4, 3, 2]] >>> def sort_bucket(bucket): >>> return sorted(bucket) >>> batch_dp = source_dp.bucketbatch( >>> batch_size=3, drop_last=True, batch_num=100, >>> bucket_num=1, use_in_batch_shuffle=False, sort_key=sort_bucket >>> ) >>> list(batch_dp) [[3, 4, 5], [6, 7, 8], [0, 1, 2]]