FeedForward¶
- 类 torchtune.modules.FeedForward(*, gate_proj: Module, down_proj: Module, up_proj: Optional[Module] = None, activation: Module = SiLU())[source]¶
这个类实现了源自 Llama2 的前馈网络。
- 参数:
gate_proj (nn.Module) – 从输入维度投影到隐藏维度,通过激活函数处理后与 up_proj 相乘。
down_proj (nn.Module) – 最终投影到输出维度。
up_proj (Optional[nn.Module]) – 从输入维度投影到隐藏维度,与 activation(gate_proj) 相乘。
activation (nn.Module) – 使用的激活函数。默认是 nn.SiLU()。
- 前向传播(x: Tensor) Tensor [source]¶
- 参数:
x (torch.Tensor) – 形状为
(..., in_dim)
的输入张量,其中in_dim
是gate_proj
和up_proj
的输入维度。- 返回值:
形状为
(..., out_dim)
的输出张量,其中out_dim
是down_proj
的输出维度。- 返回值类型: