• 教程 >
  • (Beta) 使用缩放点积注意力 (SDPA) 实现高性能 Transformer
快捷方式

(Beta) 使用缩放点积注意力 (SDPA) 实现高性能 Transformer

作者: Driss Guessous

摘要

在本教程中,我们希望重点介绍一个新的 torch.nn.functional 函数,该函数有助于实现 Transformer 架构。该函数名为 torch.nn.functional.scaled_dot_product_attention。有关该函数的详细描述,请参阅 PyTorch 文档。此函数已整合到 torch.nn.MultiheadAttentiontorch.nn.TransformerEncoderLayer 中。

概述

从高层次来看,此 PyTorch 函数根据论文 Attention is all you need 中的定义计算查询、键和值的缩放点积注意力 (SDPA)。虽然可以使用现有的函数在 PyTorch 中编写此函数,但融合实现可以比朴素实现提供更大的性能优势。

融合实现

对于 CUDA 张量输入,该函数将调度到以下实现之一

注意

本教程需要 PyTorch 2.0.0 或更高版本。

import torch
import torch.nn as nn
import torch.nn.functional as F
device = "cuda" if torch.cuda.is_available() else "cpu"

# Example Usage:
query, key, value = torch.randn(2, 3, 8, device=device), torch.randn(2, 3, 8, device=device), torch.randn(2, 3, 8, device=device)
F.scaled_dot_product_attention(query, key, value)
tensor([[[-1.3321, -0.3489,  0.3015, -0.3912,  0.9867,  0.3137, -0.0691,
          -1.2593],
         [-1.0882,  0.2506,  0.6491,  0.1360,  0.5238, -0.2448, -0.0820,
          -0.6171],
         [-1.0012,  0.3990,  0.6441, -0.0277,  0.5325, -0.2564, -0.0607,
          -0.6404]],

        [[ 0.6091,  0.0708,  0.6188,  0.3252, -0.1598,  0.4197, -0.2335,
           0.0630],
         [ 0.5285,  0.3890, -0.2649,  0.3706, -0.3839,  0.1963, -0.6242,
           0.2312],
         [ 0.4048,  0.0762,  0.3777,  0.4689, -0.2978,  0.2754, -0.6429,
           0.1037]]], device='cuda:0')

显式调度器控制

虽然该函数会隐式地调度到三个实现之一,但用户也可以通过使用上下文管理器来显式控制调度。该上下文管理器允许用户显式禁用某些实现。如果用户希望确保该函数确实在使用针对其特定输入的最快实现,则可以使用上下文管理器来遍历测量性能。

# Lets define a helpful benchmarking function:
import torch.utils.benchmark as benchmark
def benchmark_torch_function_in_microseconds(f, *args, **kwargs):
    t0 = benchmark.Timer(
        stmt="f(*args, **kwargs)", globals={"args": args, "kwargs": kwargs, "f": f}
    )
    return t0.blocked_autorange().mean * 1e6

# Lets define the hyper-parameters of our input
batch_size = 32
max_sequence_len = 1024
num_heads = 32
embed_dimension = 32

dtype = torch.float16

query = torch.rand(batch_size, num_heads, max_sequence_len, embed_dimension, device=device, dtype=dtype)
key = torch.rand(batch_size, num_heads, max_sequence_len, embed_dimension, device=device, dtype=dtype)
value = torch.rand(batch_size, num_heads, max_sequence_len, embed_dimension, device=device, dtype=dtype)

print(f"The default implementation runs in {benchmark_torch_function_in_microseconds(F.scaled_dot_product_attention, query, key, value):.3f} microseconds")

# Lets explore the speed of each of the 3 implementations
from torch.nn.attention import SDPBackend, sdpa_kernel


with sdpa_kernel(SDPBackend.MATH):
    math_time=benchmark_torch_function_in_microseconds(F.scaled_dot_product_attention, query, key, value)
    print(f"The math implementation runs in {math_time:.3f} microseconds")

with sdpa_kernel(SDPBackend.FLASH_ATTENTION):
    try:
        flash_time=benchmark_torch_function_in_microseconds(F.scaled_dot_product_attention, query, key, value)
        print(f"The flash attention implementation runs in {flash_time:.3f} microseconds")
    except RuntimeError:
        print("FlashAttention is not supported. See warnings for reasons.")

with sdpa_kernel(SDPBackend.EFFICIENT_ATTENTION):
    try:
        efficient_time=benchmark_torch_function_in_microseconds(F.scaled_dot_product_attention, query, key, value)
        print(f"The memory efficient implementation runs in {efficient_time:.3f} microseconds")
    except RuntimeError:
        print("EfficientAttention is not supported. See warnings for reasons.")
The default implementation runs in 2329.345 microseconds
The math implementation runs in 87495.058 microseconds
The flash attention implementation runs in 2330.364 microseconds
The memory efficient implementation runs in 4432.501 microseconds

硬件依赖

根据您在上面单元格上运行的机器以及可用的硬件,您的结果可能会有所不同。 - 如果你没有 GPU,并且在 CPU 上运行,那么使用 FP32 时,上下文管理器将没有效果,所有三个运行应该返回相似的计时。 - 根据您的显卡支持的计算能力,flash attention 或 memory efficient 可能已经失败。

因果自注意力

以下是一个受 Andrej Karpathy NanoGPT 库启发的多头因果自注意力块的示例实现。

class CausalSelfAttention(nn.Module):

    def __init__(self, num_heads: int, embed_dimension: int, bias: bool=False, is_causal: bool=False, dropout:float=0.0):
        super().__init__()
        assert embed_dimension % num_heads == 0
        # key, query, value projections for all heads, but in a batch
        self.c_attn = nn.Linear(embed_dimension, 3 * embed_dimension, bias=bias)
        # output projection
        self.c_proj = nn.Linear(embed_dimension, embed_dimension, bias=bias)
        # regularization
        self.dropout = dropout
        self.resid_dropout = nn.Dropout(dropout)
        self.num_heads = num_heads
        self.embed_dimension = embed_dimension
        # Perform causal masking
        self.is_causal = is_causal

    def forward(self, x):
        # calculate query, key, values for all heads in batch and move head forward to be the batch dim
        query_projected = self.c_attn(x)

        batch_size = query_projected.size(0)
        embed_dim = query_projected.size(2)
        head_dim = embed_dim // (self.num_heads * 3)

        query, key, value = query_projected.chunk(3, -1)
        query = query.view(batch_size, -1, self.num_heads, head_dim).transpose(1, 2)
        key = key.view(batch_size, -1, self.num_heads, head_dim).transpose(1, 2)
        value = value.view(batch_size, -1, self.num_heads, head_dim).transpose(1, 2)

        if self.training:
            dropout = self.dropout
            is_causal = self.is_causal
        else:
            dropout = 0.0
            is_causal = False

        y = F.scaled_dot_product_attention(query, key, value, attn_mask=None, dropout_p=dropout, is_causal=is_causal)
        y = y.transpose(1, 2).view(batch_size, -1, self.num_heads * head_dim)

        y = self.resid_dropout(self.c_proj(y))
        return y


num_heads = 8
heads_per_dim = 64
embed_dimension = num_heads * heads_per_dim
dtype = torch.float16
model = CausalSelfAttention(num_heads=num_heads, embed_dimension=embed_dimension, bias=False, is_causal=True, dropout=0.1).to("cuda").to(dtype).eval()
print(model)
CausalSelfAttention(
  (c_attn): Linear(in_features=512, out_features=1536, bias=False)
  (c_proj): Linear(in_features=512, out_features=512, bias=False)
  (resid_dropout): Dropout(p=0.1, inplace=False)
)

NestedTensor 和密集张量支持

SDPA 支持 NestedTensor 和密集张量输入。 NestedTensors 处理输入是可变长度序列批次的情况,而无需将每个序列填充到批次中的最大长度。有关 NestedTensors 的更多信息,请参阅 torch.nestedNestedTensors Tutorial

import random
def generate_rand_batch(
    batch_size,
    max_sequence_len,
    embed_dimension,
    pad_percentage=None,
    dtype=torch.float16,
    device="cuda",
):
    if not pad_percentage:
        return (
            torch.randn(
                batch_size,
                max_sequence_len,
                embed_dimension,
                dtype=dtype,
                device=device,
            ),
            None,
        )
    # Random sequence lengths
    seq_len_list = [
        int(max_sequence_len * (1 - random.gauss(pad_percentage, 0.01)))
        for _ in range(batch_size)
    ]
    # Make random entry in the batch have max sequence length
    seq_len_list[random.randint(0, batch_size - 1)] = max_sequence_len
    return (
        torch.nested.nested_tensor(
            [
                torch.randn(seq_len, embed_dimension,
                            dtype=dtype, device=device)
                for seq_len in seq_len_list
            ]
        ),
        seq_len_list,
    )

random_nt, _ = generate_rand_batch(32, 512, embed_dimension, pad_percentage=0.5, dtype=dtype, device=device)
random_dense, _ = generate_rand_batch(32, 512, embed_dimension, pad_percentage=None, dtype=dtype, device=device)

# Currently the fused implementations don't support ``NestedTensor`` for training
model.eval()

with sdpa_kernel(SDPBackend.FLASH_ATTENTION):
    try:
        print(f"Random NT runs in {benchmark_torch_function_in_microseconds(model, random_nt):.3f} microseconds")
        print(f"Random Dense runs in {benchmark_torch_function_in_microseconds(model, random_dense):.3f} microseconds")
    except RuntimeError:
        print("FlashAttention is not supported. See warnings for reasons.")
/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/nested/__init__.py:226: UserWarning:

The PyTorch API of nested tensors is in prototype stage and will change in the near future. (Triggered internally at ../aten/src/ATen/NestedTensorImpl.cpp:178.)

Random NT runs in 563.879 microseconds
Random Dense runs in 947.754 microseconds

将 SDPA 与 torch.compile 一起使用

随着 PyTorch 2.0 的发布,引入了一个名为 torch.compile() 的新功能,它可以提供比急切模式显著的性能改进。缩放点积注意力与 torch.compile() 完全可组合。为了演示这一点,让我们使用 torch.compile() 编译 CausalSelfAttention 模块,并观察由此产生的性能改进。

batch_size = 32
max_sequence_len = 256
x = torch.rand(batch_size, max_sequence_len,
               embed_dimension, device=device, dtype=dtype)
print(
    f"The non compiled module runs in  {benchmark_torch_function_in_microseconds(model, x):.3f} microseconds")


compiled_model = torch.compile(model)
# Let's compile it
compiled_model(x)
print(
    f"The compiled module runs in  {benchmark_torch_function_in_microseconds(compiled_model, x):.3f} microseconds")
The non compiled module runs in  415.594 microseconds
The compiled module runs in  515.778 microseconds

确切的执行时间取决于机器,但我的结果是:未编译的模块在 166.616 微秒内运行,编译的模块在 166.726 微秒内运行。这不是我们预期的。让我们深入了解一下。PyTorch 附带了一个出色的内置探查器,您可以使用它来检查代码的性能特征。

from torch.profiler import profile, record_function, ProfilerActivity
activities = [ProfilerActivity.CPU]
if device == 'cuda':
    activities.append(ProfilerActivity.CUDA)

with profile(activities=activities, record_shapes=False) as prof:
    with record_function(" Non-Compilied Causal Attention"):
        for _ in range(25):
            model(x)
print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10))


with profile(activities=activities, record_shapes=False) as prof:
    with record_function("Compiled Causal Attention"):
        for _ in range(25):
            compiled_model(x)
print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10))

# For even more insights, you can export the trace and use ``chrome://tracing`` to view the results
#
# .. code-block:: python
#
#    prof.export_chrome_trace("compiled_causal_attention_trace.json").
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------
                                                   Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg     Self CUDA   Self CUDA %    CUDA total  CUDA time avg    # of Calls
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------
                         Non-Compilied Causal Attention         0.00%       0.000us         0.00%       0.000us       0.000us      10.507ms       101.40%      10.507ms      10.507ms             1
                         Non-Compilied Causal Attention        20.35%       2.244ms        77.40%       8.535ms       8.535ms       0.000us         0.00%      10.362ms      10.362ms             1
                                           aten::linear         1.14%     126.191us        28.92%       3.189ms      63.774us       0.000us         0.00%       7.766ms     155.310us            50
                                           aten::matmul         2.45%     270.234us        24.84%       2.739ms      54.780us       0.000us         0.00%       7.766ms     155.310us            50
                                               aten::mm        15.63%       1.724ms        19.92%       2.197ms      43.937us       7.766ms        74.94%       7.766ms     155.310us            50
         ampere_fp16_s1688gemm_fp16_128x128_ldg8_f2f_tn         0.00%       0.000us         0.00%       0.000us       0.000us       5.558ms        53.63%       5.558ms     222.303us            25
                     aten::scaled_dot_product_attention         2.02%     222.243us        18.60%       2.051ms      82.045us       0.000us         0.00%       2.596ms     103.858us            25
              aten::_scaled_dot_product_flash_attention         3.07%     338.835us        16.58%       1.829ms      73.155us       0.000us         0.00%       2.596ms     103.858us            25
                         aten::_flash_attention_forward         3.71%     409.218us        11.26%       1.242ms      49.671us       2.596ms        25.06%       2.596ms     103.858us            25
void pytorch_flash::flash_fwd_kernel<pytorch_flash::...         0.00%       0.000us         0.00%       0.000us       0.000us       2.596ms        25.06%       2.596ms     103.858us            25
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------
Self CPU time total: 11.028ms
Self CUDA time total: 10.362ms

-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------
                                                   Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg     Self CUDA   Self CUDA %    CUDA total  CUDA time avg    # of Calls
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------
                              Compiled Causal Attention         0.00%       0.000us         0.00%       0.000us       0.000us      10.487ms       100.96%      10.487ms      10.487ms             1
                              Compiled Causal Attention         8.24%     922.698us        73.56%       8.236ms       8.236ms       0.000us         0.00%      10.387ms      10.387ms             1
                                  Torch-Compiled Region         8.02%     898.065us        63.20%       7.076ms     283.030us       0.000us         0.00%      10.387ms     415.474us            25
                                       CompiledFunction        26.25%       2.938ms        55.18%       6.178ms     247.107us       0.000us         0.00%      10.387ms     415.474us            25
                                               aten::mm         9.84%       1.102ms        14.31%       1.602ms      32.038us       7.782ms        74.92%       7.782ms     155.647us            50
         ampere_fp16_s1688gemm_fp16_128x128_ldg8_f2f_tn         0.00%       0.000us         0.00%       0.000us       0.000us       5.573ms        53.65%       5.573ms     222.907us            25
              aten::_scaled_dot_product_flash_attention         2.28%     255.782us        14.62%       1.637ms      65.491us       0.000us         0.00%       2.605ms     104.180us            25
                         aten::_flash_attention_forward         3.58%     401.324us        10.54%       1.180ms      47.216us       2.605ms        25.08%       2.605ms     104.180us            25
void pytorch_flash::flash_fwd_kernel<pytorch_flash::...         0.00%       0.000us         0.00%       0.000us       0.000us       2.605ms        25.08%       2.605ms     104.180us            25
ampere_fp16_s1688gemm_fp16_128x128_ldg8_f2f_stages_3...         0.00%       0.000us         0.00%       0.000us       0.000us       2.210ms        21.27%       2.210ms      88.388us            25
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------
Self CPU time total: 11.196ms
Self CUDA time total: 10.387ms

之前的代码片段会生成一份报告,其中列出了消耗最多 GPU 执行时间的 10 个 PyTorch 函数,包括编译的和未编译的模块。分析表明,对于这两个模块,大部分在 GPU 上花费的时间都集中在相同的函数集上。这里的原因是 torch.compile 非常擅长消除与 PyTorch 相关的框架开销。如果您的模型正在启动大型高效的 CUDA 内核(在本例中为 CausalSelfAttention),那么 PyTorch 的开销就可以隐藏起来。

实际上,您的模块通常不会仅包含一个 CausalSelfAttention 块。在使用 Andrej Karpathy NanoGPT 库进行实验时,编译模块将每次训练步骤的时间从:6090.49ms 降低到 3273.17ms!这是在 NanoGPT 的 ae3a8d5 提交上完成的,该提交在莎士比亚数据集上进行训练。

将 SDPA 与 attn_bias 子类一起使用

# As of PyTorch 2.3, we have added a new submodule that contains tensor subclasses.
# Designed to be used with ``torch.nn.functional.scaled_dot_product_attention``.
# The module is named ``torch.nn.attention.bias`` and contains the following two
# utilities for generating causal attention variants:
#
# - ``torch.nn.attention.bias.causal_upper_left``
# - ``torch.nn.attention.bias.causal_lower_right``
#
# .. note::
#    The current argument ``is_causal`` in ``torch.nn.functional.scaled_dot_product_attention``
#    is the same as using ``torch.nn.attention.bias.causal_upper_left``.
#

from torch.nn.attention.bias import causal_lower_right, causal_upper_left

batch_size = 32
sequence_length_q = 2
sequence_length_kv = 10
num_heads = 16
embed_dimension = 32

dtype = torch.float16

query = torch.rand(batch_size, num_heads, sequence_length_q, embed_dimension, device=device, dtype=dtype)
key = torch.rand(batch_size, num_heads, sequence_length_kv, embed_dimension, device=device, dtype=dtype)
value = torch.rand(batch_size, num_heads, sequence_length_kv, embed_dimension, device=device, dtype=dtype)

upper_left_bias = causal_upper_left(sequence_length_q, sequence_length_kv)
lower_right_bias = causal_lower_right(sequence_length_q, sequence_length_kv)

print(type(upper_left_bias))
print(type(lower_right_bias))

assert type(upper_left_bias) == type(lower_right_bias)
assert issubclass(type(upper_left_bias), torch.Tensor)

# As you can see from the previous output, are the same type ``torch.nn.attention.bias.CausalBias``
# and subclass ``torch.Tensor``

# Lets see what these tensors look like
print(upper_left_bias)
print(lower_right_bias)

# Upper Left Bias aligns the causal attention mask to the upper left corner of the attention scores matrix.
# This only has an impact when the attention scores matrix is not square, which is common for decoding use cases.
# Another way of thinking about this concept is that when you use upper left bias,
# the 0th token in the query is aligned to the 0th token in the key, while for lower right bias,
# Assuming the attention score matrix is two dimensional, ``attn_score[0][0]`` is the attention score
# between the 0th token in the query and the 0th token in the key.
# For lower right bias, the sequence of q is aligned so that the last token in q is aligned to the last token in k
# (for example, ``attn_score[-1][-1])`` is all True since the last token in q is at the same position as the last token in k
# even if the sequence length of q and k are different.

# These objects are intended to be used with sdpa
out_upper_left = F.scaled_dot_product_attention(query, key, value, upper_left_bias)
out_lower_right = F.scaled_dot_product_attention(query, key, value, lower_right_bias)
out_is_causal = F.scaled_dot_product_attention(query, key, value, is_causal=True)

assert torch.allclose(out_upper_left, out_is_causal)
assert not torch.allclose(out_upper_left, out_lower_right)

# These attention biases should also be compatible with torch.compile
compiled_sdpa = torch.compile(F.scaled_dot_product_attention, fullgraph=True)
out_upper_left = compiled_sdpa(query, key, value, upper_left_bias)
<class 'torch.nn.attention.bias.CausalBias'>
<class 'torch.nn.attention.bias.CausalBias'>
tensor([[ True, False, False, False, False, False, False, False, False, False],
        [ True,  True, False, False, False, False, False, False, False, False]])
tensor([[ True,  True,  True,  True,  True,  True,  True,  True,  True, False],
        [ True,  True,  True,  True,  True,  True,  True,  True,  True,  True]])

结论

在本教程中,我们演示了 torch.nn.functional.scaled_dot_product_attention 的基本用法。我们展示了如何使用 sdpa_kernel 上下文管理器来断言在 GPU 上使用特定实现。此外,我们构建了一个简单的 CausalSelfAttention 模块,它可以使用 NestedTensor 并且可以进行 torch 编译。在这个过程中,我们展示了如何使用探查工具来探索用户定义模块的性能特征。

脚本的总运行时间: ( 0 分钟 7.713 秒)

由 Sphinx-Gallery 生成的图库

文档

访问 PyTorch 的全面开发者文档

查看文档

教程

获取针对初学者和高级开发人员的深入教程

查看教程

资源

查找开发资源并获得问题的解答

查看资源