快捷方式

padded_collate_sft

torchtune.data.padded_collate_sft(batch: List[Dict[str, List[int]]], padding_idx: int = 0, ignore_idx: int = - 100) Dict[str, Tensor][source]

将一批序列填充到批次中最长的序列长度,并将整数列表转换为张量。

参数::
  • batch (List[Dict[str, List[int]]]) – 包含输入、标签对的字典列表。

  • padding_idx (int) – 输入 ID 的填充索引。默认为 0。

  • ignore_idx (int) – 标签的填充索引。默认为 -100。

返回值::

合并的输入和标签张量。

返回类型::

Dict[str, torch.Tensor]

示例

>>> token_pairs = [
>>>    {"tokens": [1, 2, 3], "labels": [4, 5, 6]},
>>>    {"tokens": [7,], "labels": [10,]},
>>> ]
>>> collated = padded_collate(
>>>    batch=token_pairs,
>>>    padding_idx=padding_idx,
>>>    ignore_idx=ignore_idx,
>>> )
>>> collated["tokens"]
>>> tensor([[1, 2, 3], [7, 0, 0]])
>>> collated["labels"]
>>> tensor([[4, 5, 6], [10, -100, -100]])

文档

访问 PyTorch 的全面开发者文档

查看文档

教程

获得针对初学者和高级开发人员的深入教程

查看教程

资源

查找开发资源并获得问题的解答

查看资源