概述
最新的大语言模型(LLM)服务框架和模型越来越多地采用各种注意力变体,例如分组查询注意力(GQA)、多查询注意力(MQA)、PagedAttention 以及滑动窗口,以平衡精度与性能。传统上,每种变体都需要手动重写 FlashAttention 内核,才能在特定场景中获得合理的性能。
PyTorch 的 torch.nn.attention.flex_attention 提供了一种兼顾灵活性与效率的通用设计。它接受用户定义的 score_mod 和 mask_mod 来描述注意力变体及其组合,然后使用 torch.compile 将这些函数降级,从而自动生成高效的 FlashAttention 内核。常见的注意力计算如下:

许多注意力变体(例如 Alibi Bias、相对位置编码或 Tanh Soft-Capping)都可以通过用户定义的函数 score_mod 来表示。
除了与注意力得分相关的变体外,解码器层中的因果掩码(Casual Mask)和锯齿张量(Jagged Tensors)等其他变体是由计算稀疏性引起的,可以通过 mask_mod 来描述。更多关于 score_mod 和 mask_mod 的示例,请参阅 FlexAttention 第 1 部分 和 第 2 部分 博客。
FlexAttention 的优势使其在主流 LLM 生态系统项目中得到广泛采用,包括 HuggingFace、 vLLM 和 SGLang。这种普及显著降低了快速适配最新 LLM 模型所需的工作量。
在 Intel® GPU 上原生支持 FlexAttention
FlexAttention 是 PyTorch 中功能强大且灵活的注意力内核,其行为类似于 FlashAttention,但提供了更大的自由度来修改注意力得分和掩码逻辑。该内核使用 Triton 实现,这是一种允许程序员编写协作线程阵列(CTA 级)GPU 内核(也称为块级内核)的语言。
目前,FlexAttention 内核模板由两个内核组成:flex_attention 和 flex_decoding。flex_attention 内核专为推理和训练的预填充(prefill)阶段而设计,而 flex_decoding 内核则针对具有短查询和长 KV 缓存场景的推理解码阶段。在 PyTorch 2.9 中,所有 FlexAttention 的前向和反向场景都在 Intel GPU 上得到原生支持,并与 PyTorch 的标准 GPU 行为保持一致。这为用户提供了跨不同 GPU 的一致且可移植的性能,使开发人员能够编写一次代码,无需任何修改即可实现显著的注意力机制效率提升。
下图显示了 Intel® Arc™ B580 显卡上不同配置的内核性能。在 BS=1 和 head_dim=128 的条件下,对于 Hq=16, Hkv=16 的多头注意力(MHA)配置,FlexAttention 内核的性能与使用 oneDNN 后端的 PyTorch scaled_dot_product_attention 相当,并显著优于 Math 后端。对于 Hq=16, Hkv=2 的 GQA 配置,FlexAttention 的表现优于上述两种后端。

公开的 Triton 编译器不支持为 Intel GPU 编译 Triton 内核。Triton XPU 是 Intel 维护的 Triton 编译器扩展,除了支持公共 Triton 版本支持的 NVIDIA 和 AMD GPU 外,还增加了对 Intel GPU 的支持。PyTorch 2.9 集成了 Triton XPU,使 Triton 内核能够在 Intel GPU 上运行。在接下来的章节中,我们将介绍在 Intel GPU 上使用 Triton XPU 应用于 FlexAttention 内核的优化。
Intel® GPU 上的 FlexAttention 注意力内核优化
要理解 FlexAttention 中的 CTA 级平铺(tiling)模式,我们可以从不使用 block_mask 的密集注意力情况开始。(稀疏注意力遵循类似的模式,只是只调度有效的块。)FlexAttention API 旨在计算 Q、K 和 V 上的密集注意力,如下图所示。

内核沿 Q、Kᵀ 和 V 矩阵的 H_q、CTX_q 和 CTX_kv 维度对问题进行平铺。输入被划分为更小的块:
- q: BLOCK_M × D_HEAD_qk
- kᵀ: D_HEAD_qk × BLOCK_N
- v: BLOCK_N × D_HEAD_v
每个 CTA 以 BLOCK_N 为块大小,串行遍历 CTX_kv 维度,执行 (CTX_kv / BLOCK_N) 个步骤以覆盖完整的注意力上下文。

当 score_mod 或 block_mask 在内核内引入更复杂的计算或条件逻辑时,FlexAttention 可能会受到计算限制,因为 GPU 每次内存访问执行的算术运算要多得多。然而,在更简单的配置中,性能通常受内存限制,数据移动占用了执行时间。尽管如此,键/值(K/V)块的移动仍然是各种配置下的主要性能瓶颈:K/V 块必须由不同的线程块反复从全局内存中获取,即使在计算密集型场景下也限制了可实现的吞吐量。
为了最大化 FlexAttention 的效率,目标是尽可能重叠矩阵乘法、算术计算和内存访问。
Intel GPU AI 专用功能
Intel 最新的 GPU 架构(包含在 Intel® Arc™ B 系列显卡 中)提供了两个专门用于加速 AI 工作负载的硬件组件:
- Intel XMX (Xe Matrix eXtensions)
XMX 引擎是 Intel 专用的矩阵乘法单元,概念上类似于 NVIDIA 的 Tensor Cores 或 AMD 的 Matrix Cores。它们提供高吞吐量的混合精度计算,这对于深度学习操作(如注意力机制)至关重要。
- 块 I/O (Block I/O)
块 I/O 显著提高了将数据从内存移动到寄存器的效率。当与 XMX 引擎配合使用时,它为大型 AI 内核提供了精简的高带宽数据访问。关键的块 I/O 硬件功能包括:
直接 2D 矩阵加载:从全局内存到寄存器,周期更短。
自动边界保护和零填充:减少边界检查和填充的 ALU 操作。
内置支持 2D 转置和向量神经网络指令 (VNNI) 格式转换:在数据传输过程中进行。
异步预取:从全局内存到缓存,实现更好的流水线重叠。
支持关键 AI 数据类型:如 INT8、BF16 和 FP16 — 对现代推理和混合精度训练至关重要。
Triton XPU 集成了 XMX 和块 I/O 硬件特性,以加速 Intel GPU 上的 Triton 内核,利用这些专用组件减少数据依赖停顿,隐藏内存延迟,并提高整体内核性能。
Warp 级 FlexAttention 平铺与块 I/O 协同
如 FlexAttention 博客 中所述,PyTorch 的 FlexAttention 为注意力算法引入了高级的、与设备后端无关的优化。这些优化在受支持的硬件目标上统一应用。随后,Triton XPU 编译器后端进一步优化了以 Triton 语言生成的底层内核,从而在 Intel GPU 上实现高效执行和高性能。
在 Triton 中,Warp 级平铺在很大程度上由编译器自动处理。在此基础上,Triton XPU 后端进一步将每个 CTA 级切片细分为高效的水平 Warp 切片,以更好地匹配底层硬件。例如,当每个 CTA 使用两个 Warp 时,产生的平铺模式如下所示:

图示说明了一个 Triton 程序实例(线程块)如何计算注意力操作的平铺部分。每个实例处理一个大小为 BLOCK_M × D_HEAD_qk 的查询块,该块保持固定并被重用。键(Key)和值(Value)分别以大小为 D_HEAD_qk × BLOCK_N 和 BLOCK_N × D_HEAD_v 的切片流式传输。对于每个 K 切片,计算一个 BLOCK_M × BLOCK_N 的得分切片 (Q × Kᵀ),进行水平方向的 Softmax,然后与对应的 V 切片相乘,以累加输出一个 BLOCK_M × D_HEAD_v 的切片。通过将线程块平均分为两个子组(Warp)来实现并行,其中每个子组沿 N 维度水平计算切片的一部分。这种平铺最大限度地利用了寄存器,并且可以在偏好寄存器存储而非共享本地内存 (SLM) 的 Intel GPU 上高效运行。
基于寄存器执行的优势
查询 (Query)、得分 (Score) 和 输出 (Output) 矩阵直接存储在物理线程的通用寄存器文件 (GRF) 中,避免了对 SLM 的依赖。这提供了几个好处:
- XMX 引擎仅对寄存器操作数进行操作,因此避免使用 SLM 消除了额外的数据移动瓶颈。
- 水平归约 (Horizontal reductions)(例如 Softmax)可以完全在 Warp 内完成,消除了昂贵的同步。
- 第一次矩阵乘法的得分可以直接用作下一次矩阵乘法的输入 A,避免了通过 SLM 进行额外的加载或存储。
因此,FlexAttention 内核几乎完全从寄存器中执行,减少了同步,消除了 SLM 流量,并为 Intel GPU 上的 Transformer 工作负载提供了高性能。
基于寄存器执行的缺点
主要限制在于通用寄存器 (GRF) 对每个物理线程是私有的,导致 K 和 V 矩阵被多个线程冗余加载到寄存器中,使内核变为内存受限。这种限制通过 Triton XPU 基于块 I/O 的软件循环流水线优化得到了缓解。在软件流水线中,硬件特性被用于重叠内存到缓存的传输、缓存到寄存器的加载以及 MMA 执行。
- 块 I/O 预取:支持从 HBM 到缓存的异步数据传输。
- 块 I/O 加载操作:使用硬件计分板 (scoreboard) 异步将数据从缓存传输到寄存器。

如图所示,每次迭代将 XMX (MMA) 计算与 HBM 到 L1 缓存的预取流水线化,有效地隐藏了大部分内存延迟。流水线中 XMX (MMA) 计算与 HBM 到 L1 缓存预取的距离由软件循环流水线的 num_stages 控制。
这种重叠的有效性取决于所涉及的硬件资源及其相对延迟。对于具有足够大迭代次数的内核,性能遵循屋顶线模型 (roofline model),在内存延迟占主导时成为内存受限,而在 MMA 延迟占主导时成为计算受限。
在 Triton XPU 中,软件流水线通过使用 块 I/O 预取 明确地将内存访问与 MMA 计算重叠。增加 num_stages 会增加未完成的内存加载数量,从而更好地饱和内存子系统,但代价是更高的 L1 缓存使用率。
为了在软件流水线中明确地重叠 MMA 执行与缓存到寄存器的加载,会导致过度的寄存器压力和溢出,从而导致性能下降。因此,将重叠 MMA 和缓存到寄存器的传输通过同一循环迭代中的指令级和线程级并行委托给硬件来处理。
Triton XPU 上的 FlexDecoding
FlexDecoding 是针对解码阶段 (CTX_q=1) 优化的灵活注意力机制的专门变体。除了 FlexAttention 所采用的内存访问优化外,FlexDecoding 还引入了两个关键增强功能:
- 并行 K/V 处理:K 和 V 矩阵在多个内核实例中沿 CTX_kv 维度进行划分。这种转换将串行迭代转化为并行执行,同时减少了冗余的 K 和 V 矩阵加载。
- 优化的向量-矩阵运算:XMX 引擎对向量-矩阵乘法 (M=1) 的原生支持,消除了在查询上下文较小时对查询或得分矩阵在 CTX_q 维度上进行填充的需求,从而提高了 Intel GPU 上的解码效率。

在 Intel GPU 上最大化 FlexAttention 和 FlexDecoding 的性能
由于用户可以提供修改注意力得分的任意函数(例如通过 score_mod)或应用自定义掩码逻辑(例如通过 mask_mod),并且注意力形状可能各不相同,因此最佳的内核配置(如平铺大小、阶段数、Warp 计数和共享内存使用量)在不同用例中可能有很大差异。尽管 XPU 启发式模板为 FlexAttention 内核提供了默认配置,但“一刀切”的配置对于多样化的注意力形状和自定义逻辑来说很少是最佳的。
为了处理这种可变性,max-autotune 是用户可以在 torch.compile 中启用的模式。在此模式下,TorchInductor 会探索各种内核参数(包括块大小、num_stages、Warp 计数和平铺维度),并选择能提供最佳经验性能的配置。max-autotune 模式也受 Intel GPU 支持。
FlexAttention 在 LLM 生态系统中的应用
FlexAttention 已被广泛集成到主流的 LLM 生态系统库中。下图显示了基于 HF/transformers 和 TorchAO 的一些流行 LLM 的性能。以 torch.scaled_dot_product_attention 为基准,7 个模型中有 4 个在预填充阶段使用 FlexAttention 可以获得更好的性能,而在解码阶段则有 7 个模型中有 5 个获得更好性能。


结论和未来工作
本博客介绍了 FlexAttention 在 Intel GPU 上的优化进展,并展示了我们在内核和端到端层面取得的合理性能数据。FlexAttention 旨在涵盖各种注意力变体及其组合。未来,我们将扩大基准测试范围,涵盖更多模型和场景,例如 LLM 服务框架中的分块预填充 (chunked prefill) 和 PagedAttention。除了 Triton 内核模板外,我们还将跟随社区启用更多的计算后端。
产品和性能信息
Intel® Core™ i5-13400 (Arc B580, 12GB VRAM), Ubuntu 24.10, 驱动程序: 25.35.35096.9, PyTorch 2.9.1, TorchAO-v0.14.1