torch.nn.attention.flex_attention¶
- torch.nn.attention.flex_attention.flex_attention(query, key, value, score_mod=None, block_mask=None, scale=None, enable_gqa=False, return_lse=False, kernel_options=None)[源][源]¶
此函式實現了帶任意注意力分數修改函式的縮放點積注意力。
此函式計算查詢、鍵和值張量之間的縮放點積注意力,並應用使用者定義的注意力分數修改函式。注意力分數修改函式將在計算查詢和鍵張量之間的注意力分數後應用。注意力分數計算方法如下:
score_mod函式應具有以下簽名:def score_mod( score: Tensor, batch: Tensor, head: Tensor, q_idx: Tensor, k_idx: Tensor ) -> Tensor:
- 其中
score: 一個標量張量,表示注意力分數,其資料型別和裝置應與查詢、鍵和值張量相同。batch,head,q_idx,k_idx: 標量張量,分別指示批次索引、查詢頭索引、查詢索引和鍵/值索引。它們應具有torch.int資料型別,並位於與 score 張量相同的裝置上。
- 引數
query (Tensor) – 查詢張量;形狀為 。
key (Tensor) – 鍵張量;形狀為 。
value (Tensor) – 值張量;形狀為 。
score_mod (Optional[Callable]) – 用於修改注意力分數的函式。預設不應用 score_mod。
block_mask (Optional[BlockMask]) – BlockMask 物件,控制注意力的塊稀疏模式。
scale (Optional[float]) – 在 softmax 前應用的縮放因子。如果為 None,預設值為
enable_gqa (bool) – 如果設定為 True,啟用分組查詢注意力 (Grouped Query Attention, GQA),並將鍵/值頭廣播到查詢頭。
return_lse (bool) – 是否返回注意力分數的 logsumexp。預設為 False。
kernel_options (Optional[Dict[str, Any]]) – 傳遞給 Triton kernels 的選項。
- 返回
注意力輸出;形狀為 。
- 返回型別
output (Tensor)
- 形狀圖例
警告
torch.nn.attention.flex_attention是 PyTorch 中的一個原型特性。請期待未來版本中更穩定的實現。瞭解有關特性分類的更多資訊:https://pytorch.com.tw/blog/pytorch-feature-classification-changes/#prototype
BlockMask 工具函式¶
- torch.nn.attention.flex_attention.create_block_mask(mask_mod, B, H, Q_LEN, KV_LEN, device='cuda', BLOCK_SIZE=128, _compile=False)[源][源]¶
此函式根據 mask_mod 函式建立塊掩碼元組。
- 引數
mask_mod (Callable) – mask_mod 函式。這是一個可呼叫物件,定義了注意力機制的掩碼模式。它接受四個引數:b(批次大小)、h(頭數)、q_idx(查詢索引)和 kv_idx(鍵/值索引)。它應返回一個布林張量,指示允許哪些注意力連線 (True) 或遮蔽哪些注意力連線 (False)。
B (int) – 批次大小。
H (int) – 查詢頭數。
Q_LEN (int) – 查詢的序列長度。
KV_LEN (int) – 鍵/值的序列長度。
device (str) – 用於建立掩碼的裝置。
BLOCK_SIZE (int or tuple[int, int]) – 塊掩碼的塊大小。如果提供單個整數,則同時用於查詢和鍵/值。
- 返回
一個包含塊掩碼資訊的 BlockMask 物件。
- 返回型別
- 使用示例
def causal_mask(b, h, q_idx, kv_idx): return q_idx >= kv_idx block_mask = create_block_mask(causal_mask, 1, 1, 8192, 8192, device="cuda") query = torch.randn(1, 1, 8192, 64, device="cuda", dtype=torch.float16) key = torch.randn(1, 1, 8192, 64, device="cuda", dtype=torch.float16) value = torch.randn(1, 1, 8192, 64, device="cuda", dtype=torch.float16) output = flex_attention(query, key, value, block_mask=block_mask)
- torch.nn.attention.flex_attention.create_mask(mod_fn, B, H, Q_LEN, KV_LEN, device='cuda')[源][源]¶
此函式根據 mod_fn 函式建立掩碼張量。
- torch.nn.attention.flex_attention.create_nested_block_mask(mask_mod, B, H, q_nt, kv_nt=None, BLOCK_SIZE=128, _compile=False)[源][源]¶
此函式根據 mask_mod 函式建立與巢狀張量相容的塊掩碼元組。返回的 BlockMask 將位於輸入巢狀張量指定的裝置上。
- 引數
mask_mod (Callable) – mask_mod 函式。這是一個可呼叫物件,定義了注意力機制的掩碼模式。它接受四個引數:b(批次大小)、h(頭數)、q_idx(查詢索引)和 kv_idx(鍵/值索引)。它應返回一個布林張量,指示允許哪些注意力連線 (True) 或遮蔽哪些注意力連線 (False)。
B (int) – 批次大小。
H (int) – 查詢頭數。
q_nt (torch.Tensor) – 定義查詢序列長度結構的鋸齒狀佈局巢狀張量 (Jagged layout nested tensor, NJT)。塊掩碼將被構建為在來自 NJT 的、長度為
sum(S)的“堆疊序列”(stacked sequence) 上操作。kv_nt (torch.Tensor) – 定義鍵/值序列長度結構的鋸齒狀佈局巢狀張量 (NJT),支援交叉注意力。塊掩碼將被構建為在來自 NJT 的、長度為
sum(S)的“堆疊序列”上操作。如果此引數為 None,則使用q_nt來定義鍵/值的結構。預設值:NoneBLOCK_SIZE (int or tuple[int, int]) – 塊掩碼的塊大小。如果提供單個整數,則同時用於查詢和鍵/值。
- 返回
一個包含塊掩碼資訊的 BlockMask 物件。
- 返回型別
- 使用示例
# shape (B, num_heads, seq_len*, D) where seq_len* varies across the batch query = torch.nested.nested_tensor(..., layout=torch.jagged) key = torch.nested.nested_tensor(..., layout=torch.jagged) value = torch.nested.nested_tensor(..., layout=torch.jagged) def causal_mask(b, h, q_idx, kv_idx): return q_idx >= kv_idx block_mask = create_nested_block_mask(causal_mask, 1, 1, query, _compile=True) output = flex_attention(query, key, value, block_mask=block_mask)
# shape (B, num_heads, seq_len*, D) where seq_len* varies across the batch query = torch.nested.nested_tensor(..., layout=torch.jagged) key = torch.nested.nested_tensor(..., layout=torch.jagged) value = torch.nested.nested_tensor(..., layout=torch.jagged) def causal_mask(b, h, q_idx, kv_idx): return q_idx >= kv_idx # cross attention case: pass both query and key/value NJTs block_mask = create_nested_block_mask(causal_mask, 1, 1, query, key, _compile=True) output = flex_attention(query, key, value, block_mask=block_mask)
BlockMask¶
- class torch.nn.attention.flex_attention.BlockMask(seq_lengths, kv_num_blocks, kv_indices, full_kv_num_blocks, full_kv_indices, q_num_blocks, q_indices, full_q_num_blocks, full_q_indices, BLOCK_SIZE, mask_mod)[源][源]¶
BlockMask 是我們表示塊稀疏注意力掩碼的格式。它介於 BCSR 和非稀疏格式之間。
基礎知識¶
塊稀疏掩碼意味著,塊大小為 KV_BLOCK_SIZE x Q_BLOCK_SIZE 的塊只有在該塊內的每個元素都稀疏時才被視為稀疏,而不是表示掩碼中單個元素的稀疏性。這與硬體很好地契合,硬體通常期望執行連續載入和計算。
這種格式主要針對 1. 簡潔性和 2. 核心效率進行最佳化。值得注意的是,它*並非*針對大小進行最佳化,因為此掩碼總是按 KV_BLOCK_SIZE * Q_BLOCK_SIZE 的因子進行縮減。如果大小是個問題,可以透過增加塊大小來減小張量的大小。
我們格式的基本要素包括:
num_blocks_in_row: Tensor[ROWS]: 描述每行中存在的塊數量。
col_indices: Tensor[ROWS, MAX_BLOCKS_IN_COL]: col_indices[i] 是行 i 的塊位置序列。在該行中 col_indices[i][num_blocks_in_row[i]] 之後的值是未定義的。
例如,要從此格式重構原始張量
dense_mask = torch.zeros(ROWS, COLS) for row in range(ROWS): for block_idx in range(num_blocks_in_row[row]): dense_mask[row, col_indices[row, block_idx]] = 1
值得注意的是,此格式使得沿著掩碼的行進行歸約(reduction)更容易實現。
詳情¶
我們格式的基礎只需要 kv_num_blocks 和 kv_indices。但是,此物件上有多達 8 個張量。這代表 4 對:
1. (kv_num_blocks, kv_indices): 用於 attention 的前向傳播,因為我們沿著 KV 維度進行歸約。
2. [OPTIONAL] (full_kv_num_blocks, full_kv_indices): 這是可選的,純粹是為了最佳化。事實證明,對每個塊應用掩碼非常耗時!如果我們明確知道哪些塊是“完整的”並且完全不需要掩碼,那麼我們可以跳過對這些塊應用 mask_mod。這要求使用者將單獨的 mask_mod 從 score_mod 中剝離出來。對於因果掩碼,這大約能帶來 15% 的加速。
3. [GENERATED] (q_num_blocks, q_indices): 後向傳播需要,因為計算 dKV 需要沿著 Q 維度迭代掩碼。這些是從 1 自動生成的。
4. [GENERATED] (full_q_num_blocks, full_q_indices): 同上,但用於後向傳播。這些是從 2 自動生成的。
- as_tuple(flatten=True)[source][source]¶
返回 BlockMask 屬性的元組。
- 引數
flatten (bool) – 如果為 True,它將展平 (KV_BLOCK_SIZE, Q_BLOCK_SIZE) 元組
- classmethod from_kv_blocks(kv_num_blocks, kv_indices, full_kv_num_blocks=None, full_kv_indices=None, BLOCK_SIZE=128, mask_mod=None, seq_lengths=None)[source][source]¶
從鍵值塊資訊建立 BlockMask 例項。
- 引數
kv_num_blocks (Tensor) – 每個 Q_BLOCK_SIZE 行瓦片中的 kv_塊數量。
kv_indices (Tensor) – 每個 Q_BLOCK_SIZE 行瓦片中的鍵值塊索引。
full_kv_num_blocks (Optional[Tensor]) – 每個 Q_BLOCK_SIZE 行瓦片中的完整 kv_塊數量。
full_kv_indices (Optional[Tensor]) – 每個 Q_BLOCK_SIZE 行瓦片中的完整鍵值塊索引。
BLOCK_SIZE (Union[int, tuple[int, int]]) – KV_BLOCK_SIZE x Q_BLOCK_SIZE 瓦片的大小。
mask_mod (Optional[Callable]) – 用於修改掩碼的函式。
- 返回
透過 _transposed_ordered 生成具有完整 Q 資訊的例項
- 返回型別
- 引發
RuntimeError – 如果 kv_indices 的維度小於 2。
AssertionError – 如果僅提供了 full_kv_* 引數中的一個。
- property shape¶
- to(device)[source][source]¶
將 BlockMask 移動到指定裝置。
- 引數
device (torch.device or str) – 目標裝置,BlockMask 將被移到該裝置。可以是 torch.device 物件或字串(例如,'cpu','cuda:0')。
- 返回
一個新的 BlockMask 例項,其所有張量元件已移動到指定裝置。
- 返回型別
注意
此方法不會原地修改原始 BlockMask。相反,它返回一個新的 BlockMask 例項,其中各個張量屬性可能會或可能不會被移動到指定裝置,具體取決於它們當前的裝置位置。