RMSNorm¶
- class torch.nn.modules.normalization.RMSNorm(normalized_shape, eps=None, elementwise_affine=True, device=None, dtype=None)[source][source]¶
对小批量输入应用均方根层归一化。
此层实现的操作如论文 Root Mean Square Layer Normalization 中所述
RMS 是在最后
D
维度上计算的,其中D
是normalized_shape
的维度。例如,如果normalized_shape
是(3, 5)
(2 维形状),则 RMS 是在输入的最后 2 个维度上计算的。- 参数
- 形状
输入:
输出: (与输入形状相同)
示例
>>> rms_norm = nn.RMSNorm([2, 3]) >>> input = torch.randn(2, 2, 3) >>> rms_norm(input)