實驗性運算元¶
注意力運算元¶
-
std::tuple<at::Tensor, at::Tensor, at::Tensor> gqa_attn_splitk(const at::Tensor &XQ, const at::Tensor &cache_K, const at::Tensor &cache_V, const at::Tensor &seq_positions, const double qk_scale, const int64_t num_split_ks, const int64_t kv_cache_quant_num_groups, const bool use_tensor_cores, const int64_t cache_logical_dtype_int)¶
解碼分組查詢注意力 Split-K,支援 BF16/INT4 KV。
此為解碼分組查詢注意力 (GQA) 的 CUDA 實現,支援 BF16 和 INT4 KV 快取以及 BF16 輸入查詢。目前僅支援最大上下文長度 16384、固定頭維度 128 以及單個 KV 快取頭。支援任意數量的查詢頭。
- 引數:
XQ – 輸入查詢;形狀 = (B, 1, H_Q, D),其中 B = 批大小,H_Q = 查詢頭數量,D = 頭維度(固定為 128)
cache_K – K 快取;形狀 = (B, MAX_T, H_KV, D),其中 MAX_T = 最大上下文長度(固定為 16384),H_KV = KV 快取頭數量(固定為 1)
cache_V – V 快取;形狀 = (B, MAX_T, H_KV, D)
seq_positions – 序列位置(包含每個 token 的實際長度);形狀 = (B)
qk_scale – 在 QK^T 後應用的縮放因子
num_split_ks – Split K 的數量(控制上下文長度維度 (MAX_T) 的並行度)
kv_cache_quant_num_groups – 每個 KV token 進行組量化(INT4 和 FP8)的分組數量(每組使用相同的縮放因子和偏置進行量化)。目前 FP8 僅支援單個分組。
use_tensor_cores – 是否使用 Tensor Core wmma 指令進行快速實現
cache_logical_dtype_int – 指定 kv_cache 的量化資料型別:{BF16:0, FP8:1, INT4:2}
- 返回值:
包含合併後的 split-K 輸出、未合併的 split-K 輸出以及 split-K 元資料(包含最大 QK^T 和 softmax(QK^T) 的頭部和)的元組