TransformerCrossAttentionLayer¶
- class torchtune.modules.TransformerCrossAttentionLayer(attn: MultiHeadAttention, mlp: Module, *, ca_norm: Optional[Module] = None, mlp_norm: Optional[Module] = None, ca_scale: Optional[Module] = None, mlp_scale: Optional[Module] = None)[source]¶
遵循与 TransformerSelfAttentionLayer 相同约定的交叉注意力 Transformer 层。在注意力 **和** FF 层之前应用归一化。
- 参数:
attn (MultiHeadAttention) – 注意力模块。
mlp (nn.Module) – 前馈模块。
ca_norm (Optional[nn.Module]) – 在交叉注意力之前应用的归一化。
mlp_norm (Optional[nn.Module]) – 在前馈层之前应用的归一化。
ca_scale (Optional[nn.Module]) – 用于缩放交叉注意力输出的模块。
mlp_scale (Optional[nn.Module]) – 用于缩放前馈输出的模块。
- 引发:
AssertionError – 如果 attn.pos_embeddings 已设置。
- forward(x: Tensor, *, encoder_input: Optional[Tensor] = None, encoder_mask: Optional[Tensor] = None, **kwargs: Dict) Tensor [source]¶
- 参数:
x (torch.Tensor) – 形状为 [batch_size x seq_length x embed_dim] 的输入张量
encoder_input (Optional[torch.Tensor]) – 来自编码器的可选输入嵌入。形状 [batch_size x token_sequence x embed_dim]
encoder_mask (Optional[torch.Tensor]) – 定义令牌和编码器嵌入之间关系矩阵的布尔张量。在位置 i,j 处的 True 值表示令牌 i 可以关注解码器中的嵌入 j。掩码的形状为 [batch_size x token_sequence x embed_sequence]。默认值为 None。
**kwargs (Dict) – 与自我注意力无关的 transformer 层输入。
- 返回:
- 输出张量与输入张量形状相同。
[batch_size x seq_length x embed_dim]
- 返回类型: