Unflatten¶
- 类 torch.nn.Unflatten(dim, unflattened_size)[源][源]¶
展开 tensor 的一个维度,将其扩展到所需的形状。用于
Sequential
。dim
指定要展开的输入 tensor 的维度,使用 Tensor 或 NamedTensor 时,它可以分别是 int 或 str。unflattened_size
是展开后的 tensor 维度的新形状,对于 Tensor 输入,它可以是 tuple、list 或 torch.Size 的整数;对于 NamedTensor 输入,它可以是 NamedShape(即 (名称, 大小) 元组的元组)。
- 形状
输入: , 其中 是维度
dim
的大小, 表示任意数量的维度(包括零个)。输出: , 其中 =
unflattened_size
且 。
- 参数
unflattened_size (Union[torch.Size, Tuple, List, NamedShape]) – 展开后的维度的新形状
示例
>>> input = torch.randn(2, 50) >>> # With tuple of ints >>> m = nn.Sequential( >>> nn.Linear(50, 50), >>> nn.Unflatten(1, (2, 5, 5)) >>> ) >>> output = m(input) >>> output.size() torch.Size([2, 2, 5, 5]) >>> # With torch.Size >>> m = nn.Sequential( >>> nn.Linear(50, 50), >>> nn.Unflatten(1, torch.Size([2, 5, 5])) >>> ) >>> output = m(input) >>> output.size() torch.Size([2, 2, 5, 5]) >>> # With namedshape (tuple of tuples) >>> input = torch.randn(2, 50, names=('N', 'features')) >>> unflatten = nn.Unflatten('features', (('C', 2), ('H', 5), ('W', 5))) >>> output = unflatten(input) >>> output.size() torch.Size([2, 2, 5, 5])
- NamedShape¶
tuple[tuple[str, int]] 的别名