PyTorch/XLA 是一个 Python 包,它利用 XLA 深度学习编译器在各种硬件后端(包括 Google Cloud TPU、GPU 和 AWS Inferentia/Trainium)上启用 PyTorch 深度学习工作负载。PyTorch/XLA 团队一直致力于为使用 TPU/GPU 和 XLA 后端的研究人员和开发人员带来新功能。在此次更新中,我们对框架进行了许多新增和改进。一些显著亮点包括:
- 可用性改进
- 与 JAX 操作的实验性桥接
- 一个基于 Pallas 的新型内核,用于不规则分页注意力(ragged paged attention),可在 vLLM TPU 上实现进一步优化。
这些功能、错误修复和其他详细信息在 发布说明中均有概述。现在让我们详细探讨这些亮点!
可用性改进
开发人员现在可以通过标记他们希望分析的精确代码区域,从而更好地针对他们希望测量性能的代码区域。一个示例如下:
server = xp.start_server(8001) xp.start_trace(profiling_dir) # Run some computation ... xp.stop_trace()
PyTorch/XLA 2.7 还引入了一个 API,用于查询缓存编译图的数量,有助于检测生产推理或训练期间的意外编译。此外,通过避免不必要的张量复制,优化了主机到设备的传输,从而提高了性能。
PyTorch/XLA 中的 JAX 桥接(原型)
我们正在尝试将 JAX 操作直接集成到 PyTorch/XLA 图中,作为实现框架之间桥接的一种方式 — 此方法允许用户在运行 XLA 的 PyTorch 模型中调用 JAX 函数。
作为一个用例,我们探索了从 PyTorch/XLA 调用 `jax.experimental.shard_alike`。此函数改进了 scan 等特定代码模式的分片传播,我们已将其作为编译器中 GSPMD(广义 SPMD)工作流的一部分集成。此工具在 torchprime 中使用,以支持 SplashAttention Pallas 内核。
import torch_xla.core.xla_builder as xb # Native function written in JAX def jax_function(...): import jax ... return ... res = xb.call_jax(...) </pre?
不规则分页注意力 Pallas 内核
高效处理变长序列的注意力对于扩展大型语言模型至关重要,而新型 不规则分页注意力 Pallas 内核为 vLLM TPU 带来了重大的性能和可用性提升。
本次更新引入了一个使用 Pallas 自定义内核语言实现的新型内核,并将其编译为 Mosaic 以用于 TPU。它支持 不规则(变长)输入序列并实现了 分页注意力模式。以下是主要功能:
- 支持混合预填充和解码操作,以提高推理吞吐量(例如,对于 llama-3-8b,与填充后的多查询分页注意力实现相比,可将速度提高至 5 倍)。
- 无需 GMM (Grouped Matmul) 元数据!我们在内核中即时计算元数据。这可以将性能提高 10%。
- 提供与 CUDA Flash Attention 等效的功能,支持分页注意力并具有相似的接口。
我们正在持续与 vLLM 社区合作,进一步优化性能,扩展内核覆盖范围,并简化大规模 TPU 推理。
GPU 构建回归
GPU 构建在 PyTorch/XLA 2.6 版本中曾暂停,但我们现在已在 2.7 版本中重新启用 GPU 持续集成 (CI)。当前版本包含使用 CUDA 12.6 的 GPU 构建,标志着 GPU 支持向前迈出了重要一步。
虽然此版本中的 CUDA 支持仍被视为 实验性,但我们计划在即将发布的版本中将覆盖范围扩展到其他 CUDA 版本。
参与进来
请在 GitHub 上查看最新更改。一如既往,我们积极寻求社区的反馈和贡献。