实验性算子¶
注意力算子¶
-
std::tuple<at::Tensor, at::Tensor, at::Tensor> gqa_attn_splitk(const at::Tensor &XQ, const at::Tensor &cache_K, const at::Tensor &cache_V, const at::Tensor &seq_positions, const double qk_scale, const int64_t num_split_ks, const int64_t kv_cache_quant_num_groups, const bool use_tensor_cores, const int64_t cache_logical_dtype_int)¶
解码分组查询注意力 Split-K,支持 BF16/INT4 KV。
此为解码分组查询注意力 (GQA) 的 CUDA 实现,支持 BF16 和 INT4 KV 缓存以及 BF16 输入查询。目前仅支持最大上下文长度 16384、固定头维度 128 以及单个 KV 缓存头。支持任意数量的查询头。
- 参数:
XQ – 输入查询;形状 = (B, 1, H_Q, D),其中 B = 批大小,H_Q = 查询头数量,D = 头维度(固定为 128)
cache_K – K 缓存;形状 = (B, MAX_T, H_KV, D),其中 MAX_T = 最大上下文长度(固定为 16384),H_KV = KV 缓存头数量(固定为 1)
cache_V – V 缓存;形状 = (B, MAX_T, H_KV, D)
seq_positions – 序列位置(包含每个 token 的实际长度);形状 = (B)
qk_scale – 在 QK^T 后应用的缩放因子
num_split_ks – Split K 的数量(控制上下文长度维度 (MAX_T) 的并行度)
kv_cache_quant_num_groups – 每个 KV token 进行组量化(INT4 和 FP8)的分组数量(每组使用相同的缩放因子和偏置进行量化)。目前 FP8 仅支持单个分组。
use_tensor_cores – 是否使用 Tensor Core wmma 指令进行快速实现
cache_logical_dtype_int – 指定 kv_cache 的量化数据类型:{BF16:0, FP8:1, INT4:2}
- 返回值:
包含合并后的 split-K 输出、未合并的 split-K 输出以及 split-K 元数据(包含最大 QK^T 和 softmax(QK^T) 的头部和)的元组