博客生态系统

FlagGems 加入 PyTorch 生态系统:由 Triton 驱动的通用 AI 加速算子库

作者: 2025年6月25日2025年6月27日暂无评论
FlagGems

在加速多硬件平台大语言模型的竞争中,FlagGems 提供了一套高性能、灵活且可扩展的解决方案。FlagGems 基于 Triton 语言构建,是一个面向 PyTorch 的插件式算子及内核库,旨在实现 AI 计算的普惠化。其使命是提供“编写一次,即时编译(JIT)到任何地方”的体验,让开发者能够轻松地将优化后的内核部署到广泛的硬件后端。FlagGems 已正式获准加入 PyTorch 生态系统,并成为其成员。

目前 FlagGems 已经实现了超过 180 个算子,涵盖了 PyTorch 原生算子以及大模型中常用的自定义算子,并正快速演进以紧跟生成式 AI 的前沿发展。

要查看 PyTorch 生态系统,请参阅 PyTorch Landscape,并了解更多关于项目如何 加入 PyTorch 生态系统的信息。

核心特性

  • 丰富的算子库:包含 180 多个兼容 PyTorch 的算子,且数量持续增加
  • 性能优化:针对特定算子进行手动调优,以提升速度
  • 独立于 Torch.compile:在 Eager 模式下即可完全运行
  • 逐点(Pointwise)算子代码生成:可自动为任意输入类型和布局生成内核
  • 快速内核调度:基于函数的运行时调度逻辑
  • C++ Triton 调度器:正在开发中,旨在实现更快的执行速度
  • 多后端就绪:通过后端中立的运行时 API,支持 10 多种硬件平台

架构

FlagGems 通过一个基于 Triton 的多后端库,在 Tree 外(out-of-tree)扩展了 PyTorch 的调度系统。它拦截 ATen 算子调用并提供特定于后端的 Triton 实现,从而能够轻松支持各种替代 GPU 和领域专用加速器(DSA)。

即插即用

  • 注册到 PyTorch 的调度系统
  • 拦截 ATen 算子调用
  • 无缝替换 CUDA 算子实现

一次编写,随处编译

  • 统一的算子库代码
  • 可在任何支持 Triton 的后端上编译
  • 支持 GPU 和 DSA 等异构芯片

三步快速上手

1. 安装依赖

pip install torch>=2.2.0  # 2.6.0 preferred

pip install triton>=2.2.0 # 3.2.0 preferred

2. 安装 FlagGems

git clone https://github.com/FlagOpen/FlagGems.git

cd FlagGems

pip install --no-build-isolation .

or editable install:

pip install --no-build-isolation -e .

3. 在项目中启用 FlagGems

import flag_gems

flag_gems.enable()  # Replaces supported PyTorch ops globally

想要更精细的控制?使用托管上下文(Managed Context)

with flag_gems.use_gems():

    output = model.generate(**inputs)

需要明确指定的算子?

out = flag_gems.ops.slice_scatter(inp, dim=0, src=src, start=0, end=10, step=1)

逐点算子的自动代码生成

通过 @pointwise_dynamic 装饰器,FlagGems 可以自动生成支持广播、融合和内存布局的高效内核。以下是一个实现融合 GeLU 和逐元素乘法的示例:

@pointwise_dynamic(promotion_methods=[(0, 1, “DEFAULT”)])

@triton.jit
def gelu_tanh_and_mul_kernel(x, y):

    x_fp32 = x.to(tl.float32)

    x_gelu = 0.5 * x_fp32 * (1 + tanh(x_fp32 * 0.79788456 * (1 + 0.044715 * pow(x_fp32, 2))))

    return x_gelu * y

性能验证

FlagGems 内置了测试和基准测试功能

  • 精度测试
cd tests

pytest test_<op_name>_ops.py --ref cpu
  • 性能基准测试
cd benchmark

pytest test_<op_name>_perf.py -s  # CUDA microbenchmarks

pytest test_<op_name>_perf.py -s --mode cpu  # End-to-end comparison

基准测试结果

FlagGems Speedup

FlagGems 的初步基准测试结果展示了其与 PyTorch 原生算子实现相比的性能表现。结果代表了测得的平均加速比,大于 1 的值表示 FlagGems 比原生 PyTorch 算子更快。对于绝大多数算子,FlagGems 的性能要么与 PyTorch 原生实现相当,要么显著优于后者。

在 180 多个算子中的很大一部分,FlagGems 实现了接近 1.0 的加速比,表明其性能与原生 PyTorch 实现持平。

诸如 LAYERNORMCROSS_ENTROPY_LOSSADDMMSOFTMAX 等一些核心操作也表现出了令人印象深刻的加速效果。

多后端支持

FlagGems 具有供应商灵活性并具备后端感知能力

通过以下方式设置所需后端

export GEMS_VENDOR=<vendor>

在 Python 中检查当前激活的后端

import flag_gems

print(flag_gems.vendor_name)

meetups

总结

FlagGems 为大模型加速提供了一个统一的内核库,架起了软件可移植性和硬件性能之间的桥梁。凭借广泛的后端支持、不断增长的算子集合以及先进的代码生成功能,它将成为您突破 AI 计算极限的首选 Triton 平台。