torch.nn.utils.rnn.pad_packed_sequence¶
- torch.nn.utils.rnn.pad_packed_sequence(sequence, batch_first=False, padding_value=0.0, total_length=None)[源代码][源代码]¶
填充打包的可变长度序列批次。
它是
pack_padded_sequence()
的逆运算。返回的 Tensor 的数据大小将为
T x B x *
(如果batch_first
为False
)或B x T x *
(如果batch_first
为True
),其中T
是最长序列的长度,B
是批次大小。示例
>>> from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence >>> seq = torch.tensor([[1, 2, 0], [3, 0, 0], [4, 5, 6]]) >>> lens = [2, 1, 3] >>> packed = pack_padded_sequence(seq, lens, batch_first=True, enforce_sorted=False) >>> packed PackedSequence(data=tensor([4, 1, 3, 5, 2, 6]), batch_sizes=tensor([3, 2, 1]), sorted_indices=tensor([2, 0, 1]), unsorted_indices=tensor([1, 2, 0])) >>> seq_unpacked, lens_unpacked = pad_packed_sequence(packed, batch_first=True) >>> seq_unpacked tensor([[1, 2, 0], [3, 0, 0], [4, 5, 6]]) >>> lens_unpacked tensor([2, 1, 3])
注意
total_length
可用于在Module
中实现pack sequence -> recurrent network -> unpack sequence
模式,该模块包装在DataParallel
中。有关详细信息,请参阅 此 FAQ 部分。- 参数
sequence (PackedSequence) – 要填充的批次
batch_first (bool, 可选) – 如果为
True
,则输出将为B x T x *
格式,否则为T x B x *
。padding_value (float, 可选) – 填充元素的值。
total_length (int, 可选) – 如果不为
None
,则输出将被填充为长度为total_length
。如果total_length
小于sequence
中的最大序列长度,则此方法将抛出ValueError
。
- 返回值
包含填充序列的 Tensor 元组,以及包含批次中每个序列长度列表的 Tensor。批次元素将按照批次传递给
pack_padded_sequence
或pack_sequence
时的原始顺序重新排序。- 返回类型