torch.unflatten¶
- torch.unflatten(input, dim, sizes) Tensor ¶
将输入张量的某个维度展开为多个维度。
另请参阅
torch.flatten()
是此函数的逆操作。它将多个维度合并为一个。- 参数
- 返回值
输入张量的一个视图,其中指定的维度已被展开。
- 示例:
>>> torch.unflatten(torch.randn(3, 4, 1), 1, (2, 2)).shape torch.Size([3, 2, 2, 1]) >>> torch.unflatten(torch.randn(3, 4, 1), 1, (-1, 2)).shape torch.Size([3, 2, 2, 1]) >>> torch.unflatten(torch.randn(5, 12, 3), -2, (2, 2, 3, 1, 1)).shape torch.Size([5, 2, 2, 3, 1, 1, 3])