快捷方式

padded_collate

torchtune.data.padded_collate(batch: List[Dict[str, List[int]]], *, pad_direction: str, keys_to_pad: List[str], padding_idx: Union[int, Dict[str, int]])[源代码]

一个通用的填充整理函数,它从给定的 pad_direction 对批次中序列的 keys_to_pad 条目进行填充,以填充批次中每个条目的最大序列长度。

注意

此函数假设所有不在 keys_to_pad 中的批次元素不需要任何整理(请参见下面的示例)。

参数::
  • batch (List[Dict[str, List[int]]]) – 包含输入的字典列表。

  • pad_direction (str) – 是否从左侧或右侧填充条目。如果 pad_direction="right",我们使用 torch.nn.utils.rnn.pad_sequence(),否则如果 pad_direction="left",我们使用 torchtune.data.left_pad_sequence().

  • keys_to_pad (List[str]) – 要应用填充的批次元素键。应为批次中键的子集。

  • padding_idx (Union[int, Dict[str, int]]) – 要应用于所有 keys_to_pad 元素的单个整数填充值,或一个键与 keys_to_pad 相同的映射,包含每个键的填充值。

返回值::

形状为 [batch_size, max_seq_len] 的填充输入 ID 张量。

返回类型::

torch.Tensor

引发::
  • ValueError – 如果 pad_direction 不是 “left” 或 “right” 之一。

  • ValueError – 如果 keys_to_pad 为空,或不是列表,或不是批次中键的子集。

  • ValueError – 如果 padding_idx 以字典形式提供,但键与 keys_to_pad 不相同。

示例

>>> a = [1, 2, 3]
>>> b = [4, 5, 6, 7]
>>> c = [8, 9, 10, 11, 12]
>>> batch = [
>>>     {"tokens": a, "labels": 1},
>>>     {"tokens": b, "labels": 3},
>>>     {"tokens": c, "labels": 0},
>>> ]
>>> padded_collate(
>>>     batch,
>>>     pad_direction="left",
>>>     keys_to_pad=["tokens"],
>>>     padding_idx=-10
>>> )
{
    'labels': tensor([1, 3, 0]),
    'tokens': tensor([[-10, -10,   1,   2,   3],
                      [-10,   4,   5,   6,   7],
                      [  8,   9,  10,  11,  12]])
}

文档

访问 PyTorch 的全面的开发者文档

查看文档

教程

获取面向初学者和高级开发人员的深入教程

查看教程

资源

查找开发资源并获得问题的解答

查看资源