快捷方式

torch.unflatten

torch.unflatten(input, dim, sizes) Tensor

在多个维度上扩展输入张量的维度。

另请参见

torch.flatten() 此函数的逆函数。它将多个维度合并为一个维度。

参数
  • input (Tensor) – 输入张量。

  • dim (int) – 要展平的维度,指定为 input.shape 中的索引。

  • sizes (Tuple[int]) – 展平维度的新的形状。它的一个元素可以是 -1,在这种情况下,相应的输出维度将被推断。否则,sizes 的乘积 *必须* 等于 input.shape[dim]

返回值

输入的视图,其中指定维度已展平。

示例:
>>> torch.unflatten(torch.randn(3, 4, 1), 1, (2, 2)).shape
torch.Size([3, 2, 2, 1])
>>> torch.unflatten(torch.randn(3, 4, 1), 1, (-1, 2)).shape
torch.Size([3, 2, 2, 1])
>>> torch.unflatten(torch.randn(5, 12, 3), -2, (2, 2, 3, 1, 1)).shape
torch.Size([5, 2, 2, 3, 1, 1, 3])

文档

访问 PyTorch 的综合开发者文档

查看文档

教程

获取针对初学者和高级开发人员的深入教程

查看教程

资源

查找开发资源并获得问题的解答

查看资源