快捷方式

BatchAsyncMapper

class torchdata.datapipes.iter.BatchAsyncMapper(source_datapipe, async_fn: Callable, batch_size: int, input_col=None, output_col=None, max_concurrency: int = 32, flatten: bool = True)

将来自源 DataPipe 的元素组合成批次,并对每个批次中的每个元素并发应用协程函数,然后将输出扁平化为一个单一的、未嵌套的 IterDataPipe(函数名:async_map_batches)。

参数:
  • source_datapipe – 源 IterDataPipe

  • async_fn – 要应用于每个数据批次的协程函数

  • batch_size – 从 source_datapipe 聚合的批次大小

  • input_col

    应用 fn 的数据索引或索引,例如

    • None 作为默认值,直接将 fn 应用于数据。

    • 整数用于列表/元组。

    • 键用于字典。

  • output_col

    放置 fn 结果的数据索引。只有当 input_col 不为 None 时才能指定 output_col

    • None 作为默认值,以替换 input_col 指定的索引;对于具有多个索引的 input_col,使用最左边的索引,其他索引将被删除。

    • 整数用于列表/元组。-1 表示将结果附加到末尾。

    • 键用于字典。新键是可以接受的。

  • max_concurrency – 调用异步函数的最大并发度。(默认:32

  • flatten – 确定批次是否在最后被扁平化(默认:True)如果为 False,则输出将以 batch_size 大小的批次形式存在

示例

>>> from torchdata.datapipes.iter import IterableWrapper
>>> async def mul_ten(x):
...     await asyncio.sleep(1)
...     return x * 10
>>> dp = IterableWrapper(range(50))
>>> dp = dp.async_map_batches(mul_ten, 16)
>>> list(dp)
[0, 10, 20, 30, ...]
>>> dp = IterableWrapper([(i, i) for i in range(50)])
>>> dp = dp.async_map_batches(mul_ten, 16, input_col=1)
>>> list(dp)
[(0, 0), (1, 10), (2, 20), (3, 30), ...]
>>> dp = IterableWrapper([(i, i) for i in range(50)])
>>> dp = dp.async_map_batches(mul_ten, 16, input_col=1, output_col=-1)
>>> list(dp)
[(0, 0, 0), (1, 1, 10), (2, 2, 20), (3, 3, 30), ...]
# Async fetching html from remote
>>> from aiohttp import ClientSession
>>> async def fetch_html(url: str, **kwargs):
...     async with ClientSession() as session:
...         resp = await session.request(method="GET", url=url, **kwargs)
...         resp.raise_for_status()
...         html = await resp.text()
...     return html
>>> dp = IterableWrapper(urls)
>>> dp = dp.async_map_batches(fetch_html, 16)

文档

访问 PyTorch 的全面开发者文档

查看文档

教程

获取面向初学者和高级开发人员的深入教程

查看教程

资源

查找开发资源并获得问题的解答

查看资源