快捷方式

padded_collate

torchtune.data.padded_collate(batch: List[Dict[str, List[int]]], *, pad_direction: str, keys_to_pad: List[str], padding_idx: Union[int, Dict[str, int]])[source]

一个通用的填充整理函数,用于从给定的 pad_direction 方向填充批次中 keys_to_pad 条目的序列,使其达到批次中每个条目的最大序列长度。

注意

此函数假定所有不在 keys_to_pad 中的批次元素都不需要任何整理(参见下面的示例)。

参数:
  • batch (List[Dict[str, List[int]]]) – 包含输入的字典列表。

  • pad_direction (str) – 从左侧还是右侧填充条目。如果 pad_direction="right",我们使用 torch.nn.utils.rnn.pad_sequence(),否则如果 pad_direction="left",我们使用 torchtune.data.left_pad_sequence()

  • keys_to_pad (List[str]) – 要应用填充的批次元素键。应为批次中键的子集。

  • padding_idx (Union[int, Dict[str, int]]) – 要应用于所有 keys_to_pad 元素的单个整数填充值,或与 keys_to_pad 相同的键的映射,带有每个键的填充值。

返回:

填充后的输入 id 张量,形状为 [batch_size, max_seq_len]。

返回类型:

torch.Tensor

引发:
  • ValueError – 如果 pad_direction 不是 “left” 或 “right” 之一。

  • ValueError – 如果 keys_to_pad 为空,或不是列表,或不是批次中键的子集。

  • ValueError – 如果 padding_idx 以字典形式提供,但键与 keys_to_pad 不相同。

示例

>>> a = [1, 2, 3]
>>> b = [4, 5, 6, 7]
>>> c = [8, 9, 10, 11, 12]
>>> batch = [
>>>     {"tokens": a, "labels": 1},
>>>     {"tokens": b, "labels": 3},
>>>     {"tokens": c, "labels": 0},
>>> ]
>>> padded_collate(
>>>     batch,
>>>     pad_direction="left",
>>>     keys_to_pad=["tokens"],
>>>     padding_idx=-10
>>> )
{
    'labels': tensor([1, 3, 0]),
    'tokens': tensor([[-10, -10,   1,   2,   3],
                      [-10,   4,   5,   6,   7],
                      [  8,   9,  10,  11,  12]])
}

文档

访问 PyTorch 的全面开发者文档

查看文档

教程

获取面向初学者和高级开发者的深入教程

查看教程

资源

查找开发资源并获得问题解答

查看资源