快捷方式

旋转位置嵌入

class torchtune.modules.RotaryPositionalEmbeddings(dim: int, max_seq_len: int = 4096, base: int =10000)[源代码]

此类实现了 https://arxiv.org/abs/2104.09864 中提出的旋转位置嵌入 (RoPE)。

参考实现(用于正确性验证)可以在这里找到:https://github.com/meta-llama/llama/blob/main/llama/model.py#L80

在此实现中,我们在初始化期间计算并缓存了直至 max_seq_len 的每个位置的嵌入。

参数:
  • dim (int) – 嵌入维度。这通常设置为注意力模块中每个头的维度,计算方式为 embed_dim // num_heads

  • max_seq_len (int) – 模型的最大期望序列长度,如果超出此长度,缓存的频率将被重新计算

  • base (int) – 用于计算旋转角度的几何级数的底数

forward(x: Tensor, *, input_pos: Optional[Tensor] = None) Tensor[源代码]
参数:
  • x (torch.Tensor) – 输入张量,形状为 [b, s, n_h, h_d]

  • input_pos (可选[torch.Tensor]) – 包含每个 token 位置 ID 的可选张量。在训练期间,打包时此参数用于指示每个 token 相对于其样本的位置,形状为 [b, s]。在推理期间,此参数指示当前 token 的位置。如果为 None,则假定 token 的索引即为其位置 ID。默认为 None。

返回:

输出张量,形状为 [b, s, n_h, h_d]

返回类型:

torch.Tensor

张量形状的符号表示
  • b: 批大小

  • s: 序列长度

  • n_h: 头数量

  • h_d: 头维度

文档

访问 PyTorch 的全面开发者文档

查看文档

教程

获取针对初学者和高级开发者的深度教程

查看教程

资源

查找开发资源并获得问题解答

查看资源