
在跨多种 AI 硬件加速大语言模型的竞赛中,FlagGems 提供了一个高性能、灵活且可扩展的解决方案。FlagGems 基于 Triton 语言构建,是一个基于插件的 PyTorch 算子和内核库,旨在普及 AI 计算。其使命是实现“一次编写,处处 JIT(即时编译)”的体验,使开发者能够轻松地在各种硬件后端上部署优化后的内核。FlagGems 最近获得 PyTorch 生态系统工作组的认可,正式加入了 PyTorch 生态系统。
FlagGems 已实现超过 180 个算子——涵盖原生 PyTorch 算子和广泛用于大模型的自定义算子——并正快速发展,以跟上生成式 AI 的前沿步伐。
要查看 PyTorch 生态系统,请访问 PyTorch Landscape,并了解更多关于项目如何加入 PyTorch 生态系统的信息。
主要特性
- 丰富的算子库:180 多个与 PyTorch 兼容的算子,且在不断增加
- 性能优化:精选算子经过手工调优以提升速度
- 独立于 Torch.compile:在 Eager 模式下功能完整
- Pointwise 算子代码生成:为任意输入类型和布局自动生成内核
- 快速内核分发:每个函数都有独立的运行时分发逻辑
- C++ Triton 分发器:正在开发中,以实现更快的执行速度
- 支持多后端:通过后端中立的运行时 API,可在 10 多个硬件平台上运行
架构
FlagGems 通过一个由 Triton 驱动的多后端库,在树外扩展了 PyTorch 的分发系统。它拦截 ATen 算子调用,并提供针对特定后端的 Triton 实现,从而可以轻松支持替代 GPU 和领域特定加速器(DSA)。
即插即用
- 注册到 PyTorch 的分发系统
- 拦截 ATen 算子调用
- 无缝替换 CUDA 算子实现
一次编写,随处编译
- 统一的算子库代码
- 可在任何支持 Triton 的后端上编译
- 支持 GPU 和 DSA 等异构芯片
三步快速上手
1. 安装依赖
pip install torch>=2.2.0 # 2.6.0 preferred pip install triton>=2.2.0 # 3.2.0 preferred
2. 安装 FlagGems
git clone https://github.com/FlagOpen/FlagGems.git cd FlagGems pip install --no-build-isolation . or editable install: pip install --no-build-isolation -e .
3. 在你的项目中启用 FlagGems
import flag_gems flag_gems.enable() # Replaces supported PyTorch ops globally
希望更精细的控制?使用托管上下文
with flag_gems.use_gems(): output = model.generate(**inputs)
需要显式调用算子?
out = flag_gems.ops.slice_scatter(inp, dim=0, src=src, start=0, end=10, step=1)
Pointwise 算子的自动代码生成
通过 @pointwise_dynamic 装饰器,FlagGems 可以自动生成支持广播、融合和内存布局的高效内核。以下是一个实现 GeLU 和逐元素乘法融合的示例:
@pointwise_dynamic(promotion_methods=[(0, 1, “DEFAULT”)]) @triton.jit
def gelu_tanh_and_mul_kernel(x, y): x_fp32 = x.to(tl.float32) x_gelu = 0.5 * x_fp32 * (1 + tanh(x_fp32 * 0.79788456 * (1 + 0.044715 * pow(x_fp32, 2)))) return x_gelu * y
性能验证
FlagGems 包含内置的测试和基准测试
- 精度测试
cd tests pytest test_<op_name>_ops.py --ref cpu
- 性能基准测试
cd benchmark pytest test_<op_name>_perf.py -s # CUDA microbenchmarks pytest test_<op_name>_perf.py -s --mode cpu # End-to-end comparison
基准测试结果
FlagGems 的初步基准测试结果展示了其相对于 PyTorch 原生算子实现的性能。结果表示测得的平均加速比,大于 1 的值表示 FlagGems 比原生 PyTorch 算子更快。对于绝大多数算子,FlagGems 的性能与 PyTorch 的原生实现相当或显著超越。
对于 180 多个算子中的很大一部分,FlagGems 实现了接近 1.0 的加速比,表明其性能与原生 PyTorch 实现相当。
一些核心操作,如 LAYERNORM
、 CROSS_ENTROPY_
、 ADDMM
和 SOFTMAX
也显示出令人印象深刻的加速效果。
多后端支持
FlagGems 灵活支持不同供应商,并能感知后端
使用以下命令设置所需的后端
export GEMS_VENDOR=<vendor>
在 Python 中检查当前活动的后端
import flag_gems print(flag_gems.vendor_name)
总结
FlagGems 提供了一个统一的内核库,用于加速大模型,它在软件可移植性和硬件性能之间架起了桥梁。凭借广泛的后端支持、不断增长的算子集以及先进的代码生成功能,它已成为您突破 AI 计算极限的首选 Triton 实验场。