快捷方式

Collator

class torchdata.datapipes.iter.Collator(datapipe: ~IterDataPipe, conversion: ~Optional[~Union[~Callable[[...], ~Any], ~Dict[~Union[str, ~Any], ~Union[~Callable, ~Any]]]] = <function default_collate>, collate_fn: ~Optional[~Callable] = None)

通过自定义合并函数(函数名称:collate)将 DataPipe 中的样本合并为 Tensor(s)。

默认情况下,它使用 torch.utils.data.default_collate()

注意

在编写自定义合并函数时,您可以导入 torch.utils.data.default_collate() 以获取默认行为,并使用 functools.partial 指定任何其他参数。

参数:
  • datapipe – 正在合并的 Iterable DataPipe

  • collate_fn – 用于收集和组合数据或一批数据的自定义合并函数。默认函数根据数据类型合并为 Tensor(s)。

示例

>>> # xdoctest: +SKIP
>>> # Convert integer data to float Tensor
>>> class MyIterDataPipe(torch.utils.data.IterDataPipe):
...     def __init__(self, start, end):
...         super(MyIterDataPipe).__init__()
...         assert end > start, "this example code only works with end >= start"
...         self.start = start
...         self.end = end
...
...     def __iter__(self):
...         return iter(range(self.start, self.end))
...
...     def __len__(self):
...         return self.end - self.start
...
>>> ds = MyIterDataPipe(start=3, end=7)
>>> print(list(ds))
[3, 4, 5, 6]
>>> def collate_fn(batch):
...     return torch.tensor(batch, dtype=torch.float)
...
>>> collated_ds = CollateIterDataPipe(ds, collate_fn=collate_fn)
>>> print(list(collated_ds))
[tensor(3.), tensor(4.), tensor(5.), tensor(6.)]

文档

访问 PyTorch 的全面开发者文档

查看文档

教程

获取针对初学者和高级开发人员的深入教程

查看教程

资源

查找开发资源并获得问题的解答

查看资源