快捷方式

BucketBatcher

class torchdata.datapipes.iter.BucketBatcher(datapipe: IterDataPipe[T_co], batch_size: int, drop_last: bool = False, batch_num: int = 100, bucket_num: int = 1, sort_key: Optional[Callable] = None, use_in_batch_shuffle: bool = True)

从排序的桶中创建数据的小批量(函数名称:bucketbatch)。如果 drop_last 设置为 True,则会添加一个外部维度作为 batch_size,或者如果 drop_last 设置为 False,则为最后一批添加 length % batch_size

此 DataPipe 的目的是根据传递的排序函数对具有相似性的样本进行批处理。例如,在文本领域中,它可能将具有类似令牌数的示例进行批处理,以最大程度地减少填充并提高吞吐量。

参数::
  • datapipe – 要进行批处理的可迭代 DataPipe

  • batch_size – 每个批次的大小

  • drop_last – 如果最后一批不完整,则选择丢弃最后一批

  • batch_num – 桶内的批次数量(即 bucket_size = batch_size * batch_num

  • bucket_num – 用于洗牌的池中包含的桶的数量(即 pool_size = bucket_size * bucket_num

  • sort_key – 用于对桶(列表)进行排序的可调用对象

  • use_in_batch_shuffle – 如果为 True,则进行批内洗牌;如果为 False,则进行缓冲洗牌

示例

>>> from torchdata.datapipes.iter import IterableWrapper
>>> source_dp = IterableWrapper(range(10))
>>> batch_dp = source_dp.bucketbatch(batch_size=3, drop_last=True)
>>> list(batch_dp)
[[5, 6, 7], [9, 0, 1], [4, 3, 2]]
>>> def sort_bucket(bucket):
>>>     return sorted(bucket)
>>> batch_dp = source_dp.bucketbatch(
>>>     batch_size=3, drop_last=True, batch_num=100,
>>>     bucket_num=1, use_in_batch_shuffle=False, sort_key=sort_bucket
>>> )
>>> list(batch_dp)
[[3, 4, 5], [6, 7, 8], [0, 1, 2]]

文档

访问 PyTorch 的全面开发者文档

查看文档

教程

获取针对初学者和高级开发人员的深入教程

查看教程

资源

查找开发资源并获得问题的解答

查看资源