快捷方式

padded_collate

torchtune.data.padded_collate(batch: List[Dict[str, List[int]]], *, pad_direction: str, keys_to_pad: List[str], padding_idx: Union[int, Dict[str, int]])[来源]

一个通用的填充整理函数,它根据给定的 pad_direction 对批量序列中的 keys_to_pad 条目进行填充,直到达到批量中每个条目的最大序列长度。

注意

此函数假定所有不在 keys_to_pad 中的批量元素不需要任何整理(参见下面的示例)。

参数:
  • batch (List[Dict[str, List[int]]]) – 包含输入的字典列表。

  • pad_direction (str) – 是从左侧还是右侧填充条目。如果 pad_direction="right",我们使用 torch.nn.utils.rnn.pad_sequence();否则,如果 pad_direction="left",我们使用 torchtune.data.left_pad_sequence()

  • keys_to_pad (List[str]) – 应用填充的批量元素键。应该是批量中键的子集。

  • padding_idx (Union[int, Dict[str, int]]) – 用于应用于所有 keys_to_pad 元素的单个整数填充值,或键与 keys_to_pad 相同、包含按键填充值的映射。

返回值:

填充后的输入 ID 张量,形状为 [batch_size, max_seq_len]

返回类型:

torch.Tensor

抛出异常:

ValueError – 如果 pad_direction 不是“left”或“right”之一,**或者**如果 keys_to_pad 为空或不是列表,**或者**如果 keys_to_pad 不是批量中键的子集,**或者**如果 padding_idx 以字典形式提供,但键与 keys_to_pad 不完全相同

示例

>>> a = [1, 2, 3]
>>> b = [4, 5, 6, 7]
>>> c = [8, 9, 10, 11, 12]
>>> batch = [
>>>     {"tokens": a, "labels": 1},
>>>     {"tokens": b, "labels": 3},
>>>     {"tokens": c, "labels": 0},
>>> ]
>>> padded_collate(
>>>     batch,
>>>     pad_direction="left",
>>>     keys_to_pad=["tokens"],
>>>     padding_idx=-10
>>> )
{
    'labels': tensor([1, 3, 0]),
    'tokens': tensor([[-10, -10,   1,   2,   3],
                      [-10,   4,   5,   6,   7],
                      [  8,   9,  10,  11,  12]])
}

文档

查阅 PyTorch 的全面开发者文档

查看文档

教程

获取针对初学者和高级开发者的深入教程

查看教程

资源

查找开发资源并解答你的疑问

查看资源