torch.nn.attention.flex_attention¶
- torch.nn.attention.flex_attention.flex_attention(query, key, value, score_mod=None, block_mask=None, scale=None, enable_gqa=False, return_lse=False, kernel_options=None)[源][源]¶
此函数实现了带任意注意力分数修改函数的缩放点积注意力。
此函数计算查询、键和值张量之间的缩放点积注意力,并应用用户定义的注意力分数修改函数。注意力分数修改函数将在计算查询和键张量之间的注意力分数后应用。注意力分数计算方法如下:
score_mod
函数应具有以下签名:def score_mod( score: Tensor, batch: Tensor, head: Tensor, q_idx: Tensor, k_idx: Tensor ) -> Tensor:
- 其中
score
: 一个标量张量,表示注意力分数,其数据类型和设备应与查询、键和值张量相同。batch
,head
,q_idx
,k_idx
: 标量张量,分别指示批量索引、查询头索引、查询索引和键/值索引。它们应具有torch.int
数据类型,并位于与 score 张量相同的设备上。
- 参数
query (Tensor) – 查询张量;形状为 。
key (Tensor) – 键张量;形状为 。
value (Tensor) – 值张量;形状为 。
score_mod (Optional[Callable]) – 用于修改注意力分数的函数。默认不应用 score_mod。
block_mask (Optional[BlockMask]) – BlockMask 对象,控制注意力的块稀疏模式。
scale (Optional[float]) – 在 softmax 前应用的缩放因子。如果为 None,默认值为
enable_gqa (bool) – 如果设置为 True,启用分组查询注意力 (Grouped Query Attention, GQA),并将键/值头广播到查询头。
return_lse (bool) – 是否返回注意力分数的 logsumexp。默认为 False。
kernel_options (Optional[Dict[str, Any]]) – 传递给 Triton kernels 的选项。
- 返回
注意力输出;形状为 。
- 返回类型
output (Tensor)
- 形状图例
警告
torch.nn.attention.flex_attention
是 PyTorch 中的一个原型特性。请期待未来版本中更稳定的实现。了解有关特性分类的更多信息:https://pytorch.ac.cn/blog/pytorch-feature-classification-changes/#prototype
BlockMask 工具函数¶
- torch.nn.attention.flex_attention.create_block_mask(mask_mod, B, H, Q_LEN, KV_LEN, device='cuda', BLOCK_SIZE=128, _compile=False)[源][源]¶
此函数根据 mask_mod 函数创建块掩码元组。
- 参数
mask_mod (Callable) – mask_mod 函数。这是一个可调用对象,定义了注意力机制的掩码模式。它接受四个参数:b(批量大小)、h(头数)、q_idx(查询索引)和 kv_idx(键/值索引)。它应返回一个布尔张量,指示允许哪些注意力连接 (True) 或屏蔽哪些注意力连接 (False)。
B (int) – 批量大小。
H (int) – 查询头数。
Q_LEN (int) – 查询的序列长度。
KV_LEN (int) – 键/值的序列长度。
device (str) – 用于创建掩码的设备。
BLOCK_SIZE (int or tuple[int, int]) – 块掩码的块大小。如果提供单个整数,则同时用于查询和键/值。
- 返回
一个包含块掩码信息的 BlockMask 对象。
- 返回类型
- 使用示例
def causal_mask(b, h, q_idx, kv_idx): return q_idx >= kv_idx block_mask = create_block_mask(causal_mask, 1, 1, 8192, 8192, device="cuda") query = torch.randn(1, 1, 8192, 64, device="cuda", dtype=torch.float16) key = torch.randn(1, 1, 8192, 64, device="cuda", dtype=torch.float16) value = torch.randn(1, 1, 8192, 64, device="cuda", dtype=torch.float16) output = flex_attention(query, key, value, block_mask=block_mask)
- torch.nn.attention.flex_attention.create_mask(mod_fn, B, H, Q_LEN, KV_LEN, device='cuda')[源][源]¶
此函数根据 mod_fn 函数创建掩码张量。
- torch.nn.attention.flex_attention.create_nested_block_mask(mask_mod, B, H, q_nt, kv_nt=None, BLOCK_SIZE=128, _compile=False)[源][源]¶
此函数根据 mask_mod 函数创建与嵌套张量兼容的块掩码元组。返回的 BlockMask 将位于输入嵌套张量指定的设备上。
- 参数
mask_mod (Callable) – mask_mod 函数。这是一个可调用对象,定义了注意力机制的掩码模式。它接受四个参数:b(批量大小)、h(头数)、q_idx(查询索引)和 kv_idx(键/值索引)。它应返回一个布尔张量,指示允许哪些注意力连接 (True) 或屏蔽哪些注意力连接 (False)。
B (int) – 批量大小。
H (int) – 查询头数。
q_nt (torch.Tensor) – 定义查询序列长度结构的锯齿状布局嵌套张量 (Jagged layout nested tensor, NJT)。块掩码将被构建为在来自 NJT 的、长度为
sum(S)
的“堆叠序列”(stacked sequence) 上操作。kv_nt (torch.Tensor) – 定义键/值序列长度结构的锯齿状布局嵌套张量 (NJT),支持交叉注意力。块掩码将被构建为在来自 NJT 的、长度为
sum(S)
的“堆叠序列”上操作。如果此参数为 None,则使用q_nt
来定义键/值的结构。默认值:NoneBLOCK_SIZE (int or tuple[int, int]) – 块掩码的块大小。如果提供单个整数,则同时用于查询和键/值。
- 返回
一个包含块掩码信息的 BlockMask 对象。
- 返回类型
- 使用示例
# shape (B, num_heads, seq_len*, D) where seq_len* varies across the batch query = torch.nested.nested_tensor(..., layout=torch.jagged) key = torch.nested.nested_tensor(..., layout=torch.jagged) value = torch.nested.nested_tensor(..., layout=torch.jagged) def causal_mask(b, h, q_idx, kv_idx): return q_idx >= kv_idx block_mask = create_nested_block_mask(causal_mask, 1, 1, query, _compile=True) output = flex_attention(query, key, value, block_mask=block_mask)
# shape (B, num_heads, seq_len*, D) where seq_len* varies across the batch query = torch.nested.nested_tensor(..., layout=torch.jagged) key = torch.nested.nested_tensor(..., layout=torch.jagged) value = torch.nested.nested_tensor(..., layout=torch.jagged) def causal_mask(b, h, q_idx, kv_idx): return q_idx >= kv_idx # cross attention case: pass both query and key/value NJTs block_mask = create_nested_block_mask(causal_mask, 1, 1, query, key, _compile=True) output = flex_attention(query, key, value, block_mask=block_mask)
BlockMask¶
- class torch.nn.attention.flex_attention.BlockMask(seq_lengths, kv_num_blocks, kv_indices, full_kv_num_blocks, full_kv_indices, q_num_blocks, q_indices, full_q_num_blocks, full_q_indices, BLOCK_SIZE, mask_mod)[源][源]¶
BlockMask 是我们表示块稀疏注意力掩码的格式。它介于 BCSR 和非稀疏格式之间。
基础知识¶
块稀疏掩码意味着,块大小为 KV_BLOCK_SIZE x Q_BLOCK_SIZE 的块只有在该块内的每个元素都稀疏时才被视为稀疏,而不是表示掩码中单个元素的稀疏性。这与硬件很好地契合,硬件通常期望执行连续加载和计算。
这种格式主要针对 1. 简洁性和 2. 内核效率进行优化。值得注意的是,它*并非*针对大小进行优化,因为此掩码总是按 KV_BLOCK_SIZE * Q_BLOCK_SIZE 的因子进行缩减。如果大小是个问题,可以通过增加块大小来减小张量的大小。
我们格式的基本要素包括:
num_blocks_in_row: Tensor[ROWS]: 描述每行中存在的块数量。
col_indices: Tensor[ROWS, MAX_BLOCKS_IN_COL]: col_indices[i] 是行 i 的块位置序列。在该行中 col_indices[i][num_blocks_in_row[i]] 之后的值是未定义的。
例如,要从此格式重构原始张量
dense_mask = torch.zeros(ROWS, COLS) for row in range(ROWS): for block_idx in range(num_blocks_in_row[row]): dense_mask[row, col_indices[row, block_idx]] = 1
值得注意的是,此格式使得沿着掩码的行进行归约(reduction)更容易实现。
详情¶
我们格式的基础只需要 kv_num_blocks 和 kv_indices。但是,此对象上有多达 8 个张量。这代表 4 对:
1. (kv_num_blocks, kv_indices): 用于 attention 的前向传播,因为我们沿着 KV 维度进行归约。
2. [OPTIONAL] (full_kv_num_blocks, full_kv_indices): 这是可选的,纯粹是为了优化。事实证明,对每个块应用掩码非常耗时!如果我们明确知道哪些块是“完整的”并且完全不需要掩码,那么我们可以跳过对这些块应用 mask_mod。这要求用户将单独的 mask_mod 从 score_mod 中剥离出来。对于因果掩码,这大约能带来 15% 的加速。
3. [GENERATED] (q_num_blocks, q_indices): 后向传播需要,因为计算 dKV 需要沿着 Q 维度迭代掩码。这些是从 1 自动生成的。
4. [GENERATED] (full_q_num_blocks, full_q_indices): 同上,但用于后向传播。这些是从 2 自动生成的。
- as_tuple(flatten=True)[source][source]¶
返回 BlockMask 属性的元组。
- 参数
flatten (bool) – 如果为 True,它将展平 (KV_BLOCK_SIZE, Q_BLOCK_SIZE) 元组
- classmethod from_kv_blocks(kv_num_blocks, kv_indices, full_kv_num_blocks=None, full_kv_indices=None, BLOCK_SIZE=128, mask_mod=None, seq_lengths=None)[source][source]¶
从键值块信息创建 BlockMask 实例。
- 参数
kv_num_blocks (Tensor) – 每个 Q_BLOCK_SIZE 行瓦片中的 kv_块数量。
kv_indices (Tensor) – 每个 Q_BLOCK_SIZE 行瓦片中的键值块索引。
full_kv_num_blocks (Optional[Tensor]) – 每个 Q_BLOCK_SIZE 行瓦片中的完整 kv_块数量。
full_kv_indices (Optional[Tensor]) – 每个 Q_BLOCK_SIZE 行瓦片中的完整键值块索引。
BLOCK_SIZE (Union[int, tuple[int, int]]) – KV_BLOCK_SIZE x Q_BLOCK_SIZE 瓦片的大小。
mask_mod (Optional[Callable]) – 用于修改掩码的函数。
- 返回
通过 _transposed_ordered 生成具有完整 Q 信息的实例
- 返回类型
- 引发
RuntimeError – 如果 kv_indices 的维度小于 2。
AssertionError – 如果仅提供了 full_kv_* 参数中的一个。
- property shape¶
- to(device)[source][source]¶
将 BlockMask 移动到指定设备。
- 参数
device (torch.device or str) – 目标设备,BlockMask 将被移到该设备。可以是 torch.device 对象或字符串(例如,'cpu','cuda:0')。
- 返回
一个新的 BlockMask 实例,其所有张量组件已移动到指定设备。
- 返回类型
注意
此方法不会原地修改原始 BlockMask。相反,它返回一个新的 BlockMask 实例,其中各个张量属性可能会或可能不会被移动到指定设备,具体取决于它们当前的设备位置。