通过 Pallas 实现自定义内核¶
随着 OpenAI Triton 的兴起,自定义内核在 GPU 社区中变得越来越流行,例如 FlashAttention 和 PagedAttention 的引入。为了在 TPU 世界中提供相同的功能,Google 推出了 Pallas。PyTorch/XLA 要想持续提升在 TPU 上的性能,就必须支持自定义内核,而最好的方式就是通过 Pallas。
假设您定义了一个 Pallas 内核如下
from torch_xla.experimental.custom_kernel import jax_import_guard
jax_import_guard()
import jax
from jax.experimental import pallas as pl
import jax.numpy as jnp
def add_vectors_kernel(x_ref, y_ref, o_ref):
x, y = x_ref[...], y_ref[...]
o_ref[...] = x + y
@jax.jit
def add_vectors(x: jax.Array, y: jax.Array) -> jax.Array:
return pl.pallas_call(add_vectors_kernel,
out_shape=jax.ShapeDtypeStruct(x.shape, x.dtype)
)(x, y)
需要注意的是,在导入任何 jax 模块之前运行 jax_import_guard()
非常重要。否则,程序可能会在 TPU 上挂起,因为 jax 会锁定 TPU,导致 torch-xla 无法访问。
使上述内核与 PyTorch/XLA 兼容¶
使用示例
q = torch.randn(3, 2, 128, 4).to("xla")
k = torch.randn(3, 2, 128, 4).to("xla")
v = torch.randn(3, 2, 128, 4).to("xla")
# Adopts any Pallas kernel
from torch_xla.experimental.custom_kernel import make_kernel_from_pallas
pt_kernel = make_kernel_from_pallas(add_vectors, lambda x, y: [(x.shape, x.dtype)])
output = pt_kernel(q, k)
对于简单的内核,适配就像一行代码一样简单。对于更复杂的内核,您可以参考我们的 Flash Attention 实现以获取详细信息。
使用内置内核¶
除了手动包装外部 Pallas 内核外,还有 PyTorch/XLA 已完成适配的内置内核。这些内置内核可以像任何其他 torch.ops 一样使用。当前支持的内置内核有: - FlashAttention - PagedAttention
FlashAttention¶
使用示例¶
# Use built-in kernels
import torch_xla.experimental.custom_kernel
output = flash_attention(q, k, v)
集成示例¶
在我们的训练测试脚本中,我们提供了一个 FlashAttention 集成示例。
PagedAttention¶
使用示例¶
# Use built-in kernels
import torch_xla.experimental.custom_kernel
output = torch.ops.xla.paged_attention(
query.squeeze(dim=1),
key_cache,
value_cache,
context_lens,
block_tables,
pages_per_compute_block,
megacore_mode=None,
)
集成示例¶
vLLM TPU 集成在此处利用了 PagedAttention,以实现 KV 缓存的有效内存管理。
依赖项¶
Pallas 集成依赖于 JAX 来运行。然而,并非所有 JAX 版本都与您安装的 PyTorch/XLA 兼容。要安装正确的 JAX
pip install torch_xla[pallas] -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html -f https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html