• 教程 >
  • 通过使用 Nested Tensors 和 torch.compile() 加速 PyTorch Transformer
快捷方式

通过使用 Nested Tensors 和 torch.compile() 加速 PyTorch Transformer

作者: Mikayla Gawarecki

你将学到什么
  • 了解 PyTorch 提供的用于构建自定义 Transformer 层的底层构建块(nested tensors, scaled_dot_product_attention, torch.compile()FlexAttention

  • 了解上述技术如何通过以 MultiHeadAttention 为例来改进内存使用和性能

  • 探索使用上述构建块进行高级定制

前提条件
  • PyTorch v.2.6.0 或更高版本

在过去几年里,PyTorch 团队开发了各种底层功能,这些功能组合在一起可以创建各种 Transformer 变体。这些功能包括

  • 使用 torch.jagged 布局的 Nested Tensors(也称为 NJTs)

  • scaled_dot_product_attention

  • torch.compile()

  • FlexAttention

本教程将简要概述上述技术,并演示如何组合它们以获得灵活且高性能的 Transformer 层,同时改善用户体验。

你可能会注意到 torch.nn 模块目前提供了各种与 Transformer 相关的层。特别是,它包括 TransformerEncoderLayer, TransformerEncoder, TransformerDecoderLayer, TransformerDecoder, TransformerMultiheadAttention。这系列层最初是按照 Attention is All You Need 论文实现的。本教程中讨论的组件在现有 nn 层之上提供了更好的用户体验、灵活性和性能。

本教程适合我吗?

如果你想了解 torch 库为编写自己的 Transformer 层提供了哪些构建块以及最佳实践,那么你来对地方了。请继续阅读!

如果你正在寻找一个流行的 Transformer 架构的开箱即用实现,请注意有许多开源库提供了它们,包括

如果你只对高性能的注意力分数修改感兴趣,请查看 FlexAttention 博客,其中包含一个 mask 的 gym

介绍构建块

首先,我们将简要介绍引言中提到的四项技术

Nested tensors 泛化了常规稠密张量的形状,允许使用相同的张量用户体验表示大小不规则的数据。在 Transformer 的上下文中,我们可以将 nested tensors 视为一种表示可变序列长度的工具。它们消除了显式 padding 和 masking(想想 nn.MultiHeadAttention 中的 key_padding_mask)这种易出错实践的必要性。

scaled_dot_product_attention 是一个用于计算 \(\text{softmax}(\frac{QK^T}{\sqrt{E}} + B)V\) 的原语,它可以分派到该操作的融合实现或回退实现。它在 eager 模式(即 PyTorch 的默认模式,操作会即时执行)下开箱即用,并且与 torch.compile() 无缝集成。截至 2.6 版本,它还将原生提供分组查询注意力。

torch.compile() 是一个在 2.0 版本中引入的编译器,能够捕获 PyTorch 代码图并对其执行各种优化,例如融合一系列操作。使用 torch.jagged 布局的 Nested tensors 和 scaled_dot_product_attention 可以与 compile 无缝协作。在 Transformer 的上下文中,将 compile 与 nested tensor 和 SDPA 结合使用的好处是 compile 可以消除 eager 模式下的框架开销,并将 Transformer 中的一系列操作(例如 projection 和 activation)融合在一起。

FlexAttention 是一个原语,允许用户在 softmax 操作之前修改注意力分数。它泛化了上述 scaled_dot_product_attention 的加性 B 项,允许进行任意计算。它需要 compile 才能获得良好性能。

上述构建块是“你所需要的一切”(截至 2024 年 10 月)

本节的主要前提是,大多数 Transformer 变体都是 GPT 风格的,由 Embedding、Positional Encoding、Attention Blocks 和 Feed Forward networks 等层组成。如果我们试图对此领域的差异进行分类,可能会得出以下几点

  1. 层类型(激活函数,如 SwiGLU 等,归一化函数,如 RMSNorm 等,位置编码,如 Sinusoidal, Rotary 等)

  2. 层顺序,例如在哪应用归一化和位置编码。

  3. 注意力分数修改,例如 ALiBi, Relative Positional Bias 等等。

在非编译环境(pre-compiler environment)中,你可能编写一个自定义 Transformer,注意到它可以正常工作但速度很慢。为了解决这个问题,你可能需要为特定的操作序列开发一个自定义的融合内核。在编译环境(compiler environment)中,你只需执行第一步,然后进行编译即可从改进的性能中受益。

MultiheadAttention

请记住,MultiheadAttention 接受 query、key 和 value 作为输入,并由一个输入 projection、一个 scaled_dot_product_attention 操作符和一个输出 projection 组成。这里我们想要展示的主要亮点是,用 nested tensors 替换 padded/masked 输入所带来的改进。改进有三个方面

  • 用户体验 请记住,nn.MultiheadAttention 要求 querykeyvalue 是稠密的 torch.Tensors。它还提供了一个 key_padding_mask,用于屏蔽由于批处理中不同序列长度而产生的 key 中的 padding token。由于 nn.MHA 中没有 query_padding_mask,用户必须小心地对输出进行 mask/slice 以考虑 query 序列长度。NestedTensor 清晰地消除了这种易出错的 padding mask 的需求。

  • 内存 Nested tensors 允许你清晰地表示一批不同序列长度的数据,而无需实例化一个带有 [B, S] padding mask(其中 B 是批大小,S 是批处理中的最大序列长度,D 是嵌入大小)的稠密 [B, S, D] 张量。因此,输入和中间激活将使用更少的内存。

  • 性能 由于未实例化 padding 且跳过了对 padding 的不必要计算,性能和内存使用都得到了改善。

我们将通过在 Nested Tensor 教程 中的 MultiheadAttention 层基础上进行构建,并将其与 nn.MultiheadAttention 层进行比较,来演示上述优势。

import torch
import torch.nn as nn
import torch.nn.functional as F


class MultiHeadAttention(nn.Module):
    """
    Computes multi-head attention. Supports nested or padded tensors.

    Args:
        E_q (int): Size of embedding dim for query
        E_k (int): Size of embedding dim for key
        E_v (int): Size of embedding dim for value
        E_total (int): Total embedding dim of combined heads post input projection. Each head
            has dim E_total // nheads
        nheads (int): Number of heads
        dropout (float, optional): Dropout probability. Default: 0.0
        bias (bool, optional): Whether to add bias to input projection. Default: True
    """

    def __init__(
        self,
        E_q: int,
        E_k: int,
        E_v: int,
        E_total: int,
        nheads: int,
        dropout: float = 0.0,
        bias=True,
        device=None,
        dtype=None,
    ):
        factory_kwargs = {"device": device, "dtype": dtype}
        super().__init__()
        self.nheads = nheads
        self.dropout = dropout
        self._qkv_same_embed_dim = E_q == E_k and E_q == E_v
        if self._qkv_same_embed_dim:
            self.packed_proj = nn.Linear(E_q, E_total * 3, bias=bias, **factory_kwargs)
        else:
            self.q_proj = nn.Linear(E_q, E_total, bias=bias, **factory_kwargs)
            self.k_proj = nn.Linear(E_k, E_total, bias=bias, **factory_kwargs)
            self.v_proj = nn.Linear(E_v, E_total, bias=bias, **factory_kwargs)
        E_out = E_q
        self.out_proj = nn.Linear(E_total, E_out, bias=bias, **factory_kwargs)
        assert E_total % nheads == 0, "Embedding dim is not divisible by nheads"
        self.E_head = E_total // nheads
        self.bias = bias

    def forward(
        self,
        query: torch.Tensor,
        key: torch.Tensor,
        value: torch.Tensor,
        attn_mask=None,
        is_causal=False,
    ) -> torch.Tensor:
        """
        Forward pass; runs the following process:
            1. Apply input projection
            2. Split heads and prepare for SDPA
            3. Run SDPA
            4. Apply output projection

        Args:
            query (torch.Tensor): query of shape (``N``, ``L_q``, ``E_qk``)
            key (torch.Tensor): key of shape (``N``, ``L_kv``, ``E_qk``)
            value (torch.Tensor): value of shape (``N``, ``L_kv``, ``E_v``)
            attn_mask (torch.Tensor, optional): attention mask of shape (``N``, ``L_q``, ``L_kv``) to pass to SDPA. Default: None
            is_causal (bool, optional): Whether to apply causal mask. Default: False

        Returns:
            attn_output (torch.Tensor): output of shape (N, L_t, E_q)
        """
        # Step 1. Apply input projection
        if self._qkv_same_embed_dim:
            if query is key and key is value:
                result = self.packed_proj(query)
                query, key, value = torch.chunk(result, 3, dim=-1)
            else:
                q_weight, k_weight, v_weight = torch.chunk(
                    self.packed_proj.weight, 3, dim=0
                )
                if self.bias:
                    q_bias, k_bias, v_bias = torch.chunk(
                        self.packed_proj.bias, 3, dim=0
                    )
                else:
                    q_bias, k_bias, v_bias = None, None, None
                query, key, value = (
                    F.linear(query, q_weight, q_bias),
                    F.linear(key, k_weight, k_bias),
                    F.linear(value, v_weight, v_bias),
                )

        else:
            query = self.q_proj(query)
            key = self.k_proj(key)
            value = self.v_proj(value)

        # Step 2. Split heads and prepare for SDPA
        # reshape query, key, value to separate by head
        # (N, L_t, E_total) -> (N, L_t, nheads, E_head) -> (N, nheads, L_t, E_head)
        query = query.unflatten(-1, [self.nheads, self.E_head]).transpose(1, 2)
        # (N, L_s, E_total) -> (N, L_s, nheads, E_head) -> (N, nheads, L_s, E_head)
        key = key.unflatten(-1, [self.nheads, self.E_head]).transpose(1, 2)
        # (N, L_s, E_total) -> (N, L_s, nheads, E_head) -> (N, nheads, L_s, E_head)
        value = value.unflatten(-1, [self.nheads, self.E_head]).transpose(1, 2)

        # Step 3. Run SDPA
        # (N, nheads, L_t, E_head)
        attn_output = F.scaled_dot_product_attention(
            query, key, value, dropout_p=self.dropout, is_causal=is_causal
        )
        # (N, nheads, L_t, E_head) -> (N, L_t, nheads, E_head) -> (N, L_t, E_total)
        attn_output = attn_output.transpose(1, 2).flatten(-2)

        # Step 4. Apply output projection
        # (N, L_t, E_total) -> (N, L_t, E_out)
        attn_output = self.out_proj(attn_output)

        return attn_output

实用工具

在本节中,我们包含了一个实用工具,用于使用 Zipf 分布生成半真实数据以获取句子长度。这用于生成嵌套的 query、key 和 value 张量。我们还包含了一个基准测试实用工具。

import numpy as np


def zipf_sentence_lengths(alpha: float, batch_size: int) -> torch.Tensor:
    # generate fake corpus by unigram Zipf distribution
    # from wikitext-2 corpus, we get rank "." = 3, "!" = 386, "?" = 858
    sentence_lengths = np.empty(batch_size, dtype=int)
    for ibatch in range(batch_size):
        sentence_lengths[ibatch] = 1
        word = np.random.zipf(alpha)
        while word != 3 and word != 386 and word != 858:
            sentence_lengths[ibatch] += 1
            word = np.random.zipf(alpha)
    return torch.tensor(sentence_lengths)


# Generate a batch of semi-realistic data using Zipf distribution for sentence lengths
# in the form of nested tensors with the jagged layout.
def gen_batch(N, E_q, E_k, E_v, device, dtype=torch.float32, query_seq_len_1=False):
    # generate semi-realistic data using Zipf distribution for sentence lengths
    sentence_lengths = zipf_sentence_lengths(alpha=1.2, batch_size=N)

    # Note: the torch.jagged layout is a nested tensor layout that supports a single ragged
    # dimension and works with torch.compile. The batch items each have shape (B, S*, D)
    # where B = batch size, S* = ragged sequence length, and D = embedding dimension.
    if query_seq_len_1:
        query = torch.nested.nested_tensor(
            [torch.randn(1, E_q, dtype=dtype, device=device) for l in sentence_lengths],
            layout=torch.jagged,
        )
    else:
        query = torch.nested.nested_tensor(
            [
                torch.randn(l.item(), E_q, dtype=dtype, device=device)
                for l in sentence_lengths
            ],
            layout=torch.jagged,
        )

    key = torch.nested.nested_tensor(
        [
            torch.randn(s.item(), E_k, dtype=dtype, device=device)
            for s in sentence_lengths
        ],
        layout=torch.jagged,
    )

    value = torch.nested.nested_tensor(
        [
            torch.randn(s.item(), E_v, dtype=dtype, device=device)
            for s in sentence_lengths
        ],
        layout=torch.jagged,
    )

    return query, key, value, sentence_lengths


import math
import timeit


def benchmark(func, *args, **kwargs):
    torch.cuda.synchronize()
    torch.cuda.reset_peak_memory_stats()
    begin = timeit.default_timer()
    output = func(*args, **kwargs)
    torch.cuda.synchronize()
    end = timeit.default_timer()
    return output, (end - begin), torch.cuda.max_memory_allocated()

现在我们将演示在 MultiheadAttention 层中使用 nested tensors + compile 进行自注意力计算时的性能改进。我们将其与传统的 nn.MultiheadAttention + compile(带 padding 和 masking)进行比较。

N, E_q, E_k, E_v, E_total = 512, 512, 512, 512, 512
E_out = E_q
d_model = E_q
nheads = 8
dropout = 0.0
bias = True
device = "cuda"
torch.manual_seed(6)
query, key, value, sentence_lengths = gen_batch(N, E_q, E_k, E_v, device)
S = sentence_lengths.max().item()
print(
    f"Total sequence length in nested query {sentence_lengths.sum().item()}, max sequence length {S}"
)
padded_query, padded_key, padded_value = (
    t.to_padded_tensor(0.0) for t in (query, key, value)
)

torch.manual_seed(6)
mha_layer = MultiHeadAttention(
    E_q, E_k, E_v, E_total, nheads, dropout=dropout, bias=bias, device="cuda"
)
torch.manual_seed(6)
vanilla_mha_layer = nn.MultiheadAttention(
    E_q, nheads, dropout=dropout, batch_first=True, bias=bias, device="cuda"
)

# ``nn.MultiheadAttention`` uses a non conventional initialization for layers, so do this for exact parity :(
mha_layer.out_proj.weight = nn.Parameter(
    vanilla_mha_layer.out_proj.weight.clone().detach()
)
mha_layer.packed_proj.weight = nn.Parameter(
    vanilla_mha_layer.in_proj_weight.clone().detach()
)
mha_layer.out_proj.bias = nn.Parameter(vanilla_mha_layer.out_proj.bias.clone().detach())
mha_layer.packed_proj.bias = nn.Parameter(
    vanilla_mha_layer.in_proj_bias.clone().detach()
)

new_mha_layer = torch.compile(mha_layer)
# warmup compile
nested_result_warmup = new_mha_layer(query, query, query, is_causal=True)

# benchmark
nested_result, nested_time, nested_peak_memory = benchmark(
    new_mha_layer, query, query, query, is_causal=True
)
padded_nested_result = nested_result.to_padded_tensor(0.0)

# For the vanilla ``nn.MultiheadAttention``, we need to construct the ``key_padding_mask``
# Further, ``nn.MultiheadAttention`` forces one to materialize the ``attn_mask`` even if using ``is_causal``
src_key_padding_mask = torch.where(padded_query == 0.0, -math.inf, 0)[:, :, 0]
attn_mask = torch.empty((N, S, S), device=device).fill_(float("-inf"))
for i, s in enumerate(sentence_lengths):
    attn_mask[i, :s, :s] = nn.Transformer.generate_square_subsequent_mask(s)
attn_mask = attn_mask.unsqueeze(1).expand(N, nheads, S, S).reshape(N * nheads, S, S)

vanilla_mha_layer = torch.compile(vanilla_mha_layer)
# warmup compile
warmup_vanilla_result = vanilla_mha_layer(
    padded_query,
    padded_query,
    padded_query,
    attn_mask=attn_mask,
    key_padding_mask=src_key_padding_mask,
    need_weights=False,
    is_causal=True,
)

# benchmark
(padded_result, _), padded_time, padded_peak_memory = benchmark(
    vanilla_mha_layer,
    padded_query,
    padded_query,
    padded_query,
    key_padding_mask=src_key_padding_mask,
    need_weights=False,
    attn_mask=attn_mask,
    is_causal=True,
)

print(f"{padded_time=:.5f}, padded_peak_memory={padded_peak_memory/1e9:.2f} GB")
print(f"{nested_time=:.5f}, nested_peak_memory={nested_peak_memory/1e9:.2f} GB")
print(
    "Max difference between vanilla and nested result",
    (padded_result - padded_nested_result).abs().max().item(),
)
print(f"Nested speedup: {(padded_time/nested_time):.2f}")
print(
    f"Nested peak memory reduction {((padded_peak_memory - nested_peak_memory)/1e9):.2f} GB"
)
Total sequence length in nested query 10436, max sequence length 128
padded_time=0.01608, padded_peak_memory=3.87 GB
nested_time=0.00254, nested_peak_memory=0.92 GB
Max difference between vanilla and nested result 0.0
Nested speedup: 6.33
Nested peak memory reduction 2.96 GB

作为参考,以下是在 A100 上的样本输出

padded_time=0.03454, padded_peak_memory=4.14 GB
nested_time=0.00612, nested_peak_memory=0.76 GB
Max difference between vanilla and nested result 0.0
Nested speedup: 5.65
Nested peak memory reduction 3.39 GB

我们也可以看到反向传播的相同情况

for i, entry_length in enumerate(sentence_lengths):
    # padding-specific step: remove output projection bias from padded entries for fair comparison
    padded_result[i, entry_length:, :] = 0.0

_, padded_bw_time, padded_bw_peak_mem = benchmark(
    lambda: padded_result.sum().backward()
)
_, nested_bw_time, nested_bw_peak_mem = benchmark(
    lambda: padded_nested_result.sum().backward()
)

print(f"{padded_bw_time=:.5f}, padded_bw_peak_mem={padded_bw_peak_mem/1e9:.2f} GB")
print(f"{nested_bw_time=:.5f}, nested_bw_peak_mem={nested_bw_peak_mem/1e9:.2f} GB")
print(f"Nested backward speedup: {(padded_bw_time/nested_bw_time):.2f}")
print(
    f"Nested backward peak memory reduction {((padded_bw_peak_mem - nested_bw_peak_mem)/1e9):.2f} GB"
)

print(
    "Difference in out_proj.weight.grad",
    (mha_layer.out_proj.weight.grad - vanilla_mha_layer.out_proj.weight.grad)
    .abs()
    .max()
    .item(),
)
print(
    "Difference in packed_proj.weight.grad",
    (mha_layer.packed_proj.weight.grad - vanilla_mha_layer.in_proj_weight.grad)
    .abs()
    .max()
    .item(),
)
print(
    "Difference in out_proj.bias.grad",
    (mha_layer.out_proj.bias.grad - vanilla_mha_layer.out_proj.bias.grad)
    .abs()
    .max()
    .item(),
)
print(
    "Difference in packed_proj.bias.grad",
    (mha_layer.packed_proj.bias.grad - vanilla_mha_layer.in_proj_bias.grad)
    .abs()
    .max()
    .item(),
)
padded_bw_time=1.62963, padded_bw_peak_mem=4.68 GB
nested_bw_time=0.06652, nested_bw_peak_mem=3.04 GB
Nested backward speedup: 24.50
Nested backward peak memory reduction 1.64 GB
Difference in out_proj.weight.grad 0.000396728515625
Difference in packed_proj.weight.grad 0.00146484375
Difference in out_proj.bias.grad 0.0
Difference in packed_proj.bias.grad 0.0029296875

A100 上的样本输出

padded_bw_time=2.09337, padded_bw_peak_mem=5.10 GB
nested_bw_time=0.01452, nested_bw_peak_mem=3.24 GB
Nested backward speedup: 144.13
Nested backward peak memory reduction 1.86 GB
Difference in out_proj.weight.grad 0.000244140625
Difference in packed_proj.weight.grad 0.001556396484375
Difference in out_proj.bias.grad 0.0
Difference in packed_proj.bias.grad 0.001953125

GPT 风格的层

基本的 GPT 风格 Transformer 层包括一个因果自注意力层,后接一个带有 skip connections 的前馈网络 (FFN)。使用上面的 MultiheadAttention 层实现这一点相当简单,并且与使用 is_causal=Truenn.TransformerEncoderLayer 的结果等效。

为简洁起见,本教程省略了实现其他 nn 层的示例,你可以在此处找到它们。

更进一步

到目前为止,我们演示了如何实现遵循传统 nn.MultiheadAttention 的高性能 MultiheadAttention 层。回到我们对 Transformer 架构修改的分类,请记住我们将修改分为层类型、层顺序和注意力分数修改。我们相信改变层类型和层顺序(例如将 LayerNorm 替换为 RMSNorm)是相当简单的。

在本节中,我们将讨论使用上述构建块的各种功能,包括以下内容

  • 交叉注意力

  • 完全遮蔽的行不再导致 NaNs

  • 修改注意力分数:使用 FlexAttention 和 NJT 的 ALiBi

  • Packed Projection

交叉注意力

交叉注意力是一种注意力形式,其中 query 和 key/value 张量来自不同的序列。

一个例子是在 nn.TransformerDecoderLayer 中,其中 query 来自 decoder,而 key/value 来自 encoder。

上述 MultiheadAttention 层使用 nested tensors 对 query 和 key/value 都能很好地推广到这种情况。

query, _, _, q_len = gen_batch(N, E_q, E_k, E_v, device)
_, key, value, kv_len = gen_batch(N, E_q, E_k, E_v, device)

print(
    f"Total sequence length in nested query {q_len.sum().item()}, max sequence length {q_len.max().item()}"
)
print(
    f"Total sequence length in nested key/value {kv_len.sum().item()}, max sequence length {kv_len.max().item()}"
)
out = new_mha_layer(query, key, value, is_causal=False)
Total sequence length in nested query 10617, max sequence length 165
Total sequence length in nested key/value 10176, max sequence length 137

如上所述,我们可以将其与 vanilla 编译的 nn.MultiheadAttention 进行比较。

torch.manual_seed(6)
query, _, _, q_len = gen_batch(N, E_q, E_k, E_v, device)
_, key, value, kv_len = gen_batch(N, E_q, E_k, E_v, device)
padded_query, padded_key, padded_value = (
    t.to_padded_tensor(0.0) for t in (query, key, value)
)

key_padding_mask = torch.where(padded_key == 0.0, -math.inf, 0)[:, :, 0]

# warmup compile
warmup_nested_result = new_mha_layer(query, key, value, is_causal=False)
warmup_vanilla_result = vanilla_mha_layer(
    padded_query,
    padded_key,
    padded_value,
    key_padding_mask=key_padding_mask,
    need_weights=False,
    is_causal=False,
)

nested_result, nested_time, nested_peak_memory = benchmark(
    new_mha_layer, query, key, value, is_causal=False
)
(padded_result, _), padded_time, padded_peak_memory = benchmark(
    vanilla_mha_layer,
    padded_query,
    padded_key,
    padded_value,
    key_padding_mask=key_padding_mask,
    need_weights=False,
    is_causal=False,
)
padded_nested_result = nested_result.to_padded_tensor(0.0)
for i, entry_length in enumerate(q_len):
    # padding-specific step: remove output projection bias from padded entries for fair comparison
    padded_result[i, entry_length:, :] = 0.0

print(
    "Max difference between vanilla and nested result",
    (padded_result - padded_nested_result).abs().max().item(),
)
print(f"Nested speedup: {(padded_time/nested_time):.2f}")
print(
    f"Nested peak memory reduction {((padded_peak_memory - nested_peak_memory)/1e9):.2f} GB"
)
Max difference between vanilla and nested result 0.0
Nested speedup: 4.98
Nested peak memory reduction 1.20 GB

A100 上的样本输出

Max difference between vanilla and nested result 0.0
Nested speedup: 4.01
Nested peak memory reduction 1.40 GB

完全遮蔽的行不再导致 NaNs

长期以来,nn.MultiheadAttentionscaled_dot_product_attention 存在一个问题,即如果一行被完全遮蔽,注意力层的输出将是 NaN。参见该 issue。这是因为空集上的 softmax 是未定义的。

感谢此 PR,这种情况不再发生。相反,scaled_dot_product_attention 中对应于完全遮蔽行的输出将为 0。对于 nn.MHA 不使用“fast-path”的情况,这也将适用。

强烈建议使用带有 NJTs 的自定义 MHA 层,而不是现有 nn.MultiheadAttention 中的“fast-path”,因为 NJT 正确建模不规则性的能力使得能够正确表达空序列。

FlexAttention + NJT

NJT 也可以与 FlexAttention 模块组合。这是对 MultiheadAttention 层的泛化,允许对注意力分数进行任意修改。下面的例子采用 ALiBi 的实现 alibi_mod,来自 attention gym,并将其与 nested 输入张量一起使用。

from torch.nn.attention.flex_attention import flex_attention


def generate_alibi_bias(H: int):
    """Returns an alibi bias score_mod given the number of heads H
    Args:
        H: number of heads
    Returns:
        alibi_bias: alibi bias score_mod
    """

    def alibi_mod(score, b, h, q_idx, kv_idx):
        scale = torch.exp2(-((h + 1) * 8.0 / H))
        bias = (q_idx - kv_idx) * scale
        return score + bias

    return alibi_mod


query, key, value, _ = gen_batch(N, E_q, E_k, E_v, device)
n_heads, D = 8, E_q // 8
alibi_score_mod = generate_alibi_bias(n_heads)
query = query.unflatten(-1, [n_heads, D]).transpose(1, 2).detach().requires_grad_()
key = key.unflatten(-1, [n_heads, D]).transpose(1, 2).detach().requires_grad_()
value = value.unflatten(-1, [n_heads, D]).transpose(1, 2).detach().requires_grad_()
out_flex2 = flex_attention(query, key, value, score_mod=alibi_score_mod)

此外,也可以通过 create_nested_block_mask 函数将 FlexAttentionblock_mask 实用工具与 NJTs 一起使用。这对于利用 mask 的稀疏性加速注意力计算很有用。特别是,该函数会为 NJT 中所有可变长度序列合并到一起的“堆叠序列”创建一个稀疏的块 mask,同时正确屏蔽序列间的注意力。在下面的例子中,我们展示了如何使用此实用工具创建一个因果块 mask。

from torch.nn.attention.flex_attention import create_nested_block_mask


def causal_mask(b, h, q_idx, kv_idx):
    return q_idx >= kv_idx


query, key, value, _ = gen_batch(N, E_q, E_k, E_v, device)
block_mask = create_nested_block_mask(causal_mask, 1, 1, query, _compile=True)
query = query.unflatten(-1, [n_heads, D]).transpose(1, 2).detach().requires_grad_()
key = key.unflatten(-1, [n_heads, D]).transpose(1, 2).detach().requires_grad_()
value = value.unflatten(-1, [n_heads, D]).transpose(1, 2).detach().requires_grad_()
out_flex = flex_attention(query, key, value, block_mask=block_mask)

Packed Projection

Packed projection 是一种技术,利用了当 projection 的输入(矩阵乘法)相同时(自注意力)的特点,可以将 projection 权重和偏差打包到单个张量中。当单个 projection 受内存限制而非计算限制时,它特别有用。这里我们将演示两个示例

  • MultiheadAttention 的输入 projection

  • Transformer 层前馈网络中的 SwiGLU activation

MultiheadAttention 的输入 projection

在进行自注意力时,querykeyvalue 是同一个张量。这些张量中的每一个都通过一个 Linear(E_q, E_total) 层进行 projection。我们可以将这打包到一个层中,这正是我们在上面的 MultiheadAttention 层中所做的。

让我们比较 packed projection 与常规方法的性能

class InputProjection(nn.Module):
    def __init__(self, E_q, E_total, bias=False, device=None, dtype=None):
        factory_kwargs = {"device": device, "dtype": dtype}
        super().__init__()
        self.q_proj = nn.Linear(E_q, E_total, bias=bias, **factory_kwargs)
        self.k_proj = nn.Linear(E_q, E_total, bias=bias, **factory_kwargs)
        self.v_proj = nn.Linear(E_q, E_total, bias=bias, **factory_kwargs)

    def forward(self, x):
        return self.q_proj(x), self.k_proj(x), self.v_proj(x)


class PackedInputProjection(nn.Module):
    def __init__(self, E_q, E_total, bias=False, device=None, dtype=None):
        factory_kwargs = {"device": device, "dtype": dtype}
        super().__init__()
        self.packed_proj = nn.Linear(E_q, E_total * 3, bias=bias, **factory_kwargs)

    def forward(self, query):
        return torch.chunk(self.packed_proj(query), 3, dim=-1)


B, D, dtype = 256, 8192, torch.bfloat16

torch.set_float32_matmul_precision("high")
in_proj = torch.compile(InputProjection(D, D, device="cuda", dtype=torch.bfloat16))
packed_in_proj = torch.compile(
    PackedInputProjection(D, D, device="cuda", dtype=torch.bfloat16)
)

q, _, _, sequence_lengths = gen_batch(B, D, D, D, device="cuda", dtype=torch.bfloat16)

# warmup
in_proj(q)
packed_in_proj(q)

# benchmark
(q_out, k_out, v_out), time, _ = benchmark(in_proj, q)
(q_out, k_out, v_out), time_packed, _ = benchmark(packed_in_proj, q)
# On my A100 prints 1.05x speedup
print(
    f"InputProjection: {time:5f} s, PackedInputProjection: {time_packed:5f} s, speedup: {time/time_packed:.2f}x"
)
InputProjection: 0.034046 s, PackedInputProjection: 0.032757 s, speedup: 1.04x

Transformer 层的前馈网络中的 SwiGLU

Swish-Gated Linear Unit (SwiGLU) 是一种非线性激活函数,在 Transformer 层的前馈网络中越来越受欢迎(例如 Llama)。带有 SwiGLU 激活的前馈网络定义如下

class SwiGLUFFN(nn.Module):
    def __init__(
        self,
        dim,
        hidden_dim,
        multiple_of,
        ffn_dim_multiplier=None,
        device=None,
        dtype=None,
    ):
        factory_kwargs = {"device": device, "dtype": dtype}
        super().__init__()
        hidden_dim = int(2 * hidden_dim / 3)
        # custom dim factor multiplier
        if ffn_dim_multiplier is not None:
            hidden_dim = int(ffn_dim_multiplier * hidden_dim)
        hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of)

        self.w1 = nn.Linear(dim, hidden_dim, bias=False, **factory_kwargs)
        self.w2 = nn.Linear(hidden_dim, dim, bias=False, **factory_kwargs)
        self.w3 = nn.Linear(dim, hidden_dim, bias=False, **factory_kwargs)

    def forward(self, x):
        return self.w2(F.silu(self.w1(x)) * self.w3(x))

使用 packed projection 的另一种实现方法是

class PackedSwiGLUFFN(nn.Module):
    def __init__(
        self,
        dim,
        hidden_dim,
        multiple_of,
        ffn_dim_multiplier=None,
        device=None,
        dtype=None,
    ):
        factory_kwargs = {"device": device, "dtype": dtype}
        super().__init__()
        hidden_dim = int(2 * hidden_dim / 3)
        # custom dim factor multiplier
        if ffn_dim_multiplier is not None:
            hidden_dim = int(ffn_dim_multiplier * hidden_dim)
        hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of)

        self.w13 = nn.Linear(dim, 2 * hidden_dim, bias=False, **factory_kwargs)
        self.w2 = nn.Linear(hidden_dim, dim, bias=False, **factory_kwargs)

    def forward(self, x):
        x1, x3 = torch.chunk(self.w13(x), 2, dim=-1)
        return self.w2(F.silu(x1) * x3)

我们可以如下比较这两种实现的性能。根据你的硬件,结果可能会有所不同。在 A100 上,我看到 D=128 时有 1.12 倍的加速。

D = 128

swigluffn = torch.compile(SwiGLUFFN(D, D * 4, 256, device="cuda", dtype=torch.bfloat16))
packed_swigluffn = torch.compile(
    PackedSwiGLUFFN(D, D * 4, 256, device="cuda", dtype=torch.bfloat16)
)

q, _, _, sentence_lengths = gen_batch(D, D, D, D, device="cuda", dtype=torch.bfloat16)

# warmup
swigluffn(q)
packed_swigluffn(q)

# benchmark
_, time, _ = benchmark(swigluffn, q)
_, time_packed, _ = benchmark(packed_swigluffn, q)
# On my A100 prints 1.08x speedup
print(
    f"SwiGLUFFN: {time} s, PackedSwiGLUFFN: {time_packed} s, speedup: {time/time_packed:.2f}x"
)
SwiGLUFFN: 0.0010205730000052426 s, PackedSwiGLUFFN: 0.0010395129997959884 s, speedup: 0.98x

扩展示例

我们计划更新本教程,以演示更多如何使用各种高性能构建块(如 KV-Caching、Grouped Query Attention 等)的示例。此外,还有一些很好的例子,展示了如何使用各种高性能构建块来实现不同的 Transformer 架构。一些示例包括

结论

在本教程中,我们介绍了 PyTorch 提供的用于编写 Transformer 层的底层构建块,并演示了如何组合它们的示例。我们希望本教程能让读者了解 PyTorch 用户可以多么轻松地实现灵活且高性能的 Transformer 层。

脚本总运行时间: ( 1 分 6.445 秒)

由 Sphinx-Gallery 生成的图集

文档

访问 PyTorch 的完整开发者文档

查看文档

教程

获取针对初学者和高级开发人员的深度教程

查看教程

资源

查找开发资源并获取问题解答

查看资源