padded_collate_sft¶
- torchtune.data.padded_collate_sft(batch: List[Dict[str, List[int]]], padding_idx: int = 0, ignore_idx: int = - 100) Dict[str, Tensor] [source]¶
将一批序列填充到批次中最长的序列长度,并将整数列表转换为张量。
- 参数::
- 返回值::
合并的输入和标签张量。
- 返回类型::
Dict[str, torch.Tensor]
示例
>>> token_pairs = [ >>> {"tokens": [1, 2, 3], "labels": [4, 5, 6]}, >>> {"tokens": [7,], "labels": [10,]}, >>> ] >>> collated = padded_collate( >>> batch=token_pairs, >>> padding_idx=padding_idx, >>> ignore_idx=ignore_idx, >>> ) >>> collated["tokens"] >>> tensor([[1, 2, 3], [7, 0, 0]]) >>> collated["labels"] >>> tensor([[4, 5, 6], [10, -100, -100]])