• 教程 >
  • Context Parallel 介绍
快捷方式

Context Parallel 介绍

作者: Xilun Wu, Chien-Chin Huang

注意

editGitHub 上查看和编辑此教程。

先决条件
  • PyTorch 2.7 或更高版本

介绍

Context Parallel 是一种在大规模语言模型训练中使用的技术,通过在多个设备上对长输入序列进行分片来减小峰值激活大小。它打破了 Transformer 块中存储激活时由于峰值内存使用导致的输入序列长度限制。

Ring Attention 是一种新颖的 Attention 层并行实现,对于高性能的 Context Parallel 至关重要。Ring Attention 会打乱 KV 分片并计算部分注意力分数,重复此过程直到每个设备上都使用过所有 KV 分片。目前已实现了两种 Ring Attention 变体:基于 all-gather 的 pass-KV基于 all-to-all 的 pass-KV

  1. 基于 all-gather 的 pass-KV 算法用于 Llama3 训练,该算法首先对 key 和 value 张量执行 all-gather,然后计算本地 query 张量块的注意力输出。我们修改后的基于 all-gather 的 pass-KV 算法同时对 KV 分片进行 all-gather,并使用本地 key 和 value 张量块计算本地 query 张量块的注意力输出,最后计算本地 query 张量和剩余 KV 分片的注意力输出。这使得注意力计算和 all-gather 集合操作可以在一定程度上重叠。例如,在 Llama3 训练中,我们还在序列维度上对 freq_cis 进行分片。

  2. 基于 all-to-all 的方法使用交错的 all-to-all 集合操作来环状打乱 KV 分片,以使 SDPA (Scaled Dot Product Attention) 计算与下一个 SDPA 所需的 all-to-all 通信重叠。

Context Parallel API 包含两个部分

  1. context_parallel() 允许用户创建一个 Python 上下文,在该上下文中,SDPA 函数 (torch.nn.functional.scaled_dot_product_attention) 将被自动替换为 Ring Attention。要在某个维度上对张量进行分片,只需将张量及其分片维度分别传递给参数 buffersbuffer_seq_dims。我们建议用户将沿序列维度进行计算的张量添加到 buffers 中,并沿该维度对其进行分片。以 Llama3 训练为例,如果 buffers 中缺少 freq_cis,将导致旋转嵌入计算错误。

  2. set_rotate_method() 允许用户选择基于 all-gather 的 pass-KV 方法或基于 all-to-all 的 pass-KV 方法。

设置

使用 torch.distributed.tensor.experimental.context_parallel(),用户可以轻松地对张量输入进行分片并并行化 SDPA 函数的执行。为了更好地演示此 API 的用法,我们先从一个执行 SDPA 的简单代码片段开始,然后使用该 API 对其进行并行化

import torch
import torch.nn.functional as F

from torch.nn.attention import sdpa_kernel, SDPBackend


def sdpa_example():
    assert torch.cuda.is_available()
    torch.cuda.set_device("cuda:0")
    torch.cuda.manual_seed(0)

    batch = 8
    nheads = 8
    qkv_len = 8192
    dim = 32
    backend = SDPBackend.FLASH_ATTENTION
    dtype = (
        torch.bfloat16
        if backend == SDPBackend.FLASH_ATTENTION
        or backend == SDPBackend.CUDNN_ATTENTION
        else torch.float32
    )

    qkv = [
        torch.rand(
            (batch, nheads, qkv_len, dim),
            dtype=dtype,
            requires_grad=True,
            device='cuda',
        )
        for _ in range(3)
    ]
    # specify the SDPBackend to use
    with sdpa_kernel(backend):
        out = F.scaled_dot_product_attention(*qkv, is_causal=True)


if __name__ == "__main__":
    sdpa_example()

启用 Context Parallel

现在,让我们先将其调整为分布式程序,其中每个 rank 具有相同的张量输入。然后我们应用 context parallel API 对输入进行分片,并将计算分布到各个 rank 上

# file: cp_sdpa_example.py
import os

import torch
import torch.distributed as dist
import torch.nn.functional as F
from torch.distributed.device_mesh import init_device_mesh
from torch.distributed.tensor.experimental import context_parallel
from torch.distributed.tensor.experimental._attention import context_parallel_unshard
from torch.nn.attention import sdpa_kernel, SDPBackend


def context_parallel_sdpa_example(world_size: int, rank: int):
    assert torch.cuda.is_available()
    assert dist.is_nccl_available()
    torch.cuda.set_device(f"cuda:{rank}")
    torch.cuda.manual_seed(0)

    dist.init_process_group(
        backend="nccl",
        init_method="env://",
        world_size=world_size,
        rank=rank,
    )
    device_mesh = init_device_mesh(
        device_type="cuda", mesh_shape=(world_size,), mesh_dim_names=("cp",)
    )

    batch = 8
    nheads = 8
    qkv_len = 64
    dim = 32
    backend = SDPBackend.FLASH_ATTENTION
    dtype = (
        torch.bfloat16
        if backend == SDPBackend.FLASH_ATTENTION
        or backend == SDPBackend.CUDNN_ATTENTION
        else torch.float32
    )

    qkv = [
        torch.rand(
            (batch, nheads, qkv_len, dim),
            dtype=dtype,
            requires_grad=True,
            device='cuda',
        )
        for _ in range(3)
    ]
    # specify the SDPBackend to use
    with sdpa_kernel(backend):
        out = F.scaled_dot_product_attention(*qkv, is_causal=True)

    # make a clean copy of QKV for output comparison
    cp_qkv = [t.detach().clone() for t in qkv]

    with sdpa_kernel(backend):
        # This `context_parallel()` performs two actions:
        # 1. Shard the tensor objects in `buffers` in-place along the dimension
        #    specified in `buffer_seq_dims`, the tensors in `buffers` and their
        #    sharding dims in `buffer_seq_dims` are organized in the same order.
        # 2. Replace the execution of `F.scaled_dot_product_attention` with a
        #    context-paralleled-enabled Ring Attention.
        with context_parallel(
            device_mesh, buffers=tuple(cp_qkv), buffer_seq_dims=(2, 2, 2)
        ):
            cp_out = F.scaled_dot_product_attention(*cp_qkv, is_causal=True)

        # The output `cp_out` is still sharded in the same way as QKV
        # the `context_parallel_unshard` API allows users to easily
        # unshard to gain the full tensor.
        (cp_out,) = context_parallel_unshard(device_mesh, [cp_out], [2])

    assert torch.allclose(
        cp_out,
        out,
        atol=(1e-08 if dtype == torch.float32 else 1e-03 * world_size),
    )


if __name__ == "__main__":
    rank = int(os.environ["RANK"])
    world_size = int(os.environ["WORLD_SIZE"])

    try:
        context_parallel_sdpa_example(world_size, rank)
    finally:
        dist.barrier()
        dist.destroy_process_group()

您可以使用命令 torchrun --standalone --nnodes=1 --nproc-per-node=4 cp_sdpa_example.py 在 4 个 GPU 上启动上述 context parallel SDPA。我们通过比较 Ring Attention 的输出与单个 GPU 上 SDPA 的输出,来证明其数值正确性。

选择旋转方法

您可以通过使用 torch.distributed.tensor.experimental._attention.set_rotate_method() 来选择 Ring Attention 中期望的分片旋转方法。

# file: cp_sdpa_example.py
from torch.distributed.tensor.experimental._attention import set_rotate_method

set_rotate_method("alltoall")  # rotate shards using all-to-all

with sdpa_kernel(backend):
    with context_parallel(
        device_mesh, buffers=tuple(cp_qkv), buffer_seq_dims=(2, 2, 2)
    ):
        cp_out = F.scaled_dot_product_attention(*cp_qkv, is_causal=True)

默认的旋转方法是基于 all-gather 的 pass-KV。

结论

在本教程中,我们学习了如何使用 Context Parallel API 轻松地沿序列维度并行化 SDPA 计算。有关设计和实现细节、性能分析以及 TorchTitan 中的端到端训练示例,请参阅我们关于 PyTorch 原生长上下文训练 的文章。

文档

访问 PyTorch 的全面开发者文档

查看文档

教程

获取针对初学者和高级开发者的深度教程

查看教程

资源

查找开发资源并获得问题解答

查看资源