快捷方式

旋转位置嵌入

class torchtune.modules.RotaryPositionalEmbeddings(dim: int, max_seq_len: int = 4096, base: int = 10000)[source]

此类实现了旋转位置嵌入 (RoPE),该嵌入在 https://arxiv.org/abs/2104.09864 中提出。

参考实现(用于正确性验证)可以在这里找到:https://github.com/meta-llama/llama/blob/main/llama/model.py#L80

在此实现中,我们通过在初始化期间计算,缓存了每个位置的嵌入,直到 max_seq_len

参数::
  • dim (int) – 嵌入维度。这通常设置为注意力模块中每个头的维度,计算为 embed_dim // num_heads

  • max_seq_len (int) – 模型的预期最大序列长度,如果超过,则重新计算缓存的频率

  • base (int) – 用于计算旋转角度的几何级数的底数

forward(x: Tensor, *, input_pos: Optional[Tensor] = None) Tensor[source]
参数::
  • x (torch.Tensor) – 形状为 [b, s, n_h, h_d] 的输入张量

  • input_pos (Optional[torch.Tensor]) – 包含每个标记位置 ID 的可选张量。在训练期间,这用于指示每个标记在其样本打包时的相对位置,形状 [b, s]。在推理期间,这指示当前标记的位置。如果没有,则假设标记的索引是其位置 ID。默认值为 None。

返回::

形状为 [b, s, n_h, h_d] 的输出张量

返回类型::

torch.Tensor

用于张量形状的符号
  • b: 批量大小

  • s: 序列长度

  • n_h: 头数

  • h_d: 头维度

文档

Access comprehensive developer documentation for PyTorch

View Docs

教程

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources