RMSNorm¶
- class torch.nn.modules.normalization.RMSNorm(normalized_shape, eps=None, elementwise_affine=True, device=None, dtype=None)[source][source]¶
对输入 mini-batch 应用均方根层归一化。
此层实现了论文 Root Mean Square Layer Normalization 中描述的操作
均方根 (RMS) 计算是基于最后
D
个维度进行的,其中D
是normalized_shape
的维度。例如,如果normalized_shape
为(3, 5)
(一个二维形状),则均方根计算将应用于输入的最后 2 个维度。- 参数
- 形状
输入:
输出: (与输入形状相同)
示例
>>> rms_norm = nn.RMSNorm([2, 3]) >>> input = torch.randn(2, 2, 3) >>> rms_norm(input)