注意
点击此处下载完整示例代码
(Beta) 使用缩放点积注意力 (SDPA) 实现高性能 Transformer¶
Created On: Mar 15, 2023 | Last Updated: Oct 09, 2024 | Last Verified: Nov 05, 2024
作者: Driss Guessous
摘要¶
在本教程中,我们将重点介绍一个新的 torch.nn.functional
函数,它有助于实现 Transformer 架构。该函数名为 torch.nn.functional.scaled_dot_product_attention
。有关该函数的详细说明,请参阅 PyTorch 文档。此函数已集成到 torch.nn.MultiheadAttention
和 torch.nn.TransformerEncoderLayer
中。
概述¶
从高层来看,这个 PyTorch 函数根据论文 Attention is all you need 中的定义计算 query、key 和 value 之间的缩放点积注意力 (SDPA)。虽然可以使用现有函数在 PyTorch 中编写此函数,但融合实现比朴素实现能带来巨大的性能优势。
融合实现¶
对于 CUDA 张量输入,该函数将分派到以下实现之一:
FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness
在 C++ 中定义的 PyTorch 实现
注意
本教程需要 PyTorch 2.0.0 或更高版本。
import torch
import torch.nn as nn
import torch.nn.functional as F
device = "cuda" if torch.cuda.is_available() else "cpu"
# Example Usage:
query, key, value = torch.randn(2, 3, 8, device=device), torch.randn(2, 3, 8, device=device), torch.randn(2, 3, 8, device=device)
F.scaled_dot_product_attention(query, key, value)
tensor([[[-1.3321, -0.3489, 0.3015, -0.3912, 0.9867, 0.3137, -0.0691,
-1.2593],
[-1.0882, 0.2506, 0.6491, 0.1360, 0.5238, -0.2448, -0.0820,
-0.6171],
[-1.0012, 0.3990, 0.6441, -0.0277, 0.5325, -0.2564, -0.0607,
-0.6404]],
[[ 0.6091, 0.0708, 0.6188, 0.3252, -0.1598, 0.4197, -0.2335,
0.0630],
[ 0.5285, 0.3890, -0.2649, 0.3706, -0.3839, 0.1963, -0.6242,
0.2312],
[ 0.4048, 0.0762, 0.3777, 0.4689, -0.2978, 0.2754, -0.6429,
0.1037]]], device='cuda:0')
显式分派器控制¶
虽然该函数会隐式分派到三种实现之一,但用户也可以通过使用上下文管理器来显式控制分派。此上下文管理器允许用户显式禁用某些实现。如果用户希望确保该函数确实为其特定输入使用了最快的实现,可以使用上下文管理器来遍历测量性能。
# Lets define a helpful benchmarking function:
import torch.utils.benchmark as benchmark
def benchmark_torch_function_in_microseconds(f, *args, **kwargs):
t0 = benchmark.Timer(
stmt="f(*args, **kwargs)", globals={"args": args, "kwargs": kwargs, "f": f}
)
return t0.blocked_autorange().mean * 1e6
# Lets define the hyper-parameters of our input
batch_size = 32
max_sequence_len = 1024
num_heads = 32
embed_dimension = 32
dtype = torch.float16
query = torch.rand(batch_size, num_heads, max_sequence_len, embed_dimension, device=device, dtype=dtype)
key = torch.rand(batch_size, num_heads, max_sequence_len, embed_dimension, device=device, dtype=dtype)
value = torch.rand(batch_size, num_heads, max_sequence_len, embed_dimension, device=device, dtype=dtype)
print(f"The default implementation runs in {benchmark_torch_function_in_microseconds(F.scaled_dot_product_attention, query, key, value):.3f} microseconds")
# Lets explore the speed of each of the 3 implementations
from torch.nn.attention import SDPBackend, sdpa_kernel
with sdpa_kernel(SDPBackend.MATH):
math_time=benchmark_torch_function_in_microseconds(F.scaled_dot_product_attention, query, key, value)
print(f"The math implementation runs in {math_time:.3f} microseconds")
with sdpa_kernel(SDPBackend.FLASH_ATTENTION):
try:
flash_time=benchmark_torch_function_in_microseconds(F.scaled_dot_product_attention, query, key, value)
print(f"The flash attention implementation runs in {flash_time:.3f} microseconds")
except RuntimeError:
print("FlashAttention is not supported. See warnings for reasons.")
with sdpa_kernel(SDPBackend.EFFICIENT_ATTENTION):
try:
efficient_time=benchmark_torch_function_in_microseconds(F.scaled_dot_product_attention, query, key, value)
print(f"The memory efficient implementation runs in {efficient_time:.3f} microseconds")
except RuntimeError:
print("EfficientAttention is not supported. See warnings for reasons.")
The default implementation runs in 2277.732 microseconds
The math implementation runs in 87525.708 microseconds
The flash attention implementation runs in 2280.833 microseconds
The memory efficient implementation runs in 4409.952 microseconds
硬件依赖性¶
根据您运行上述代码的机器以及可用的硬件,您的结果可能会有所不同。- 如果您没有 GPU 并在 CPU 上运行,那么对于 FP32,上下文管理器将不起作用,所有三次运行应返回相似的时间。- 根据您的显卡支持的计算能力,FlashAttention 或内存高效注意力可能已失败。
因果自注意力¶
下面是一个多头因果自注意力块的示例实现,其灵感来源于 Andrej Karpathy 的 NanoGPT 仓库。
class CausalSelfAttention(nn.Module):
def __init__(self, num_heads: int, embed_dimension: int, bias: bool=False, is_causal: bool=False, dropout:float=0.0):
super().__init__()
assert embed_dimension % num_heads == 0
# key, query, value projections for all heads, but in a batch
self.c_attn = nn.Linear(embed_dimension, 3 * embed_dimension, bias=bias)
# output projection
self.c_proj = nn.Linear(embed_dimension, embed_dimension, bias=bias)
# regularization
self.dropout = dropout
self.resid_dropout = nn.Dropout(dropout)
self.num_heads = num_heads
self.embed_dimension = embed_dimension
# Perform causal masking
self.is_causal = is_causal
def forward(self, x):
# calculate query, key, values for all heads in batch and move head forward to be the batch dim
query_projected = self.c_attn(x)
batch_size = query_projected.size(0)
embed_dim = query_projected.size(2)
head_dim = embed_dim // (self.num_heads * 3)
query, key, value = query_projected.chunk(3, -1)
query = query.view(batch_size, -1, self.num_heads, head_dim).transpose(1, 2)
key = key.view(batch_size, -1, self.num_heads, head_dim).transpose(1, 2)
value = value.view(batch_size, -1, self.num_heads, head_dim).transpose(1, 2)
if self.training:
dropout = self.dropout
is_causal = self.is_causal
else:
dropout = 0.0
is_causal = False
y = F.scaled_dot_product_attention(query, key, value, attn_mask=None, dropout_p=dropout, is_causal=is_causal)
y = y.transpose(1, 2).view(batch_size, -1, self.num_heads * head_dim)
y = self.resid_dropout(self.c_proj(y))
return y
num_heads = 8
heads_per_dim = 64
embed_dimension = num_heads * heads_per_dim
dtype = torch.float16
model = CausalSelfAttention(num_heads=num_heads, embed_dimension=embed_dimension, bias=False, is_causal=True, dropout=0.1).to("cuda").to(dtype).eval()
print(model)
CausalSelfAttention(
(c_attn): Linear(in_features=512, out_features=1536, bias=False)
(c_proj): Linear(in_features=512, out_features=512, bias=False)
(resid_dropout): Dropout(p=0.1, inplace=False)
)
NestedTensor
和密集张量支持¶
SDPA 支持 NestedTensor
和密集张量输入。NestedTensor
处理输入是变长序列批量的情况,而无需将每个序列填充到批次中的最大长度。有关 NestedTensor
的更多信息,请参阅 torch.nested 和 NestedTensors 教程。
import random
def generate_rand_batch(
batch_size,
max_sequence_len,
embed_dimension,
pad_percentage=None,
dtype=torch.float16,
device="cuda",
):
if not pad_percentage:
return (
torch.randn(
batch_size,
max_sequence_len,
embed_dimension,
dtype=dtype,
device=device,
),
None,
)
# Random sequence lengths
seq_len_list = [
int(max_sequence_len * (1 - random.gauss(pad_percentage, 0.01)))
for _ in range(batch_size)
]
# Make random entry in the batch have max sequence length
seq_len_list[random.randint(0, batch_size - 1)] = max_sequence_len
return (
torch.nested.nested_tensor(
[
torch.randn(seq_len, embed_dimension,
dtype=dtype, device=device)
for seq_len in seq_len_list
]
),
seq_len_list,
)
random_nt, _ = generate_rand_batch(32, 512, embed_dimension, pad_percentage=0.5, dtype=dtype, device=device)
random_dense, _ = generate_rand_batch(32, 512, embed_dimension, pad_percentage=None, dtype=dtype, device=device)
# Currently the fused implementations don't support ``NestedTensor`` for training
model.eval()
with sdpa_kernel(SDPBackend.FLASH_ATTENTION):
try:
print(f"Random NT runs in {benchmark_torch_function_in_microseconds(model, random_nt):.3f} microseconds")
print(f"Random Dense runs in {benchmark_torch_function_in_microseconds(model, random_dense):.3f} microseconds")
except RuntimeError:
print("FlashAttention is not supported. See warnings for reasons.")
/var/lib/ci-user/.local/lib/python3.10/site-packages/torch/nested/__init__.py:250: UserWarning:
The PyTorch API of nested tensors is in prototype stage and will change in the near future. We recommend specifying layout=torch.jagged when constructing a nested tensor, as this layout receives active development, has better operator coverage, and works with torch.compile. (Triggered internally at /pytorch/aten/src/ATen/NestedTensorImpl.cpp:178.)
Random NT runs in 568.572 microseconds
Random Dense runs in 950.171 microseconds
将 SDPA 与 torch.compile
结合使用¶
随着 PyTorch 2.0 的发布,引入了一项名为 torch.compile()
的新功能,它可以显著提升性能,优于 eager 模式。缩放点积注意力完全可以与 torch.compile()
结合使用。为了证明这一点,让我们使用 torch.compile()
编译 CausalSelfAttention
模块,并观察由此产生的性能提升。
batch_size = 32
max_sequence_len = 256
x = torch.rand(batch_size, max_sequence_len,
embed_dimension, device=device, dtype=dtype)
print(
f"The non compiled module runs in {benchmark_torch_function_in_microseconds(model, x):.3f} microseconds")
compiled_model = torch.compile(model)
# Let's compile it
compiled_model(x)
print(
f"The compiled module runs in {benchmark_torch_function_in_microseconds(compiled_model, x):.3f} microseconds")
The non compiled module runs in 425.260 microseconds
The compiled module runs in 527.525 microseconds
确切的执行时间取决于机器,但我这里的测试结果是:未编译的模块运行时间为 166.616 微秒,编译后的模块运行时间为 166.726 微秒。这与我们预期不符。让我们深入探讨一下。PyTorch 提供了一个出色的内置性能分析器 (profiler),您可以使用它来检查代码的性能特征。
from torch.profiler import profile, record_function, ProfilerActivity
activities = [ProfilerActivity.CPU]
if device == 'cuda':
activities.append(ProfilerActivity.CUDA)
with profile(activities=activities, record_shapes=False) as prof:
with record_function(" Non-Compilied Causal Attention"):
for _ in range(25):
model(x)
print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10))
with profile(activities=activities, record_shapes=False) as prof:
with record_function("Compiled Causal Attention"):
for _ in range(25):
compiled_model(x)
print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10))
# For even more insights, you can export the trace and use ``chrome://tracing`` to view the results
#
# .. code-block:: python
#
# prof.export_chrome_trace("compiled_causal_attention_trace.json").
------------------------------------------------------- ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------
Name Self CPU % Self CPU CPU total % CPU total CPU time avg Self CUDA Self CUDA % CUDA total CUDA time avg # of Calls
------------------------------------------------------- ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------
Non-Compilied Causal Attention 19.14% 2.155ms 71.05% 8.000ms 8.000ms 0.000us 0.00% 10.823ms 10.823ms 1
Non-Compilied Causal Attention 0.00% 0.000us 0.00% 0.000us 0.000us 10.719ms 101.12% 10.719ms 10.719ms 1
aten::linear 1.03% 115.972us 25.67% 2.891ms 57.811us 0.000us 0.00% 7.999ms 159.982us 50
aten::matmul 2.11% 237.712us 21.85% 2.460ms 49.206us 0.000us 0.00% 7.999ms 159.982us 50
aten::mm 10.65% 1.200ms 17.54% 1.975ms 39.497us 7.777ms 73.36% 7.999ms 159.982us 50
ampere_fp16_s1688gemm_fp16_128x128_ldg8_f2f_tn 0.00% 0.000us 0.00% 0.000us 0.000us 5.563ms 52.48% 5.563ms 222.518us 25
aten::scaled_dot_product_attention 1.85% 207.993us 17.46% 1.966ms 78.635us 0.000us 0.00% 2.824ms 112.950us 25
aten::_scaled_dot_product_flash_attention 2.83% 318.622us 15.61% 1.758ms 70.315us 0.000us 0.00% 2.824ms 112.950us 25
aten::_flash_attention_forward 2.75% 309.462us 11.14% 1.254ms 50.158us 2.824ms 26.64% 2.824ms 112.950us 25
void pytorch_flash::flash_fwd_kernel<Flash_fwd_kerne... 0.00% 0.000us 0.00% 0.000us 0.000us 2.824ms 26.64% 2.824ms 112.950us 25
------------------------------------------------------- ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------
Self CPU time total: 11.260ms
Self CUDA time total: 10.601ms
------------------------------------------------------- ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------
Name Self CPU % Self CPU CPU total % CPU total CPU time avg Self CUDA Self CUDA % CUDA total CUDA time avg # of Calls
------------------------------------------------------- ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------
Compiled Causal Attention 0.00% 0.000us 0.00% 0.000us 0.000us 10.757ms 101.41% 10.757ms 10.757ms 1
Compiled Causal Attention 8.28% 949.181us 76.54% 8.779ms 8.779ms 0.000us 0.00% 10.607ms 10.607ms 1
Torch-Compiled Region: 2/0 7.96% 912.511us 63.42% 7.275ms 290.983us 0.000us 0.00% 10.607ms 424.295us 25
CompiledFunction 24.78% 2.843ms 55.47% 6.362ms 254.482us 0.000us 0.00% 10.607ms 424.295us 25
aten::mm 8.35% 958.240us 13.33% 1.528ms 30.569us 7.775ms 73.30% 7.775ms 155.494us 50
ampere_fp16_s1688gemm_fp16_128x128_ldg8_f2f_tn 0.00% 0.000us 0.00% 0.000us 0.000us 5.561ms 52.43% 5.561ms 222.459us 25
aten::_scaled_dot_product_flash_attention 2.16% 248.000us 14.85% 1.703ms 68.132us 0.000us 0.00% 2.833ms 113.307us 25
aten::_flash_attention_forward 2.85% 327.135us 10.80% 1.239ms 49.567us 2.833ms 26.70% 2.833ms 113.307us 25
void pytorch_flash::flash_fwd_kernel<Flash_fwd_kerne... 0.00% 0.000us 0.00% 0.000us 0.000us 2.833ms 26.70% 2.833ms 113.307us 25
ampere_fp16_s1688gemm_fp16_128x128_ldg8_f2f_stages_3... 0.00% 0.000us 0.00% 0.000us 0.000us 2.213ms 20.86% 2.213ms 88.529us 25
------------------------------------------------------- ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------
Self CPU time total: 11.470ms
Self CUDA time total: 10.607ms
前面的代码片段生成了一份报告,列出了编译和非编译模块中消耗 GPU 执行时间最多的前 10 个 PyTorch 函数。分析显示,在 GPU 上花费的大部分时间都集中在两个模块的同一组函数上。这里的原因是 torch.compile
非常擅长消除与 PyTorch 相关的框架开销。如果您的模型正在启动大型、高效的 CUDA 内核,就像这里的 CausalSelfAttention
一样,那么 PyTorch 的开销可以被隐藏起来。
实际上,您的模块通常不仅仅包含一个单独的 CausalSelfAttention
块。在对 Andrej Karpathy NanoGPT 仓库进行实验时,编译模块使每个训练步骤的时间从 6090.49ms
缩短到 3273.17ms
!这是在 NanoGPT 的 commit ae3a8d5
上对 Shakespeare 数据集进行训练时完成的。
将 SDPA 与 attn_bias 子类结合使用¶
# As of PyTorch 2.3, we have added a new submodule that contains tensor subclasses.
# Designed to be used with ``torch.nn.functional.scaled_dot_product_attention``.
# The module is named ``torch.nn.attention.bias`` and contains the following two
# utilities for generating causal attention variants:
#
# - ``torch.nn.attention.bias.causal_upper_left``
# - ``torch.nn.attention.bias.causal_lower_right``
#
# .. note::
# The current argument ``is_causal`` in ``torch.nn.functional.scaled_dot_product_attention``
# is the same as using ``torch.nn.attention.bias.causal_upper_left``.
#
from torch.nn.attention.bias import causal_lower_right, causal_upper_left
batch_size = 32
sequence_length_q = 2
sequence_length_kv = 10
num_heads = 16
embed_dimension = 32
dtype = torch.float16
query = torch.rand(batch_size, num_heads, sequence_length_q, embed_dimension, device=device, dtype=dtype)
key = torch.rand(batch_size, num_heads, sequence_length_kv, embed_dimension, device=device, dtype=dtype)
value = torch.rand(batch_size, num_heads, sequence_length_kv, embed_dimension, device=device, dtype=dtype)
upper_left_bias = causal_upper_left(sequence_length_q, sequence_length_kv)
lower_right_bias = causal_lower_right(sequence_length_q, sequence_length_kv)
print(type(upper_left_bias))
print(type(lower_right_bias))
assert type(upper_left_bias) == type(lower_right_bias)
assert issubclass(type(upper_left_bias), torch.Tensor)
# As you can see from the previous output, are the same type ``torch.nn.attention.bias.CausalBias``
# and subclass ``torch.Tensor``
# Lets see what these tensors look like
print(upper_left_bias)
print(lower_right_bias)
# Upper Left Bias aligns the causal attention mask to the upper left corner of the attention scores matrix.
# This only has an impact when the attention scores matrix is not square, which is common for decoding use cases.
# Another way of thinking about this concept is that when you use upper left bias,
# the 0th token in the query is aligned to the 0th token in the key, while for lower right bias,
# Assuming the attention score matrix is two dimensional, ``attn_score[0][0]`` is the attention score
# between the 0th token in the query and the 0th token in the key.
# For lower right bias, the sequence of q is aligned so that the last token in q is aligned to the last token in k
# (for example, ``attn_score[-1][-1])`` is all True since the last token in q is at the same position as the last token in k
# even if the sequence length of q and k are different.
# These objects are intended to be used with sdpa
out_upper_left = F.scaled_dot_product_attention(query, key, value, upper_left_bias)
out_lower_right = F.scaled_dot_product_attention(query, key, value, lower_right_bias)
out_is_causal = F.scaled_dot_product_attention(query, key, value, is_causal=True)
assert torch.allclose(out_upper_left, out_is_causal)
assert not torch.allclose(out_upper_left, out_lower_right)
# These attention biases should also be compatible with torch.compile
compiled_sdpa = torch.compile(F.scaled_dot_product_attention, fullgraph=True)
out_upper_left = compiled_sdpa(query, key, value, upper_left_bias)
<class 'torch.nn.attention.bias.CausalBias'>
<class 'torch.nn.attention.bias.CausalBias'>
tensor([[ True, False, False, False, False, False, False, False, False, False],
[ True, True, False, False, False, False, False, False, False, False]])
tensor([[ True, True, True, True, True, True, True, True, True, False],
[ True, True, True, True, True, True, True, True, True, True]])
结论¶
在本教程中,我们演示了 torch.nn.functional.scaled_dot_product_attention
的基本用法。我们展示了如何使用 sdpa_kernel
上下文管理器来指定在 GPU 上使用特定的实现。此外,我们构建了一个简单的 CausalSelfAttention
模块,它支持 NestedTensor
并可由 torch 编译。在此过程中,我们展示了如何使用性能分析工具来探索用户定义模块的性能特征。
脚本总运行时间: ( 0 minutes 7.387 seconds)