PyTorch/XLA 是一个利用 XLA 深度学习编译器使 PyTorch 能够在各种硬件后端(包括 Google Cloud TPU、GPU 以及 AWS Inferentia/Trainium)上运行深度学习工作负载的 Python 软件包。PyTorch/XLA 团队一直致力于为使用 TPU/GPU 和 XLA 后端的研究人员与开发人员提供新功能。在本次更新中,我们对框架进行了多项新增和改进。一些主要亮点包括:
- 可用性改进
- 与 JAX 操作的实验性桥接
- 一种基于 Pallas 的新型不规则分页注意力(ragged paged attention)内核,可在 vLLM TPU 上实现进一步优化
这些功能、错误修复及其他详细信息已在 发布说明 中列出。现在让我们详细深入了解一下这些亮点!
可用性改进
开发人员现在可以通过标记想要分析的具体代码区域,从而更精准地针对需要衡量性能的代码部分进行定位。示例如下:
server = xp.start_server(8001) xp.start_trace(profiling_dir) # Run some computation ... xp.stop_trace()
PyTorch/XLA 2.7 还引入了一个查询缓存编译图数量的 API,有助于在生产环境的推理或训练期间检测意外的编译情况。另一项增强功能通过避免不必要的张量复制来优化主机到设备的传输,从而提高了性能。
PyTorch/XLA 中的 JAX 桥接(原型)
我们正在尝试将 JAX 操作直接集成到 PyTorch/XLA 图中,作为在不同框架之间建立桥梁的一种方式——该方法允许用户在使用 XLA 运行的 PyTorch 模型中调用 JAX 函数。
作为用例,我们探索了从 PyTorch/XLA 调用 `jax.experimental.shard_alike`。该函数改善了 scan 等特定代码模式中的分片传播,我们已将其集成到编译器 GSPMD(通用 SPMD)工作流程中。该工具在 torchprime 中被使用,以支持 SplashAttention Pallas 内核。
import torch_xla.core.xla_builder as xb # Native function written in JAX def jax_function(...): import jax ... return ... res = xb.call_jax(...) </pre?
不规则分页注意力 Pallas 内核
针对变长序列的高效注意力机制对于扩展大语言模型至关重要,而新型 Pallas 不规则分页注意力内核 为 vLLM TPU 带来了重大的性能和可用性提升。
此更新引入了一个使用 Pallas 自定义内核语言实现的内核,并已降低至 TPU 的 Mosaic。它支持 不规则(变长) 输入序列并实现了 分页注意力 模式。以下是主要功能:
- 支持混合预填充(prefill)和解码(decode)操作,以提高推理吞吐量(例如,与 llama-3-8b 的填充式多查询分页注意力实现相比,速度提升高达 5 倍)。
- 无需 GMM(分组矩阵乘法)元数据!我们会在内核中实时计算元数据,这可将性能提升 10%。
- 提供了一个与 CUDA Flash Attention 等效、支持分页注意力且接口相似的实现。
我们正在与 vLLM 社区持续合作,以进一步优化性能、扩大内核覆盖范围,并简化大规模 TPU 推理。
GPU 构建回归
GPU 构建在 PyTorch/XLA 2.6 版本中曾被暂停,但我们在 2.7 版本中重新启用了 GPU 持续集成 (CI)。当前版本包含了基于 CUDA 12.6 的 GPU 构建,标志着 GPU 支持迈出了重要一步。
虽然 CUDA 支持在该版本中仍被视为 实验性 的,但我们计划在未来的版本中将覆盖范围扩展到更多的 CUDA 版本。
参与其中
请查看 GitHub 上的最新更改。一如既往,我们积极寻求社区的反馈和贡献。