快捷方式

RMSNorm

class torch.nn.RMSNorm(normalized_shape, eps=None, elementwise_affine=True, device=None, dtype=None)[源代码]

对输入的小批量应用根均方层归一化。

此层实现的操作如论文 根均方层归一化 中所述。

y=xRMS[x]+ϵγy = \frac{x}{\sqrt{\mathrm{RMS}[x] + \epsilon}} * \gamma

根均方范数是在最后 D 个维度上计算的,其中 Dnormalized_shape 的维度。例如,如果 normalized_shape(3, 5)(一个二维形状),则根均方范数是在输入的最后 2 个维度上计算的。

参数
  • normalized_shape (intlisttorch.Size) –

    来自大小为的预期输入的输入形状

    [×normalized_shape[0]×normalized_shape[1]××normalized_shape[1]][* \times \text{normalized\_shape}[0] \times \text{normalized\_shape}[1] \times \ldots \times \text{normalized\_shape}[-1]]

    如果使用单个整数,则将其视为单元素列表,并且此模块将在最后一个维度上进行归一化,该维度预计具有该特定大小。

  • eps (可选[浮点数]) – 添加到分母中的值,用于数值稳定性。默认值:torch.finfo(x.dtype).eps()

  • elementwise_affine (布尔值) – 布尔值,设置为True时,此模块具有可学习的每个元素仿射参数,初始化为 1(权重)和 0(偏差)。默认值:True

形状
  • 输入:(N,)(N, *)

  • 输出:(N,)(N, *)(与输入相同的形状)

示例

>>> rms_norm = nn.RMSNorm([2, 3])
>>> input = torch.randn(2, 2, 3)
>>> rms_norm(input)
extra_repr()[source]

有关模块的额外信息。

返回类型

字符串

forward(x)[source]

运行前向传递。

返回类型

张量

reset_parameters()[source]

根据其在 __init__ 中使用的初始化重置参数。

文档

访问 PyTorch 的全面开发者文档

查看文档

教程

获取初学者和高级开发人员的深入教程

查看教程

资源

查找开发资源并获得问题的解答

查看资源