快捷方式

torch.nn.utils.rnn.pack_padded_sequence

torch.nn.utils.rnn.pack_padded_sequence(input, lengths, batch_first=False, enforce_sorted=True)[源代码][源代码]

打包包含变长填充序列的张量。

input 的大小可以是 T x B x * (如果 batch_firstFalse) 或 B x T x * (如果 batch_firstTrue),其中 T 是最长序列的长度,B 是批大小,而 * 是任意数量的维度 (包括 0)。

对于未排序的序列,请使用 enforce_sorted = False。如果 enforce_sortedTrue,则序列应按长度降序排序,即 input[:,0] 应为最长序列,而 input[:,B-1] 应为最短序列。enforce_sorted = True 仅在 ONNX 导出时是必要的。

它是 pad_packed_sequence() 的逆运算,因此 pad_packed_sequence() 可用于恢复在 PackedSequence 中打包的基础张量。

注意

此函数接受至少具有两个维度的任何输入。您可以应用它来打包标签,并将 RNN 的输出与它们一起使用以直接计算损失。可以通过访问 PackedSequence 对象的 .data 属性从该对象中检索张量。

参数
  • input (Tensor) – 变长序列的填充批次。

  • lengths (Tensorlist(int)) – 每个批次元素的序列长度列表(如果作为张量提供,则必须在 CPU 上)。

  • batch_first (bool, 可选) – 如果 True,则输入应为 B x T x * 格式,否则为 T x B x *

  • enforce_sorted (bool, 可选) – 如果 True,则输入应包含按长度降序排序的序列。如果 False,则输入将被无条件排序。默认值:True

返回

一个 PackedSequence 对象

返回类型

PackedSequence

文档

访问 PyTorch 的全面开发者文档

查看文档

教程

获取面向初学者和高级开发者的深入教程

查看教程

资源

查找开发资源并获得问题解答

查看资源