快捷方式

Flatten

class torch.nn.Flatten(start_dim=1, end_dim=-1)[源代码]

将连续的维度范围展平为一个张量。

有关在 Sequential 中的使用,请参阅 torch.flatten() 以获取详细信息。

形状
  • 输入:(,Sstart,...,Si,...,Send,)(*, S_{\text{start}},..., S_{i}, ..., S_{\text{end}}, *), 其中 SiS_{i} 是维度 ii 上的大小,而 * 表示任意数量的维度,包括无维度。

  • 输出:(,i=startendSi,)(*, \prod_{i=\text{start}}^{\text{end}} S_{i}, *).

参数
  • start_dim (int) – 要展平的第一个维度(默认值为 1)。

  • end_dim (int) – 要展平的最后一个维度(默认值为 -1)。

示例:
>>> input = torch.randn(32, 1, 5, 5)
>>> # With default parameters
>>> m = nn.Flatten()
>>> output = m(input)
>>> output.size()
torch.Size([32, 25])
>>> # With non-default parameters
>>> m = nn.Flatten(0, 2)
>>> output = m(input)
>>> output.size()
torch.Size([160, 5])

文档

访问 PyTorch 的全面开发者文档

查看文档

教程

获取适合初学者和高级开发人员的深入教程

查看教程

资源

查找开发资源并获取问题的解答

查看资源