快捷方式

torch.nn.utils.rnn.pad_packed_sequence

torch.nn.utils.rnn.pad_packed_sequence(sequence, batch_first=False, padding_value=0.0, total_length=None)[source]

填充一批可变长度序列的打包数据。

它是 pack_padded_sequence() 的逆操作。

返回的 Tensor 的数据大小为 T x B x *(如果 batch_firstFalse)或 B x T x *(如果 batch_firstTrue),其中 T 是最长序列的长度,B 是批次大小。

示例

>>> from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence
>>> seq = torch.tensor([[1, 2, 0], [3, 0, 0], [4, 5, 6]])
>>> lens = [2, 1, 3]
>>> packed = pack_padded_sequence(seq, lens, batch_first=True, enforce_sorted=False)
>>> packed
PackedSequence(data=tensor([4, 1, 3, 5, 2, 6]), batch_sizes=tensor([3, 2, 1]),
               sorted_indices=tensor([2, 0, 1]), unsorted_indices=tensor([1, 2, 0]))
>>> seq_unpacked, lens_unpacked = pad_packed_sequence(packed, batch_first=True)
>>> seq_unpacked
tensor([[1, 2, 0],
        [3, 0, 0],
        [4, 5, 6]])
>>> lens_unpacked
tensor([2, 1, 3])

注意

total_length 用于在包装在 Module 中的 DataParallel 中实现 pack sequence -> recurrent network -> unpack sequence 模式。有关详细信息,请参阅 此常见问题解答部分

参数
  • sequence (PackedSequence) – 要填充的批次

  • batch_first (bool, 可选) – 如果 True,则输出将采用 B x T x * 格式,否则为 T x B x *

  • padding_value (float, 可选) – 填充元素的值。

  • total_length (int, 可选) – 如果不为 None,则输出将被填充为长度 total_length。如果 total_length 小于 sequence 中的最大序列长度,则此方法将引发 ValueError

返回

包含填充序列的 Tensor 和包含批次中每个序列长度列表的 Tensor 的元组。批次元素将重新排序,就像最初将批次传递给 pack_padded_sequencepack_sequence 时一样。

返回类型

Tuple[Tensor, Tensor]

文档

访问 PyTorch 的全面开发者文档

查看文档

教程

获取面向初学者和高级开发者的深入教程

查看教程

资源

查找开发资源并获得问题的解答

查看资源