跳轉到主要內容
部落格

FlexAttention:PyTorch 的靈活性與 FlashAttention 的效能

a cartoon chart flexing his muscles

理論上,“注意力”是你所需要的一切。然而在實踐中,我們還需要像 FlashAttention 這樣經過最佳化的注意力實現。

儘管這些融合的注意力實現大大提高了效能並支援了長上下文,但這種效率的提高也伴隨著靈活性的損失。你不能再透過編寫幾個 PyTorch 運算子來嘗試新的注意力變體了——你通常需要編寫一個新的自定義核心!這對於機器學習研究人員來說,就像是一種“軟體彩票”——如果你的注意力變體不適合現有的最佳化核心之一,你就註定會遭遇慢速執行時和 CUDA 記憶體溢位。

注意力變體的一些例子包括:因果注意力、相對位置編碼Alibi滑動視窗注意力PrefixLM文件掩碼/樣本打包/鋸齒形張量Tanh 軟限幅分頁注意力等。更糟糕的是,人們通常希望將這些變體組合起來!滑動視窗注意力 + 文件掩碼 + 因果注意力 + 上下文並行?或者分頁注意力 + 滑動視窗 + Tanh 軟限幅又如何?

下面左圖展示了當今世界的現狀——掩碼、偏置和設定的某些組合具有現有的核心實現。但各種選項導致了指數級的設定,因此總體而言,我們最終獲得了相當零星的支援。更糟糕的是,研究人員提出的新注意力變體將獲得*零*支援。

Attention variant support diagram

為了徹底解決這個超立方體問題,我們引入了**FlexAttention**,一個全新的 PyTorch API。

  1. 我們提供了一個靈活的 API,允許用幾行地道的 PyTorch 程式碼實現許多注意力變體(包括到目前為止部落格文章中提到的所有變體)。
  2. 我們透過 `torch.compile` 將其轉換為融合的 FlashAttention 核心,生成一個不例項化任何額外記憶體且效能與手寫核心相當的 FlashAttention 核心。
  3. 我們還利用 PyTorch 的自動求導機制自動生成反向傳播。
  4. 最後,我們還可以利用注意力掩碼中的稀疏性,從而相對於標準注意力實現獲得顯著改進。

有了 FlexAttention,我們希望嘗試新的注意力變體將只受限於你的想象力。

你可以在 Attention Gym 找到許多 FlexAttention 示例:https://github.com/pytorch-labs/attention-gym。如果你有任何很酷的應用,歡迎提交示例!

附註:我們還發現這個 API 非常令人興奮,因為它以一種有趣的方式利用了許多現有 PyTorch 基礎設施——更多內容將在最後介紹。

FlexAttention

這是經典的注意力方程

以程式碼形式

Q, K, V: Tensor[batch_size, num_heads, sequence_length, head_dim]
score: Tensor[batch_size, num_heads, sequence_length, sequence_length] = (Q @ K) / sqrt(head_dim)
probabilities = softmax(score, dim=-1)
output: Tensor[batch_size, num_heads, sequence_length, head_dim] = probabilities @ V

FlexAttention 允許使用者定義函式 `score_mod:`

math equation

以程式碼形式

Q, K, V: Tensor[batch_size, num_heads, sequence_length, head_dim]
score: Tensor[batch_size, num_heads, sequence_length, sequence_length] = (Q @ K) / sqrt(head_dim)
modified_scores: Tensor[batch_size, num_heads, sequence_length, sequence_length] = score_mod(score)
probabilities = softmax(modified_scores, dim=-1)
output: Tensor[batch_size, num_heads, sequence_length, head_dim] = probabilities @ V

這個函式允許你在 softmax 之前*修改*注意力分數。令人驚訝的是,這對於絕大多數注意力變體來說已經足夠了(示例如下)!

具體來說,`score_mod` 的預期簽名有些獨特。

def score_mod(score: f32[], b: i32[], h: i32[], q_idx: i32[], kv_idx: i32[])
    return score # noop - standard attention

換句話說,`score` 是一個標量 PyTorch 張量,表示查詢 token 和鍵 token 的點積。其餘引數告訴您當前正在計算的是*哪個*點積——`b`(批次中的當前元素)、`h`(當前頭)、`q_idx`(查詢中的位置)、`kv_idx`(鍵/值張量中的位置)。

要應用此函式,我們可以將其實現為

for b in range(batch_size):
    for h in range(num_heads):
        for q_idx in range(sequence_length):
            for kv_idx in range(sequence_length):
                modified_scores[b, h, q_idx, kv_idx] = score_mod(scores[b, h, q_idx, kv_idx], b, h, q_idx, kv_idx)

當然,這不是 FlexAttention 在底層實現的。透過利用 `torch.compile`,我們自動將您的函式轉換為一個*融合的* FlexAttention 核心——保證有效,否則退款!

這個 API 最終表現出驚人的表現力。讓我們看一些例子。

分數修改示例

完全注意力

首先,我們來實現“完全注意力”,即標準的雙向注意力。在這種情況下,`score_mod` 是一個無操作——它接收分數作為輸入,然後原樣返回。

def noop(score, b, h, q_idx, kv_idx):
    return score

並端到端使用它(包括前向*和*後向)

from torch.nn.attention.flex_attention import flex_attention

flex_attention(query, key, value, score_mod=noop).sum().backward()

相對位置編碼

一種常見的注意力變體是“相對位置編碼”。相對位置編碼不是在查詢和鍵中編碼絕對距離,而是根據查詢和鍵之間的“距離”調整分數。

def relative_positional(score, b, h, q_idx, kv_idx):
    return score + (q_idx - kv_idx)

請注意,與典型的實現不同,這*不需要*例項化 SxS 張量。相反,FlexAttention 在核心內部“即時”計算偏差值,從而顯著改善記憶體和效能。

ALiBi 偏差

alibi bias

來源:Train Short, Test Long: Attention with Linear Biases Enables Input Length Extrapolation

ALiBi 在 Train Short, Test Long: Attention with Linear Biases Enables Input Length Extrapolation 中被引入,並聲稱在推理時具有對長度外推有益的特性。值得注意的是,MosaicML 指出 “缺乏核心支援” 是他們最終從 ALiBi 轉向旋轉嵌入的主要原因。

Alibi 與相對位置編碼相似,只有一個例外——它有一個通常預先計算的每個頭因子。

alibi_bias = generate_alibi_bias() # [num_heads]

def alibi(score, b, h, q_idx, kv_idx):
    bias = alibi_bias[h] * (kv_idx - q_idx)
    return score + bias

這展示了 `torch.compile` 提供的一個有趣的靈活性——我們可以從 `alibi_bias` 載入,即使它*沒有作為輸入明確傳遞*!生成的 Triton 核心將計算從 `alibi_bias` 張量中進行的正確載入並進行融合。請注意,即使您重新生成 `alibi_bias`,我們也不需要重新編譯。

軟限幅

軟限幅是 Gemma2 和 Grok-1 中使用的一種技術,可防止 logits 過度增長。在 FlexAttention 中,它看起來像這樣

softcap = 20
def soft_cap(score, b, h, q_idx, kv_idx):
    score = score / softcap
    score = torch.tanh(score)
    score = score * softcap
    return score

請注意,我們還會自動從前向傳播生成反向傳播。此外,儘管此實現從語義上是正確的,但出於效能原因,在這種情況下我們可能希望使用 tanh 近似。有關更多詳細資訊,請參閱 attention-gym

因果掩碼

儘管雙向注意力是最簡單的,但最初的*《Attention is All You Need》*論文和絕大多數 LLM 都採用解碼器專用設定中的注意力,其中每個 token 只能關注其之前的 token。人們通常將其視為下三角掩碼,但使用 `score_mod` API,它可以表示為

def causal_mask(score, b, h, q_idx, kv_idx):
    return torch.where(q_idx >= kv_idx, score, -float("inf"))

基本上,如果查詢 token 在鍵 token “之後”,我們就保留分數。否則,我們透過將其設定為 -inf 來將其遮蔽掉,從而確保它不會參與 softmax 計算。

然而,掩碼與其他修改相比是特殊的——如果某個東西被遮蔽掉了,我們就可以完全跳過它的計算!在這種情況下,因果掩碼大約有 50% 的稀疏性,因此不利用稀疏性會導致 2 倍的減速。儘管這個 `score_mod` 足以*正確地*實現因果掩碼,但要獲得稀疏性帶來的效能優勢需要另一個概念——`mask_mod`。

掩碼模組

為了利用掩碼的稀疏性,我們需要做更多的工作。具體來說,透過將 `mask_mod` 傳遞給 `create_block_mask`,我們可以建立一個 `BlockMask`。然後 FlexAttention 可以使用 `BlockMask` 來利用稀疏性!

mask_mod 的簽名與 score_mod 非常相似——只是沒有 score。特別是

# returns True if this position should participate in the computation
mask_mod(b, h, q_idx, kv_idx) => bool

請注意,`score_mod` 比 `mask_mod` 嚴格來說*更具*表現力。然而,對於掩碼,建議使用 `mask_mod` 和 `create_block_mask`,因為它效能更高。請參閱 FAQ 瞭解為什麼 `score_mod` 和 `mask_mod` 是分開的。

現在,讓我們看看如何使用 `mask_mod` 實現因果掩碼。

因果掩碼

from torch.nn.attention.flex_attention import create_block_mask

def causal(b, h, q_idx, kv_idx):
    return q_idx >= kv_idx

# Because the sparsity pattern is independent of batch and heads, we'll set them to None (which broadcasts them) 
block_mask = create_block_mask(causal, B=None, H=None, Q_LEN=1024, KV_LEN=1024)
# In this case, we don't need a score_mod, so we won't pass any in.
# However, score_mod can still be combined with block_mask if you need the additional flexibility.
flex_attention(query, key, value, block_mask=block_mask)

請注意,`create_block_mask` 是一個**相對昂貴的操作!** 儘管 FlexAttention 在它改變時不需要重新編譯,但如果您不小心快取它,它可能會導致顯著的減速(請檢視 FAQ 以獲取最佳實踐建議)。

雖然 TFlops 大致相同,但 mask_mod 版本的執行時間快了 2 倍!這表明我們可以利用 BlockMask 提供的稀疏性,而*不*損失硬體效率。

滑動視窗 + 因果

Sliding Window Causal diagrams

來源:Mistral 7B

Mistral 推廣,滑動視窗注意力(也稱為區域性注意力)利用了最近的 token 最有用的直覺。特別是,它允許查詢 token 只關注例如最近的 1024 個 token。這通常與因果注意力一起使用。

SLIDING_WINDOW = 1024

def sliding_window_causal(b, h, q_idx, kv_idx):
    causal_mask = q_idx >= kv_idx
    window_mask = q_idx - kv_idx <= SLIDING_WINDOW 
    return causal_mask & window_mask

# If you want to be cute...
from torch.nn.attention import and_masks

def sliding_window(b, h, q_idx, kv_idx)
    return q_idx - kv_idx <= SLIDING_WINDOW

sliding_window_causal = and_masks(causal_mask, sliding_window)

我們將其與帶有滑動視窗掩碼的 `F.scaled_dot_product_attention` 以及帶有因果掩碼的 FA2(作為效能參考點)進行基準測試。我們不僅比 `F.scaled_dot_product_attention` 快得多,我們*也*比帶有因果掩碼的 FA2 快得多,因為這個掩碼具有顯著更多的稀疏性。

execution time charts

PrefixLM

PrefixLM diagram

來源:PaliGemma: A versatile 3B VLM for transfer

《探索統一文字到文字轉換器的遷移學習極限》中提出的 T5 架構描述了一種注意力變體,它在“字首”上執行完全雙向注意力,在其餘部分執行因果注意力。我們再次組合兩個掩碼函式來完成此操作,一個用於因果掩碼,另一個基於字首長度。

prefix_length: [B]
def prefix_mask(b, h, q_idx, kv_idx):
    return kv_idx <= prefix_length[b]

prefix_lm_causal = or_masks(prefix_mask, causal_mask)
# In this case, our mask is different per sequence so we set B equal to our batch size
block_mask = create_block_mask(prefix_lm_causal, B=B, H=None, S, S)

就像 `score_mod` 一樣,`mask_mod` 允許我們引用未明確作為函式輸入的額外張量!然而,對於 prefixLM,稀疏模式是*每個輸入*都會改變。這意味著對於每個新的輸入批次,我們都需要重新計算 `BlockMask`。一種常見的模式是在模型開頭呼叫 `create_block_mask` 並將該 `block_mask` 重用於模型中的所有注意力呼叫。請參閱*重新計算 BlockMasks 與重新編譯*。

然而,作為回報,我們不僅能夠為 prefixLM 提供高效的注意力核心,我們還能夠利用輸入中存在的任意稀疏性!FlexAttention 將根據 BlockMask 資料動態調整其效能,而*無需*重新編譯核心。

文件掩碼/鋸齒序列

另一種常見的注意力變體是文件掩碼/鋸齒序列。想象一下您有許多長度不等的序列。您想將它們一起訓練,但不幸的是,大多數運算子只接受矩形張量。

透過 `BlockMask`,我們也可以在 FlexAttention 中高效地支援它!

  1. 首先,我們將所有序列展平為單個序列,包含 sum(序列長度) 個 token。
  2. 然後,我們計算每個 token 所屬的 document_id。
  3. 最後,在我們的 `mask_mod` 中,我們簡單地判斷查詢和鍵值 token 是否屬於同一個文件!
# The document that each token belongs to.
# e.g. [0, 0, 0, 1, 1, 2, 2, 2, 2, 2, 2] corresponds to sequence lengths 3, 2, and 6.
document_id: [SEQ_LEN]

def document_masking(b, h, q_idx, kv_idx):
    return document_id[q_idx] == document_id[kv_idx]

就這樣!在這種情況下,我們看到我們最終得到一個塊對角掩碼。

blockdiagonal mask

文件掩碼的一個有趣之處在於,很容易看出它如何與任意其他掩碼組合。例如,我們已經在上一節中定義了 `prefixlm_mask`。現在我們還需要定義一個 `prefixlm_document_mask` 函式嗎?

在這些情況下,我們發現一種非常有用的模式是我們稱之為“更高級別修改”的方法。在這種情況下,我們可以採用現有的 `mask_mod` 並自動將其轉換為一個適用於鋸齒序列的 `mask_mod`!

def generate_doc_mask_mod(mask_mod, document_id):
    # Get unique document IDs and their counts
    _, counts = torch.unique_consecutive(document_id, return_counts=True)
    # Create cumulative counts (offsets)
    offsets = torch.cat([torch.tensor([0], device=document_id.device), counts.cumsum(0)[:-1]])
    def doc_mask_wrapper(b, h, q_idx, kv_idx):
        same_doc = document_id[q_idx] == document_id[kv_idx]
        q_logical = q_idx - offsets[document_id[q_idx]]
        kv_logical = kv_idx - offsets[document_id[kv_idx]]
        inner_mask = mask_mod(b, h, q_logical, kv_logical)
        return same_doc & inner_mask
    return doc_mask_wrapper

例如,給定上面的 `prefix_lm_causal` 掩碼,我們可以將其轉換為一個適用於打包文件的掩碼,如下所示

prefix_length = torch.tensor(2, dtype=torch.int32, device="cuda")
def prefix_mask(b, h, q_idx, kv_idx):
    return kv_idx < prefix_length
prefix_lm_causal = or_masks(prefix_mask, causal_mask)
doc_prefix_lm_causal_mask = generate_doc_mask_mod(prefix_lm_causal, document_id)

現在,這個掩碼是“塊-字首LM-對角線”形狀的。🙂

這就是我們所有的示例!注意力變體遠不止我們能列出的,所以請檢視Attention Gym以獲取更多示例。我們希望社群也能貢獻他們最喜歡的 FlexAttention 應用。

常見問題

問:FlexAttention 何時需要重新編譯?

由於 FlexAttention 利用 `torch.compile` 進行圖捕獲,它實際上可以在廣泛的場景中避免重新編譯。值得注意的是,即使捕獲的張量值發生變化,它也*不需要*重新編譯!

flex_attention = torch.compile(flex_attention)
def create_bias_mod(bias)
    def bias_mod(score, b, h, q_idx, kv_idx):
        return score + bias
    return bias_mod
bias_mod1 = create_bias_mod(torch.tensor(0))
flex_attention(..., score_mod=bias_mod1) # Compiles the kernel here 

bias_mod2 = create_bias_mod(torch.tensor(2))
flex_attention(..., score_mod=bias_mod2) # Doesn't need to recompile! 

即使更改塊稀疏性也不需要重新編譯。但是,如果塊稀疏性發生變化,我們確實需要*重新計算* BlockMask。

問:我們應該何時重新計算 BlockMask?

每當塊稀疏性改變時,我們都需要重新計算 BlockMask。儘管計算 BlockMask 比重新編譯便宜得多(大約幾百微秒而不是幾秒),但您仍應注意不要過度重新計算 BlockMask。

以下是一些常見模式以及關於如何處理它們的建議。

掩碼從不改變(例如因果掩碼)
在這種情況下,您可以簡單地預先計算塊掩碼並全域性快取它,將其重用於所有注意力呼叫。

block_mask = create_block_mask(causal_mask, 1, 1, S,S)
causal_attention = functools.partial(flex_attention, block_mask=block_mask)

掩碼每批次都會變化(例如文件掩碼)
在這種情況下,我們建議在模型開始時計算 BlockMask,並將其貫穿模型——將 BlockMask 重用於所有層。

def forward(self, x, doc_mask):
    # Compute block mask at beginning of forwards
    block_mask = create_block_mask(doc_mask, None, None, S, S)    
    x = self.layer1(x, block_mask)
    x = self.layer2(x, block_mask)
    ...
    # amortize block mask construction cost across all layers
    x = self.layer3(x, block_mask) 
    return x

掩碼每層變化(例如資料相關的稀疏性)
這是最困難的設定,因為我們無法在多次 FlexAttention 呼叫中攤銷塊掩碼計算。儘管 FlexAttention 肯定仍然可以在這種情況下受益,但 BlockMask 的實際好處取決於您的注意力掩碼的稀疏程度以及我們構建 BlockMask 的速度。這引出了…

問:我們如何更快地計算 BlockMask?

`create_block_mask` 不幸的是在記憶體和計算方面都相當昂貴,因為確定一個塊是否完全稀疏需要評估該塊中每個點的 `mask_mod`。有幾種方法可以解決這個問題

  1. 如果您的掩碼在批處理大小或頭部之間相同,請確保您在這些維度上進行廣播(即在 `create_block_mask` 中將它們設定為 `None`)。
  2. 編譯 `create_block_mask`。不幸的是,目前 `torch.compile` 由於一些不幸的限制,無法直接作用於 `create_block_mask`。但是,您可以設定 `_compile=True`,這將顯著降低峰值記憶體和執行時(在我們的測試中通常是一個數量級)。
  3. 為 BlockMask 編寫自定義建構函式。BlockMask 的元資料非常簡單(請參閱文件)。它本質上是兩個張量。a. `num_blocks`:為每個查詢塊計算的 KV 塊數。
    b. `indices`:為每個查詢塊計算的 KV 塊的位置。例如,這是 `causal_mask` 的自定義 BlockMask 建構函式。
def create_causal_mask(S):
    BLOCK_SIZE = 128
    # The first query block computes one block, the second query block computes 2 blocks, etc.
    num_blocks = torch.arange(S // BLOCK_SIZE, device="cuda") + 1
    # Since we're always computing from the left to the right,
    # we can use the indices [0, 1, 2, ...] for every query block.
    indices = torch.arange(S // BLOCK_SIZE, device="cuda").expand(
        S // BLOCK_SIZE, S // BLOCK_SIZE
    )
    num_blocks = num_blocks[None, None, :]
    indices = indices[None, None, :]
    return BlockMask(num_blocks, indices, BLOCK_SIZE=BLOCK_SIZE, mask_mod=causal_mask)
問:為什麼 `score_mod` 和 `mask_mod` 不同?難道 `mask_mod` 不僅僅是 `score_mod` 的一個特例嗎?

非常敏銳的問題,假想的觀眾!事實上,任何 `mask_mod` 都可以很容易地轉換為 `score_mod`(我們不建議在實踐中使用此函式!)

def mask_mod_as_score_mod(b, h, q_idx, kv_idx):
    return torch.where(mask_mod(b, h, q_idx, kv_idx), score, -float("inf"))

那麼,如果 `score_mod` 可以實現 `mask_mod` 可以實現的所有功能,那麼 `mask_mod` 的意義何在?

一個直接的挑戰是:`score_mod` 需要實際的 `score` 值作為輸入,但在預計算 BlockMask 時,我們沒有實際的 `score` 值。我們或許可以透過傳入全零來偽造這些值,如果 `score_mod` 返回 `-inf`,那麼我們認為它被掩蓋了(事實上,我們最初就是這樣做的!)。

然而,存在兩個問題。首先,這是不嚴謹的——如果使用者的 `score_mod` 在輸入為 0 時返回 `-inf` 怎麼辦?或者如果使用者的 `score_mod` 用一個大的負值而不是 `-inf` 進行掩碼處理怎麼辦?我們似乎在試圖把圓釘子敲進方孔裡。然而,將 `mask_mod` 與 `score_mod` 分開還有一個更重要的原因——它從根本上更高效!

事實證明,對每個計算出的元素應用掩碼實際上非常昂貴——我們的基準測試顯示效能下降約 15-20%!因此,儘管我們可以透過跳過一半計算來顯著加速,但由於需要掩蓋每個元素,我們損失了相當一部分加速!

幸運的是,如果我們視覺化因果掩碼,我們會注意到絕大多數塊根本不需要“因果掩碼”——它們是完全計算的!只有對角線上的塊,部分計算和部分掩碼的塊,才需要應用掩碼。

blockdiagonal mask

BlockMask 之前告訴我們哪些塊需要計算,哪些塊可以跳過。現在,我們進一步增強了這個資料結構,它還可以告訴我們哪些塊是“完全計算的”(即可以跳過掩碼)與“部分計算的”(即需要應用掩碼)。但請注意,儘管在“完全計算的”塊上可以跳過掩碼,但其他 `score_mod`(如相對位置嵌入)仍然需要應用。

給定一個 `score_mod`,我們無法可靠地判斷它的哪些部分是“掩碼”。因此,使用者必須將這些部分自己分成 `mask_mod`。

問:BlockMask 需要多少額外記憶體?

BlockMask 元資料的大小為 `[BATCH_SIZE, NUM_HEADS, QUERY_LEN//BLOCK_SIZE, KV_LEN//BLOCK_SIZE]`。如果掩碼在批次或頭維度上相同,則可以在該維度上進行廣播以節省記憶體。

在預設的 `BLOCK_SIZE` 為 128 時,我們預計大多數用例的記憶體使用將可以忽略不計。例如,對於 100 萬的序列長度,BlockMask 只會使用 60MB 的額外記憶體。如果這是一個問題,您可以增加塊大小:`create_block_mask(..., BLOCK_SIZE=1024)`。例如,將 `BLOCK_SIZE` 增加到 1024 將導致此元資料下降到一兆位元組以下。

問:數值比較如何?

儘管結果並非位相同,但我們相信 FlexAttention 在數值上與 FlashAttention 一樣準確。我們在一系列因果和非因果注意力變體的輸入上比較了 FlashAttention 和 FlexAttention 的差異,生成了以下差異分佈。誤差幾乎相同。

distribution chart

效能

通常來說,FlexAttention 的效能幾乎與手寫的 Triton 核心相當,這並不奇怪,因為我們大量利用了手寫的 Triton 核心。然而,由於其通用性,我們確實會產生一小部分效能損失。例如,我們必須承擔一些額外的延遲來確定接下來要計算哪個塊。在某些情況下,我們提供了一些核心選項,它們可以在改變其行為的同時影響核心的效能。它們可以在這裡找到:效能旋鈕

作為案例研究,讓我們探討這些旋鈕如何影響因果注意力的效能。我們將比較 A100 上 Triton 核心與 FlashAttentionv2 的效能。指令碼可以在此處找到。

FlexAttention 在前向傳播中實現了 FlashAttention2 90% 的效能,在反向傳播中實現了 85% 的效能。FlexAttention 目前正在使用一種確定性演算法,它比 FAv2 重新計算更多的中間結果,但我們計劃改進 FlexAttention 的反向傳播演算法,並希望彌合這一差距!

flexattention speed chart
flexattention speed chart

結論

我們希望您在使用 FlexAttention 時能像我們開發它時一樣快樂!在開發過程中,我們發現了比預期更多的此 API 應用。我們已經看到它將 torchtune 的樣本打包吞吐量提高了 71%,消除了研究人員需要花費一週多時間編寫自己的自定義 Triton 核心的需求,並提供了與自定義手寫注意力變體具有競爭力的效能。

最後,讓 FlexAttention 的實現變得非常有趣的一點是,我們能夠以一種有趣的方式利用許多現有的 PyTorch 基礎設施。例如,TorchDynamo(torch.compile 的前端)的一個獨特之處在於,它*不*要求編譯函式中使用的張量作為輸入顯式傳入。這使我們能夠編譯文件掩碼等模組,這些模組需要訪問*全域性*變數,並且全域性變數需要更改!

bias = torch.randn(1024, 1024)
def score_mod(score, b, h, q_idx, kv_idx):
    return score + bias[q_idx][kv_idx] # The bias tensor can change!

此外,`torch.compile` 是一種通用圖捕獲機制,這一事實也使其能夠支援更“高階”的轉換,例如將任何 `mask_mod` 轉換為適用於鋸齒張量的高階轉換。

我們還利用 TorchInductor(torch.compile 的後端)基礎設施來實現 Triton 模板。這不僅使得支援程式碼生成 FlexAttention 變得容易,而且還自動為我們提供了對動態形狀和 epilogue 融合(即在注意力末尾融合一個運算子)的支援!將來,我們計劃擴充套件此支援,以允許量化版本的注意力或像 RadixAttention 這樣的功能。

此外,我們還利用了高階操作、PyTorch 的自動求導來自動生成反向傳播,以及 vmap 來自動應用 `score_mod` 以建立 BlockMask。

當然,沒有 Triton 和 TorchInductor 生成 Triton 程式碼的能力,這個專案是不可能實現的。

我們期待在未來將我們在這裡使用的方法應用於更多的應用程式!

侷限性和未來工作

  • FlexAttention 目前在 PyTorch 每夜釋出版中可用,我們計劃在 2.5.0 中將其作為原型功能釋出
  • 我們在這裡沒有介紹如何將 FlexAttention 用於推理(或如何實現 PagedAttention)——我們將在以後的文章中介紹。
  • 我們正在努力提高 FlexAttention 的效能,使其與 H100 GPU 上的 FlashAttention3 匹配。
  • FlexAttention 要求所有序列長度都是 128 的倍數——這個問題很快就會解決。
  • 我們計劃很快新增 GQA 支援——目前,您只需複製 kv 頭即可。

致謝

我們要強調一些啟發 FlexAttention 的前期工作(和人物)。

  • Tri Dao 在 FlashAttention 方面的工作
  • Francisco Massa 和 Xformers 團隊在 Triton 中實現的 BlockSparseAttention
  • Jax 團隊在 SplashAttention 方面的工作
  • Philippe Tillet 和 Keren Zhou 協助我們處理 Triton
  • Ali Hassani 關於鄰域注意力的討論
  • 所有抱怨注意力核心不支援他們最喜歡的注意力變體的人 🙂