快捷方式

torch.nn.attention.flex_attention

torch.nn.attention.flex_attention.flex_attention(query, key, value, score_mod=None, block_mask=None, scale=None, enable_gqa=False, return_lse=False, kernel_options=None)[source][source]

此函数使用任意注意力评分修改函数实现缩放点积注意力。

此函数计算查询、键和值张量之间,以及用户定义的注意力评分修改函数的缩放点积注意力。注意力评分修改函数将在计算查询和键张量之间的注意力评分后应用。注意力评分的计算方式如下

score_mod 函数应具有以下签名

def score_mod(
    score: Tensor,
    batch: Tensor,
    head: Tensor,
    q_idx: Tensor,
    k_idx: Tensor
) -> Tensor:
其中
  • score:标量张量,表示注意力评分,其数据类型和设备与查询、键和值张量相同。

  • batch, head, q_idx, k_idx:标量张量,分别表示批次索引、查询头索引、查询索引和键/值索引。 这些应具有 torch.int 数据类型,并位于与评分张量相同的设备上。

参数
  • query (Tensor) – 查询张量;形状 (B,Hq,L,E)(B, Hq, L, E)

  • key (Tensor) – 键张量;形状 (B,Hkv,S,E)(B, Hkv, S, E)

  • value (Tensor) – 值张量;形状 (B,Hkv,S,Ev)(B, Hkv, S, Ev)

  • score_mod (Optional[Callable]) – 修改注意力评分的函数。默认情况下,不应用 score_mod。

  • block_mask (Optional[BlockMask]) – BlockMask 对象,用于控制注意力的块稀疏模式。

  • scale (Optional[float]) – 应用于 softmax 之前的缩放因子。如果为 None,则默认值设置为 1E\frac{1}{\sqrt{E}}

  • enable_gqa (bool) – 如果设置为 True,则启用分组查询注意力 (GQA),并将键/值头广播到查询头。

  • return_lse (bool) – 是否返回注意力评分的 logsumexp。 默认为 False。

  • kernel_options (Optional[Dict[str, Any]]) – 传递到 Triton 内核的选项。

返回

注意力输出;形状 (B,Hq,L,Ev)(B, Hq, L, Ev)

返回类型

output (Tensor)

形状图例
  • N:Batch size...:Any number of other batch dimensions (optional)N: \text{Batch size} ... : \text{Any number of other batch dimensions (optional)}

  • S:Source sequence lengthS: \text{Source sequence length}

  • L:Target sequence lengthL: \text{Target sequence length}

  • E:Embedding dimension of the query and keyE: \text{Embedding dimension of the query and key}

  • Ev:Embedding dimension of the valueEv: \text{Embedding dimension of the value}

警告

torch.nn.attention.flex_attention 是 PyTorch 中的原型功能。请期待 PyTorch 未来版本中更稳定的实现。阅读有关功能分类的更多信息: https://pytorch.ac.cn/blog/pytorch-feature-classification-changes/#prototype

BlockMask 实用工具

torch.nn.attention.flex_attention.create_block_mask(mask_mod, B, H, Q_LEN, KV_LEN, device='cuda', BLOCK_SIZE=128, _compile=False)[source][source]

此函数从 mask_mod 函数创建块掩码元组。

参数
  • mask_mod (Callable) – mask_mod 函数。这是一个可调用对象,用于定义注意力机制的掩码模式。它接受四个参数:b(批次大小)、h(头数)、q_idx(查询索引)和 kv_idx(键/值索引)。 它应返回一个布尔张量,指示哪些注意力连接是允许的 (True) 或掩码的 (False)。

  • B (int) – 批次大小。

  • H (int) – 查询头数。

  • Q_LEN (int) – 查询的序列长度。

  • KV_LEN (int) – 键/值的序列长度。

  • device (str) – 运行掩码创建的设备。

  • BLOCK_SIZE (intTuple[int, int]) – 块掩码的块大小。如果提供单个 int,则用于查询和键/值。

返回

包含块掩码信息的 BlockMask 对象。

返回类型

BlockMask

使用示例
def causal_mask(b, h, q_idx, kv_idx):
    return q_idx >= kv_idx

block_mask = create_block_mask(causal_mask, 1, 1, 8192, 8192, device="cuda")
query = torch.randn(1, 1, 8192, 64, device="cuda", dtype=torch.float16)
key = torch.randn(1, 1, 8192, 64, device="cuda", dtype=torch.float16)
value = torch.randn(1, 1, 8192, 64, device="cuda", dtype=torch.float16)
output = flex_attention(query, key, value, block_mask=block_mask)
torch.nn.attention.flex_attention.create_mask(mod_fn, B, H, Q_LEN, KV_LEN, device='cuda')[source][source]

此函数从 mod_fn 函数创建掩码张量。

参数
  • mod_fn (Union[_score_mod_signature, _mask_mod_signature]) – 修改注意力评分的函数。

  • B (int) – 批次大小。

  • H (int) – 查询头数。

  • Q_LEN (int) – 查询的序列长度。

  • KV_LEN (int) – 键/值的序列长度。

  • device (str) – 运行掩码创建的设备。

返回

形状为 (B, H, M, N) 的掩码张量。

返回类型

mask (Tensor)

torch.nn.attention.flex_attention.create_nested_block_mask(mask_mod, B, H, q_nt, kv_nt=None, BLOCK_SIZE=128, _compile=False)[source][source]

此函数从 mask_mod 函数创建嵌套张量兼容的块掩码元组。返回的 BlockMask 将位于输入嵌套张量指定的设备上。

参数
  • mask_mod (Callable) – mask_mod 函数。这是一个可调用对象,用于定义注意力机制的掩码模式。它接受四个参数:b(批次大小)、h(头数)、q_idx(查询索引)和 kv_idx(键/值索引)。 它应返回一个布尔张量,指示哪些注意力连接是允许的 (True) 或掩码的 (False)。

  • B (int) – 批次大小。

  • H (int) – 查询头数。

  • q_nt (torch.Tensor) – 锯齿布局嵌套张量 (NJT),用于定义查询的序列长度结构。块掩码将构建为对“堆叠序列”进行操作,长度为 sum(S),其中序列长度 S 来自 NJT。

  • kv_nt (torch.Tensor) – 锯齿布局嵌套张量 (NJT),用于定义键/值的序列长度结构,允许交叉注意力。块掩码将构建为对“堆叠序列”进行操作,长度为 sum(S),其中序列长度 S 来自 NJT。如果为 None,则 q_nt 也用于定义键/值的结构。默认值:None

  • BLOCK_SIZE (intTuple[int, int]) – 块掩码的块大小。如果提供单个 int,则用于查询和键/值。

返回

包含块掩码信息的 BlockMask 对象。

返回类型

BlockMask

使用示例
# shape (B, num_heads, seq_len*, D) where seq_len* varies across the batch
query = torch.nested.nested_tensor(..., layout=torch.jagged)
key = torch.nested.nested_tensor(..., layout=torch.jagged)
value = torch.nested.nested_tensor(..., layout=torch.jagged)

def causal_mask(b, h, q_idx, kv_idx):
    return q_idx >= kv_idx

block_mask = create_nested_block_mask(causal_mask, 1, 1, query, _compile=True)
output = flex_attention(query, key, value, block_mask=block_mask)
# shape (B, num_heads, seq_len*, D) where seq_len* varies across the batch
query = torch.nested.nested_tensor(..., layout=torch.jagged)
key = torch.nested.nested_tensor(..., layout=torch.jagged)
value = torch.nested.nested_tensor(..., layout=torch.jagged)

def causal_mask(b, h, q_idx, kv_idx):
    return q_idx >= kv_idx

# cross attention case: pass both query and key/value NJTs
block_mask = create_nested_block_mask(causal_mask, 1, 1, query, key, _compile=True)
output = flex_attention(query, key, value, block_mask=block_mask)
torch.nn.attention.flex_attention.and_masks(*mask_mods)[source][source]

返回 mask_mod,它是提供的 mask_mods 的交集

返回类型

Callable[[Tensor, Tensor, Tensor, Tensor], Tensor]

torch.nn.attention.flex_attention.or_masks(*mask_mods)[source][source]

返回 mask_mod,它是提供的 mask_mods 的并集

返回类型

Callable[[Tensor, Tensor, Tensor, Tensor], Tensor]

torch.nn.attention.flex_attention.noop_mask(batch, head, token_q, token_kv)[source][source]

返回 noop mask_mod

返回类型

Tensor

BlockMask

class torch.nn.attention.flex_attention.BlockMask(seq_lengths, kv_num_blocks, kv_indices, full_kv_num_blocks, full_kv_indices, q_num_blocks, q_indices, full_q_num_blocks, full_q_indices, BLOCK_SIZE, mask_mod)[source][source]

BlockMask 是我们用于表示块稀疏注意力掩码的格式。它有点像是 BCSR 和非稀疏格式之间的交叉。

基础知识

块稀疏掩码意味着,不是表示掩码中各个元素的稀疏性,而是仅当 KV_BLOCK_SIZE x Q_BLOCK_SIZE 块内的每个元素都是稀疏的时,才认为该块是稀疏的。这与硬件非常吻合,硬件通常希望执行连续加载和计算。

此格式主要针对 1. 简易性和 2. 内核效率进行了优化。值得注意的是,它针对大小进行优化,因为此掩码始终会减少 KV_BLOCK_SIZE * Q_BLOCK_SIZE 倍。如果大小是一个问题,则可以通过增加块大小来减小张量的大小。

我们格式的要点是

num_blocks_in_row: Tensor[ROWS]:描述每行中存在的块数。

col_indices: Tensor[ROWS, MAX_BLOCKS_IN_COL]:col_indices[i] 是第 i 行的块位置序列。 col_indices[i][num_blocks_in_row[i]] 之后此行的值未定义。

例如,要从此格式重建原始张量

dense_mask = torch.zeros(ROWS, COLS)
for row in range(ROWS):
    for block_idx in range(num_blocks_in_row[row]):
        dense_mask[row, col_indices[row, block_idx]] = 1

值得注意的是,这种格式使得沿着 mask 的进行归约更加容易。

详情

我们格式的基础仅需要 kv_num_blocks 和 kv_indices。但是,在此对象上我们最多有 8 个张量。这代表 4 对:

1. (kv_num_blocks, kv_indices):用于 attention 的前向传播,因为我们沿着 KV 维度进行归约。

2. [可选] (full_kv_num_blocks, full_kv_indices):这是可选的,纯粹是为了优化。事实证明,对每个块应用 masking 非常昂贵!如果我们明确知道哪些块是“完整的”并且根本不需要 masking,那么我们可以跳过对这些块应用 mask_mod。这需要用户从 score_mod 中分离出一个单独的 mask_mod。对于因果 mask,这大约可以加速 15%。

3. [已生成] (q_num_blocks, q_indices):反向传播需要此项,因为计算 dKV 需要沿着 Q 维度迭代 mask。这些是从 1 自动生成的。

4. [已生成] (full_q_num_blocks, full_q_indices):与上面相同,但用于反向传播。这些是从 2 自动生成的。

BLOCK_SIZE: Tuple[int, int]
as_tuple(flatten=True)[源代码][源代码]

返回 BlockMask 属性的元组。

参数

flatten (bool) – 如果为 True,它将展平 (KV_BLOCK_SIZE, Q_BLOCK_SIZE) 的元组

classmethod from_kv_blocks(kv_num_blocks, kv_indices, full_kv_num_blocks=None, full_kv_indices=None, BLOCK_SIZE=128, mask_mod=None, seq_lengths=None)[源代码][源代码]

从键值块信息创建 BlockMask 实例。

参数
  • kv_num_blocks (Tensor) – 每个 Q_BLOCK_SIZE 行平铺中的 kv_blocks 的数量。

  • kv_indices (Tensor) – 每个 Q_BLOCK_SIZE 行平铺中键值块的索引。

  • full_kv_num_blocks (Optional[Tensor]) – 每个 Q_BLOCK_SIZE 行平铺中完整 kv_blocks 的数量。

  • full_kv_indices (Optional[Tensor]) – 每个 Q_BLOCK_SIZE 行平铺中完整键值块的索引。

  • BLOCK_SIZE (Union[int, Tuple[int, int]]) – KV_BLOCK_SIZE x Q_BLOCK_SIZE 平铺的大小。

  • mask_mod (Optional[Callable]) – 用于修改 mask 的函数。

返回

通过 _transposed_ordered 生成的具有完整 Q 信息的实例

返回类型

BlockMask

引发
full_kv_indices: Optional[Tensor]
full_kv_num_blocks: Optional[Tensor]
full_q_indices: Optional[Tensor]
full_q_num_blocks: Optional[Tensor]
kv_indices: Tensor
kv_num_blocks: Tensor
mask_mod: Callable[[Tensor, Tensor, Tensor, Tensor], Tensor]
numel()[源代码][源代码]

返回 mask 中元素的数量(不考虑稀疏性)。

q_indices: Optional[Tensor]
q_num_blocks: Optional[Tensor]
seq_lengths: Tuple[int, int]
property shape
sparsity()[源代码][源代码]

计算稀疏块(即未计算的块)的百分比

返回类型

float

to(device)[源代码][源代码]

将 BlockMask 移动到指定的设备。

参数

device (torch.devicestr) – 要将 BlockMask 移动到的目标设备。可以是 torch.device 对象或字符串(例如,“cpu”,“cuda:0”)。

返回

一个新的 BlockMask 实例,其中所有张量组件都已移动到指定的设备。

返回类型

BlockMask

注意

此方法不会就地修改原始 BlockMask。相反,它返回一个新的 BlockMask 实例,其中各个张量属性可能已移动到指定的设备,也可能未移动,具体取决于它们当前的设备放置位置。

to_dense()[源代码][源代码]

返回等效于块 mask 的密集块。

返回类型

Tensor

to_string(grid_size=(20, 20), limit=4)[源代码][源代码]

返回块 mask 的字符串表示形式。非常巧妙。

如果 grid_size 为 None,则打印出未压缩的版本。警告,它可能非常大!

文档

访问 PyTorch 的综合开发者文档

查看文档

教程

获取面向初学者和高级开发者的深入教程

查看教程

资源

查找开发资源并获得您的问题解答

查看资源