torch.nn.utils.rnn.pad_sequence¶
- torch.nn.utils.rnn.pad_sequence(sequences, batch_first=False, padding_value=0.0, padding_side='right')[源][源]¶
使用
padding_value
填充变长 Tensor 列表。pad_sequence
沿新维度堆叠变长 Tensor 列表,并将它们填充到等长。sequences
可以是大小为L x *
的序列列表,其中 L 是序列长度,*
是任意数量的维度(包括0
)。如果batch_first
为False
,输出大小为T x B x *
;否则为B x T x *
,其中B
是批次大小(sequences
中的元素数量),T
是最长序列的长度。示例
>>> from torch.nn.utils.rnn import pad_sequence >>> a = torch.ones(25, 300) >>> b = torch.ones(22, 300) >>> c = torch.ones(15, 300) >>> pad_sequence([a, b, c]).size() torch.Size([25, 3, 300])
注意
此函数返回一个大小为
T x B x *
或B x T x *
的 Tensor,其中 T 是最长序列的长度。此函数假定 sequences 中所有 Tensor 的尾随维度和类型相同。- 参数
- 返回
如果
batch_first
为False
,则返回大小为T x B x *
的 Tensor。否则返回大小为B x T x *
的 Tensor- 返回类型