RMSNorm¶
- class torchtune.modules.RMSNorm(dim: int, eps: float = 1e-06)[source]¶
fp32 中的均方根归一化 (Root Mean Square Normalization)。
参见: https://pytorch.ac.cn/docs/stable/generated/torch.nn.RMSNorm.html
- forward(x: Tensor) Tensor [source]¶
- 参数:
x (torch.Tensor) – 待归一化的输入张量
- 返回:
归一化并缩放后的张量,形状与
x
相同。- 返回类型: