RMSNorm¶
- class torch.nn.modules.normalization.RMSNorm(normalized_shape, eps=None, elementwise_affine=True, device=None, dtype=None)[源代码]¶
对输入的小批量应用根均方层归一化。
此层实现的操作如论文 根均方层归一化 中所述。
根均方范数是在最后的
D
个维度上计算的,其中D
是normalized_shape
的维度。例如,如果normalized_shape
为(3, 5)
(一个二维形状),则根均方范数是在输入的最后两个维度上计算的。- 参数
- 形状
输入:
输出:(与输入形状相同)
示例
>>> rms_norm = nn.RMSNorm([2, 3]) >>> input = torch.randn(2, 2, 3) >>> rms_norm(input)