RMSNorm¶
- class torch.nn.RMSNorm(normalized_shape, eps=None, elementwise_affine=True, device=None, dtype=None)[源代码][源代码]¶
对小批量输入应用均方根层归一化。
此层实现的操作如论文 均方根层归一化 中所述
RMS 是在最后的
D
维度上计算的,其中D
是normalized_shape
的维度。例如,如果normalized_shape
是(3, 5)
(2 维形状),则 RMS 在输入的最后 2 个维度上计算。- 参数
normalized_shape (int 或 list 或 torch.Size) –
来自预期大小的输入的输入形状
如果使用单个整数,则将其视为单例列表,并且此模块将在最后一个维度上进行归一化,该维度预计为该特定大小。
eps (Optional[float]) – 为数值稳定性添加到分母的值。默认值:
torch.finfo(x.dtype).eps()
elementwise_affine (bool) – 一个布尔值,当设置为
True
时,此模块具有可学习的逐元素仿射参数,初始化为 1(对于权重)。默认值:True
。
- 形状
输入:
输出: (与输入形状相同)
示例
>>> rms_norm = nn.RMSNorm([2, 3]) >>> input = torch.randn(2, 2, 3) >>> rms_norm(input)