快捷方式

torch.nn.utils.rnn.pack_padded_sequence

torch.nn.utils.rnn.pack_padded_sequence(input, lengths, batch_first=False, enforce_sorted=True)[source][source]

打包包含可变长度填充序列的张量。

input 可以是 T x B x * 大小(如果 batch_firstFalse)或 B x T x * 大小(如果 batch_firstTrue),其中 T 是最长序列的长度,B 是批量大小,* 是任意维度数(包括 0)。

对于未排序的序列,请使用 enforce_sorted = False。如果 enforce_sortedTrue,序列应按长度降序排列,即 input[:,0] 应为最长序列,而 input[:,B-1] 应为最短序列。enforce_sorted = True 仅在 ONNX 导出时是必要的。

它是 pad_packed_sequence() 的逆操作,因此可以使用 pad_packed_sequence() 来恢复打包在 PackedSequence 中的底层张量。

注意

此函数接受至少包含两个维度的任何输入。您可以将其用于打包标签,并使用 RNN 的输出与标签一起直接计算损失。可以通过访问 PackedSequence 对象的 .data 属性来从中检索张量。

参数
  • input (Tensor) – 可变长度序列的填充批次。

  • lengths (Tensorlist(int)) – 每个批次元素的序列长度列表(如果以张量形式提供,则必须位于 CPU 上)。

  • batch_first (bool, 可选) – 如果为 True,则输入应为 B x T x * 格式,否则为 T x B x *。默认值:False

  • enforce_sorted (bool, 可选) – 如果为 True,则输入应包含按长度降序排列的序列。如果为 False,输入将无条件地进行排序。默认值:True

返回类型

PackedSequence

警告

如果 input 张量的长度大于 length 中的对应值,则其维度将被截断。

返回

一个 PackedSequence 对象

返回类型

PackedSequence


© 版权所有 PyTorch 贡献者。

使用 Sphinx 构建,主题由 Read the Docs 提供。

文档

查阅 PyTorch 的全面开发者文档

查看文档

教程

获取面向初学者和高级开发者的深入教程

查看教程

资源

查找开发资源并获得问题解答

查看资源