Collator¶
- class torchdata.datapipes.iter.Collator(datapipe: ~IterDataPipe, conversion: ~Optional[~Union[~Callable[[...], ~Any], ~Dict[~Union[str, ~Any], ~Union[~Callable, ~Any]]]] = <function default_collate>, collate_fn: ~Optional[~Callable] = None)¶
通过自定义合并函数(函数名称:
collate
)将 DataPipe 中的样本合并为 Tensor(s)。默认情况下,它使用
torch.utils.data.default_collate()
。注意
在编写自定义合并函数时,您可以导入
torch.utils.data.default_collate()
以获取默认行为,并使用 functools.partial 指定任何其他参数。- 参数:
datapipe – 正在合并的 Iterable DataPipe
collate_fn – 用于收集和组合数据或一批数据的自定义合并函数。默认函数根据数据类型合并为 Tensor(s)。
示例
>>> # xdoctest: +SKIP >>> # Convert integer data to float Tensor >>> class MyIterDataPipe(torch.utils.data.IterDataPipe): ... def __init__(self, start, end): ... super(MyIterDataPipe).__init__() ... assert end > start, "this example code only works with end >= start" ... self.start = start ... self.end = end ... ... def __iter__(self): ... return iter(range(self.start, self.end)) ... ... def __len__(self): ... return self.end - self.start ... >>> ds = MyIterDataPipe(start=3, end=7) >>> print(list(ds)) [3, 4, 5, 6] >>> def collate_fn(batch): ... return torch.tensor(batch, dtype=torch.float) ... >>> collated_ds = CollateIterDataPipe(ds, collate_fn=collate_fn) >>> print(list(collated_ds)) [tensor(3.), tensor(4.), tensor(5.), tensor(6.)]