padded_collate¶
- torchtune.data.padded_collate(batch: List[Dict[str, List[int]]], *, pad_direction: str, keys_to_pad: List[str], padding_idx: Union[int, Dict[str, int]])[来源]¶
一个通用的填充整理函数,它根据给定的
pad_direction
对批量序列中的keys_to_pad
条目进行填充,直到达到批量中每个条目的最大序列长度。注意
此函数假定所有不在
keys_to_pad
中的批量元素不需要任何整理(参见下面的示例)。- 参数:
pad_direction (str) – 是从左侧还是右侧填充条目。如果
pad_direction="right"
,我们使用torch.nn.utils.rnn.pad_sequence()
;否则,如果pad_direction="left"
,我们使用torchtune.data.left_pad_sequence()
。keys_to_pad (List[str]) – 应用填充的批量元素键。应该是批量中键的子集。
padding_idx (Union[int, Dict[str, int]]) – 用于应用于所有
keys_to_pad
元素的单个整数填充值,或键与keys_to_pad
相同、包含按键填充值的映射。
- 返回值:
填充后的输入 ID 张量,形状为
[batch_size, max_seq_len]
。- 返回类型:
- 抛出异常:
ValueError – 如果
pad_direction
不是“left”或“right”之一,**或者**如果keys_to_pad
为空或不是列表,**或者**如果keys_to_pad
不是批量中键的子集,**或者**如果padding_idx
以字典形式提供,但键与keys_to_pad
不完全相同
示例
>>> a = [1, 2, 3] >>> b = [4, 5, 6, 7] >>> c = [8, 9, 10, 11, 12] >>> batch = [ >>> {"tokens": a, "labels": 1}, >>> {"tokens": b, "labels": 3}, >>> {"tokens": c, "labels": 0}, >>> ] >>> padded_collate( >>> batch, >>> pad_direction="left", >>> keys_to_pad=["tokens"], >>> padding_idx=-10 >>> ) { 'labels': tensor([1, 3, 0]), 'tokens': tensor([[-10, -10, 1, 2, 3], [-10, 4, 5, 6, 7], [ 8, 9, 10, 11, 12]]) }