快捷方式

实验运算符

注意力运算符

std::tuple<at::Tensor, at::Tensor, at::Tensor> gqa_attn_splitk(const at::Tensor &XQ, const at::Tensor &cache_K, const at::Tensor &cache_V, const at::Tensor &seq_positions, const double qk_scale, const int64_t num_split_ks, const int64_t kv_cache_quant_num_groups, const bool use_tensor_cores, const int64_t cache_logical_dtype_int)

解码分组查询注意力 Split-K 带 BF16/INT4 KV。

解码分组查询注意力 (GQA) 的 CUDA 实现,支持 BF16 和 INT4 KV 缓存以及 BF16 输入查询。它目前仅支持 16384 的最大上下文长度、128 的固定头部维度,以及一个 KV 缓存头部。它支持任意数量的查询头部。

参数:
  • XQ – 输入查询;形状 = (B, 1, H_Q, D),其中 B = 批次大小,H_Q = 查询头部数,D = 头部维度(固定为 128)

  • cache_K – K 缓存;形状 = (B, MAX_T, H_KV, D),其中 MAX_T = 最大上下文长度(固定为 16384),H_KV = KV 缓存头部数(固定为 1)

  • cache_V – V 缓存;形状 = (B, MAX_T, H_KV, D)

  • seq_positions – 序列位置(包含每个标记的实际长度);形状 = (B)

  • qk_scale – 在 QK^T 后应用的比例

  • num_split_ks – Split K 的数量(控制上下文长度维度 (MAX_T) 中的并行度)

  • kv_cache_quant_num_groups – 每个 KV 标记的组式 INT4 和 FP8 量化组数(每组对量化使用相同的比例和偏差)。FP8 目前仅支持单组。

  • use_tensor_cores – 是否使用张量核心 wmma 指令来快速实现

  • cache_logical_dtype_int – 指定 kv_cache 的量化数据类型:{BF16:0 , FP8:1, INT4:2}

返回:

组合的 Split-K 输出、未组合的 Split-K 输出和 Split-K 元数据(包含最大 QK^T 和 softmax(QK^T) 头部总和)的元组

文档

访问 PyTorch 的全面开发者文档

查看文档

教程

获取针对初学者和高级开发者的深入教程

查看教程

资源

查找开发资源并获得问题的解答

查看资源