RMSNorm¶
- 类 torch.nn.RMSNorm(normalized_shape, eps=None, elementwise_affine=True, device=None, dtype=None)[源文件][源文件]¶
在输入的小批量数据上应用均方根层归一化 (Root Mean Square Layer Normalization)。
此层实现了论文 Root Mean Square Layer Normalization 中描述的操作。
均方根(RMS)是在最后
D
个维度上计算的,其中D
是normalized_shape
的维度。例如,如果normalized_shape
是(3, 5)
(一个二维形状),则均方根是在输入的最后 2 个维度上计算的。- 参数
- 形状
输入:
输出: (与输入形状相同)
示例
>>> rms_norm = nn.RMSNorm([2, 3]) >>> input = torch.randn(2, 2, 3) >>> rms_norm(input)