BatchAsyncMapper¶
- class torchdata.datapipes.iter.BatchAsyncMapper(source_datapipe, async_fn: Callable, batch_size: int, input_col=None, output_col=None, max_concurrency: int = 32, flatten: bool = True)¶
将来自源 DataPipe 的元素组合成批次,并对每个批次中的每个元素并发应用协程函数,然后将输出扁平化为一个单一的、未嵌套的 IterDataPipe(函数名:
async_map_batches
)。- 参数:
source_datapipe – 源 IterDataPipe
async_fn – 要应用于每个数据批次的协程函数
batch_size – 从
source_datapipe
聚合的批次大小input_col –
应用
fn
的数据索引或索引,例如None
作为默认值,直接将fn
应用于数据。整数用于列表/元组。
键用于字典。
output_col –
放置
fn
结果的数据索引。只有当input_col
不为None
时才能指定output_col
None
作为默认值,以替换input_col
指定的索引;对于具有多个索引的input_col
,使用最左边的索引,其他索引将被删除。整数用于列表/元组。
-1
表示将结果附加到末尾。键用于字典。新键是可以接受的。
max_concurrency – 调用异步函数的最大并发度。(默认:
32
)flatten – 确定批次是否在最后被扁平化(默认:
True
)如果为False
,则输出将以batch_size
大小的批次形式存在
示例
>>> from torchdata.datapipes.iter import IterableWrapper >>> async def mul_ten(x): ... await asyncio.sleep(1) ... return x * 10 >>> dp = IterableWrapper(range(50)) >>> dp = dp.async_map_batches(mul_ten, 16) >>> list(dp) [0, 10, 20, 30, ...] >>> dp = IterableWrapper([(i, i) for i in range(50)]) >>> dp = dp.async_map_batches(mul_ten, 16, input_col=1) >>> list(dp) [(0, 0), (1, 10), (2, 20), (3, 30), ...] >>> dp = IterableWrapper([(i, i) for i in range(50)]) >>> dp = dp.async_map_batches(mul_ten, 16, input_col=1, output_col=-1) >>> list(dp) [(0, 0, 0), (1, 1, 10), (2, 2, 20), (3, 3, 30), ...] # Async fetching html from remote >>> from aiohttp import ClientSession >>> async def fetch_html(url: str, **kwargs): ... async with ClientSession() as session: ... resp = await session.request(method="GET", url=url, **kwargs) ... resp.raise_for_status() ... html = await resp.text() ... return html >>> dp = IterableWrapper(urls) >>> dp = dp.async_map_batches(fetch_html, 16)