PyTorch/XLA 是一个 Python 包,它利用 XLA 深度学习编译器在各种硬件后端(包括 Google Cloud TPU、GPU 以及 AWS Inferentia/Trainium)上运行 PyTorch 深度学习工作负载。PyTorch/XLA 团队一直致力于为使用 TPU/GPU 和 XLA 后端的研究人员和开发人员带来新的功能。在此次更新中,我们对框架进行了许多新增和改进。其中值得注意的亮点包括:
- 可用性改进
- 与 JAX 操作的实验性桥接
- 一个新的基于 Pallas 的稀疏分页注意力内核,实现了 vLLM TPU 的进一步优化
这些功能、错误修复和其他详细信息已在 发布说明中列出。现在让我们详细探讨这些亮点!
可用性改进
开发人员现在可以通过标记他们希望进行性能度量的精确代码区域,从而更好地定位代码区域。例如:
server = xp.start_server(8001) xp.start_trace(profiling_dir) # Run some computation ... xp.stop_trace()
PyTorch/XLA 2.7 还引入了一个 API,用于查询缓存编译图的数量,这有助于在生产推理或训练期间检测意外编译。另外一项增强功能通过避免不必要的张量复制来优化主机到设备的传输,从而提高了性能。
PyTorch/XLA 中的 JAX 桥接(原型)
我们正在尝试将 JAX 操作直接集成到 PyTorch/XLA 图中,作为实现框架之间桥接的一种方式——这种方法使用户能够在运行 XLA 的 PyTorch 模型中调用 JAX 函数。
作为一个用例,我们探索了从 PyTorch/XLA 调用“jax.experimental.shard_alike”。此函数改进了在某些代码模式(如扫描)中的分片传播,我们已将其集成到编译器中的 GSPMD(广义 SPMD)工作流中。此工具在 torchprime 中用于支持 SplashAttention Pallas 内核。
import torch_xla.core.xla_builder as xb # Native function written in JAX def jax_function(...): import jax ... return ... res = xb.call_jax(...) </pre?
稀疏分页注意力 Pallas 内核
高效处理变长序列的注意力对于扩展大型语言模型至关重要,而新的 稀疏分页注意力 Pallas 内核 为 vLLM TPU 带来了重大的性能和可用性升级。
此更新引入了一个使用 Pallas 自定义内核语言实现的自定义内核,并将其降低到 Mosaic 以用于 TPU。它支持 稀疏(变长) 输入序列并实现了 分页注意力 模式。以下是主要功能:
- 支持混合预填充和解码操作以提高推理吞吐量(例如,对于 llama-3-8b,与填充多查询分页注意力实现相比,速度提升高达 5 倍)。
- 无需 GMM(分组矩阵乘法)元数据!我们会在内核中动态计算元数据。这可以将性能提高 10%。
- 提供与 CUDA Flash Attention 等效的功能,支持分页注意力并具有相似的接口。
我们正在持续与 vLLM 社区 协作,以进一步优化性能,扩展内核覆盖范围,并简化大规模 TPU 推理。
GPU 构建回归
PyTorch/XLA 2.6 版本中暂停了 GPU 构建,但在 2.7 版本中,我们已重新启用 GPU 持续集成 (CI)。当前版本包含使用 CUDA 12.6 的 GPU 构建,标志着 GPU 支持向前迈出了重要一步。
虽然在此版本中 CUDA 支持仍被视为 实验性,但我们计划在即将发布的版本中将覆盖范围扩展到其他 CUDA 版本。
参与其中
请查看 GitHub 上的最新更改。一如既往,我们积极寻求社区的反馈和贡献。