• 教程 >
  • 透過使用 Nested Tensors 和 torch.compile() 加速 PyTorch Transformer
快捷方式

透過使用 Nested Tensors 和 torch.compile() 加速 PyTorch Transformer

作者: Mikayla Gawarecki

你將學到什麼
  • 瞭解 PyTorch 提供的用於構建自定義 Transformer 層的底層構建塊(nested tensors, scaled_dot_product_attention, torch.compile()FlexAttention

  • 瞭解上述技術如何透過以 MultiHeadAttention 為例來改進記憶體使用和效能

  • 探索使用上述構建塊進行高階定製

前提條件
  • PyTorch v.2.6.0 或更高版本

在過去幾年裡,PyTorch 團隊開發了各種底層功能,這些功能組合在一起可以建立各種 Transformer 變體。這些功能包括

  • 使用 torch.jagged 佈局的 Nested Tensors(也稱為 NJTs)

  • scaled_dot_product_attention

  • torch.compile()

  • FlexAttention

本教程將簡要概述上述技術,並演示如何組合它們以獲得靈活且高效能的 Transformer 層,同時改善使用者體驗。

你可能會注意到 torch.nn 模組目前提供了各種與 Transformer 相關的層。特別是,它包括 TransformerEncoderLayer, TransformerEncoder, TransformerDecoderLayer, TransformerDecoder, TransformerMultiheadAttention。這系列層最初是按照 Attention is All You Need 論文實現的。本教程中討論的元件在現有 nn 層之上提供了更好的使用者體驗、靈活性和效能。

本教程適合我嗎?

如果你想了解 torch 庫為編寫自己的 Transformer 層提供了哪些構建塊以及最佳實踐,那麼你來對地方了。請繼續閱讀!

如果你正在尋找一個流行的 Transformer 架構的開箱即用實現,請注意有許多開源庫提供了它們,包括

如果你只對高效能的注意力分數修改感興趣,請檢視 FlexAttention 部落格,其中包含一個 mask 的 gym

介紹構建塊

首先,我們將簡要介紹引言中提到的四項技術

Nested tensors 泛化了常規稠密張量的形狀,允許使用相同的張量使用者體驗表示大小不規則的資料。在 Transformer 的上下文中,我們可以將 nested tensors 視為一種表示可變序列長度的工具。它們消除了顯式 padding 和 masking(想想 nn.MultiHeadAttention 中的 key_padding_mask)這種易出錯實踐的必要性。

scaled_dot_product_attention 是一個用於計算 \(\text{softmax}(\frac{QK^T}{\sqrt{E}} + B)V\) 的原語,它可以分派到該操作的融合實現或回退實現。它在 eager 模式(即 PyTorch 的預設模式,操作會即時執行)下開箱即用,並且與 torch.compile() 無縫整合。截至 2.6 版本,它還將原生提供分組查詢注意力。

torch.compile() 是一個在 2.0 版本中引入的編譯器,能夠捕獲 PyTorch 程式碼圖並對其執行各種最佳化,例如融合一系列操作。使用 torch.jagged 佈局的 Nested tensors 和 scaled_dot_product_attention 可以與 compile 無縫協作。在 Transformer 的上下文中,將 compile 與 nested tensor 和 SDPA 結合使用的好處是 compile 可以消除 eager 模式下的框架開銷,並將 Transformer 中的一系列操作(例如 projection 和 activation)融合在一起。

FlexAttention 是一個原語,允許使用者在 softmax 操作之前修改注意力分數。它泛化了上述 scaled_dot_product_attention 的加性 B 項,允許進行任意計算。它需要 compile 才能獲得良好效能。

上述構建塊是“你所需要的一切”(截至 2024 年 10 月)

本節的主要前提是,大多數 Transformer 變體都是 GPT 風格的,由 Embedding、Positional Encoding、Attention Blocks 和 Feed Forward networks 等層組成。如果我們試圖對此領域的差異進行分類,可能會得出以下幾點

  1. 層型別(啟用函式,如 SwiGLU 等,歸一化函式,如 RMSNorm 等,位置編碼,如 Sinusoidal, Rotary 等)

  2. 層順序,例如在哪應用歸一化和位置編碼。

  3. 注意力分數修改,例如 ALiBi, Relative Positional Bias 等等。

在非編譯環境(pre-compiler environment)中,你可能編寫一個自定義 Transformer,注意到它可以正常工作但速度很慢。為了解決這個問題,你可能需要為特定的操作序列開發一個自定義的融合核心。在編譯環境(compiler environment)中,你只需執行第一步,然後進行編譯即可從改進的效能中受益。

MultiheadAttention

請記住,MultiheadAttention 接受 query、key 和 value 作為輸入,並由一個輸入 projection、一個 scaled_dot_product_attention 運算子和一個輸出 projection 組成。這裡我們想要展示的主要亮點是,用 nested tensors 替換 padded/masked 輸入所帶來的改進。改進有三個方面

  • 使用者體驗 請記住,nn.MultiheadAttention 要求 querykeyvalue 是稠密的 torch.Tensors。它還提供了一個 key_padding_mask,用於遮蔽由於批處理中不同序列長度而產生的 key 中的 padding token。由於 nn.MHA 中沒有 query_padding_mask,使用者必須小心地對輸出進行 mask/slice 以考慮 query 序列長度。NestedTensor 清晰地消除了這種易出錯的 padding mask 的需求。

  • 記憶體 Nested tensors 允許你清晰地表示一批不同序列長度的資料,而無需例項化一個帶有 [B, S] padding mask(其中 B 是批大小,S 是批處理中的最大序列長度,D 是嵌入大小)的稠密 [B, S, D] 張量。因此,輸入和中間啟用將使用更少的記憶體。

  • 效能 由於未例項化 padding 且跳過了對 padding 的不必要計算,效能和記憶體使用都得到了改善。

我們將透過在 Nested Tensor 教程 中的 MultiheadAttention 層基礎上進行構建,並將其與 nn.MultiheadAttention 層進行比較,來演示上述優勢。

import torch
import torch.nn as nn
import torch.nn.functional as F


class MultiHeadAttention(nn.Module):
    """
    Computes multi-head attention. Supports nested or padded tensors.

    Args:
        E_q (int): Size of embedding dim for query
        E_k (int): Size of embedding dim for key
        E_v (int): Size of embedding dim for value
        E_total (int): Total embedding dim of combined heads post input projection. Each head
            has dim E_total // nheads
        nheads (int): Number of heads
        dropout (float, optional): Dropout probability. Default: 0.0
        bias (bool, optional): Whether to add bias to input projection. Default: True
    """

    def __init__(
        self,
        E_q: int,
        E_k: int,
        E_v: int,
        E_total: int,
        nheads: int,
        dropout: float = 0.0,
        bias=True,
        device=None,
        dtype=None,
    ):
        factory_kwargs = {"device": device, "dtype": dtype}
        super().__init__()
        self.nheads = nheads
        self.dropout = dropout
        self._qkv_same_embed_dim = E_q == E_k and E_q == E_v
        if self._qkv_same_embed_dim:
            self.packed_proj = nn.Linear(E_q, E_total * 3, bias=bias, **factory_kwargs)
        else:
            self.q_proj = nn.Linear(E_q, E_total, bias=bias, **factory_kwargs)
            self.k_proj = nn.Linear(E_k, E_total, bias=bias, **factory_kwargs)
            self.v_proj = nn.Linear(E_v, E_total, bias=bias, **factory_kwargs)
        E_out = E_q
        self.out_proj = nn.Linear(E_total, E_out, bias=bias, **factory_kwargs)
        assert E_total % nheads == 0, "Embedding dim is not divisible by nheads"
        self.E_head = E_total // nheads
        self.bias = bias

    def forward(
        self,
        query: torch.Tensor,
        key: torch.Tensor,
        value: torch.Tensor,
        attn_mask=None,
        is_causal=False,
    ) -> torch.Tensor:
        """
        Forward pass; runs the following process:
            1. Apply input projection
            2. Split heads and prepare for SDPA
            3. Run SDPA
            4. Apply output projection

        Args:
            query (torch.Tensor): query of shape (``N``, ``L_q``, ``E_qk``)
            key (torch.Tensor): key of shape (``N``, ``L_kv``, ``E_qk``)
            value (torch.Tensor): value of shape (``N``, ``L_kv``, ``E_v``)
            attn_mask (torch.Tensor, optional): attention mask of shape (``N``, ``L_q``, ``L_kv``) to pass to SDPA. Default: None
            is_causal (bool, optional): Whether to apply causal mask. Default: False

        Returns:
            attn_output (torch.Tensor): output of shape (N, L_t, E_q)
        """
        # Step 1. Apply input projection
        if self._qkv_same_embed_dim:
            if query is key and key is value:
                result = self.packed_proj(query)
                query, key, value = torch.chunk(result, 3, dim=-1)
            else:
                q_weight, k_weight, v_weight = torch.chunk(
                    self.packed_proj.weight, 3, dim=0
                )
                if self.bias:
                    q_bias, k_bias, v_bias = torch.chunk(
                        self.packed_proj.bias, 3, dim=0
                    )
                else:
                    q_bias, k_bias, v_bias = None, None, None
                query, key, value = (
                    F.linear(query, q_weight, q_bias),
                    F.linear(key, k_weight, k_bias),
                    F.linear(value, v_weight, v_bias),
                )

        else:
            query = self.q_proj(query)
            key = self.k_proj(key)
            value = self.v_proj(value)

        # Step 2. Split heads and prepare for SDPA
        # reshape query, key, value to separate by head
        # (N, L_t, E_total) -> (N, L_t, nheads, E_head) -> (N, nheads, L_t, E_head)
        query = query.unflatten(-1, [self.nheads, self.E_head]).transpose(1, 2)
        # (N, L_s, E_total) -> (N, L_s, nheads, E_head) -> (N, nheads, L_s, E_head)
        key = key.unflatten(-1, [self.nheads, self.E_head]).transpose(1, 2)
        # (N, L_s, E_total) -> (N, L_s, nheads, E_head) -> (N, nheads, L_s, E_head)
        value = value.unflatten(-1, [self.nheads, self.E_head]).transpose(1, 2)

        # Step 3. Run SDPA
        # (N, nheads, L_t, E_head)
        attn_output = F.scaled_dot_product_attention(
            query, key, value, dropout_p=self.dropout, is_causal=is_causal
        )
        # (N, nheads, L_t, E_head) -> (N, L_t, nheads, E_head) -> (N, L_t, E_total)
        attn_output = attn_output.transpose(1, 2).flatten(-2)

        # Step 4. Apply output projection
        # (N, L_t, E_total) -> (N, L_t, E_out)
        attn_output = self.out_proj(attn_output)

        return attn_output

實用工具

在本節中,我們包含了一個實用工具,用於使用 Zipf 分佈生成半真實資料以獲取句子長度。這用於生成巢狀的 query、key 和 value 張量。我們還包含了一個基準測試實用工具。

import numpy as np


def zipf_sentence_lengths(alpha: float, batch_size: int) -> torch.Tensor:
    # generate fake corpus by unigram Zipf distribution
    # from wikitext-2 corpus, we get rank "." = 3, "!" = 386, "?" = 858
    sentence_lengths = np.empty(batch_size, dtype=int)
    for ibatch in range(batch_size):
        sentence_lengths[ibatch] = 1
        word = np.random.zipf(alpha)
        while word != 3 and word != 386 and word != 858:
            sentence_lengths[ibatch] += 1
            word = np.random.zipf(alpha)
    return torch.tensor(sentence_lengths)


# Generate a batch of semi-realistic data using Zipf distribution for sentence lengths
# in the form of nested tensors with the jagged layout.
def gen_batch(N, E_q, E_k, E_v, device, dtype=torch.float32, query_seq_len_1=False):
    # generate semi-realistic data using Zipf distribution for sentence lengths
    sentence_lengths = zipf_sentence_lengths(alpha=1.2, batch_size=N)

    # Note: the torch.jagged layout is a nested tensor layout that supports a single ragged
    # dimension and works with torch.compile. The batch items each have shape (B, S*, D)
    # where B = batch size, S* = ragged sequence length, and D = embedding dimension.
    if query_seq_len_1:
        query = torch.nested.nested_tensor(
            [torch.randn(1, E_q, dtype=dtype, device=device) for l in sentence_lengths],
            layout=torch.jagged,
        )
    else:
        query = torch.nested.nested_tensor(
            [
                torch.randn(l.item(), E_q, dtype=dtype, device=device)
                for l in sentence_lengths
            ],
            layout=torch.jagged,
        )

    key = torch.nested.nested_tensor(
        [
            torch.randn(s.item(), E_k, dtype=dtype, device=device)
            for s in sentence_lengths
        ],
        layout=torch.jagged,
    )

    value = torch.nested.nested_tensor(
        [
            torch.randn(s.item(), E_v, dtype=dtype, device=device)
            for s in sentence_lengths
        ],
        layout=torch.jagged,
    )

    return query, key, value, sentence_lengths


import math
import timeit


def benchmark(func, *args, **kwargs):
    torch.cuda.synchronize()
    torch.cuda.reset_peak_memory_stats()
    begin = timeit.default_timer()
    output = func(*args, **kwargs)
    torch.cuda.synchronize()
    end = timeit.default_timer()
    return output, (end - begin), torch.cuda.max_memory_allocated()

現在我們將演示在 MultiheadAttention 層中使用 nested tensors + compile 進行自注意力計算時的效能改進。我們將其與傳統的 nn.MultiheadAttention + compile(帶 padding 和 masking)進行比較。

N, E_q, E_k, E_v, E_total = 512, 512, 512, 512, 512
E_out = E_q
d_model = E_q
nheads = 8
dropout = 0.0
bias = True
device = "cuda"
torch.manual_seed(6)
query, key, value, sentence_lengths = gen_batch(N, E_q, E_k, E_v, device)
S = sentence_lengths.max().item()
print(
    f"Total sequence length in nested query {sentence_lengths.sum().item()}, max sequence length {S}"
)
padded_query, padded_key, padded_value = (
    t.to_padded_tensor(0.0) for t in (query, key, value)
)

torch.manual_seed(6)
mha_layer = MultiHeadAttention(
    E_q, E_k, E_v, E_total, nheads, dropout=dropout, bias=bias, device="cuda"
)
torch.manual_seed(6)
vanilla_mha_layer = nn.MultiheadAttention(
    E_q, nheads, dropout=dropout, batch_first=True, bias=bias, device="cuda"
)

# ``nn.MultiheadAttention`` uses a non conventional initialization for layers, so do this for exact parity :(
mha_layer.out_proj.weight = nn.Parameter(
    vanilla_mha_layer.out_proj.weight.clone().detach()
)
mha_layer.packed_proj.weight = nn.Parameter(
    vanilla_mha_layer.in_proj_weight.clone().detach()
)
mha_layer.out_proj.bias = nn.Parameter(vanilla_mha_layer.out_proj.bias.clone().detach())
mha_layer.packed_proj.bias = nn.Parameter(
    vanilla_mha_layer.in_proj_bias.clone().detach()
)

new_mha_layer = torch.compile(mha_layer)
# warmup compile
nested_result_warmup = new_mha_layer(query, query, query, is_causal=True)

# benchmark
nested_result, nested_time, nested_peak_memory = benchmark(
    new_mha_layer, query, query, query, is_causal=True
)
padded_nested_result = nested_result.to_padded_tensor(0.0)

# For the vanilla ``nn.MultiheadAttention``, we need to construct the ``key_padding_mask``
# Further, ``nn.MultiheadAttention`` forces one to materialize the ``attn_mask`` even if using ``is_causal``
src_key_padding_mask = torch.where(padded_query == 0.0, -math.inf, 0)[:, :, 0]
attn_mask = torch.empty((N, S, S), device=device).fill_(float("-inf"))
for i, s in enumerate(sentence_lengths):
    attn_mask[i, :s, :s] = nn.Transformer.generate_square_subsequent_mask(s)
attn_mask = attn_mask.unsqueeze(1).expand(N, nheads, S, S).reshape(N * nheads, S, S)

vanilla_mha_layer = torch.compile(vanilla_mha_layer)
# warmup compile
warmup_vanilla_result = vanilla_mha_layer(
    padded_query,
    padded_query,
    padded_query,
    attn_mask=attn_mask,
    key_padding_mask=src_key_padding_mask,
    need_weights=False,
    is_causal=True,
)

# benchmark
(padded_result, _), padded_time, padded_peak_memory = benchmark(
    vanilla_mha_layer,
    padded_query,
    padded_query,
    padded_query,
    key_padding_mask=src_key_padding_mask,
    need_weights=False,
    attn_mask=attn_mask,
    is_causal=True,
)

print(f"{padded_time=:.5f}, padded_peak_memory={padded_peak_memory/1e9:.2f} GB")
print(f"{nested_time=:.5f}, nested_peak_memory={nested_peak_memory/1e9:.2f} GB")
print(
    "Max difference between vanilla and nested result",
    (padded_result - padded_nested_result).abs().max().item(),
)
print(f"Nested speedup: {(padded_time/nested_time):.2f}")
print(
    f"Nested peak memory reduction {((padded_peak_memory - nested_peak_memory)/1e9):.2f} GB"
)
Total sequence length in nested query 10436, max sequence length 128
padded_time=0.01608, padded_peak_memory=3.87 GB
nested_time=0.00254, nested_peak_memory=0.92 GB
Max difference between vanilla and nested result 0.0
Nested speedup: 6.33
Nested peak memory reduction 2.96 GB

作為參考,以下是在 A100 上的樣本輸出

padded_time=0.03454, padded_peak_memory=4.14 GB
nested_time=0.00612, nested_peak_memory=0.76 GB
Max difference between vanilla and nested result 0.0
Nested speedup: 5.65
Nested peak memory reduction 3.39 GB

我們也可以看到反向傳播的相同情況

for i, entry_length in enumerate(sentence_lengths):
    # padding-specific step: remove output projection bias from padded entries for fair comparison
    padded_result[i, entry_length:, :] = 0.0

_, padded_bw_time, padded_bw_peak_mem = benchmark(
    lambda: padded_result.sum().backward()
)
_, nested_bw_time, nested_bw_peak_mem = benchmark(
    lambda: padded_nested_result.sum().backward()
)

print(f"{padded_bw_time=:.5f}, padded_bw_peak_mem={padded_bw_peak_mem/1e9:.2f} GB")
print(f"{nested_bw_time=:.5f}, nested_bw_peak_mem={nested_bw_peak_mem/1e9:.2f} GB")
print(f"Nested backward speedup: {(padded_bw_time/nested_bw_time):.2f}")
print(
    f"Nested backward peak memory reduction {((padded_bw_peak_mem - nested_bw_peak_mem)/1e9):.2f} GB"
)

print(
    "Difference in out_proj.weight.grad",
    (mha_layer.out_proj.weight.grad - vanilla_mha_layer.out_proj.weight.grad)
    .abs()
    .max()
    .item(),
)
print(
    "Difference in packed_proj.weight.grad",
    (mha_layer.packed_proj.weight.grad - vanilla_mha_layer.in_proj_weight.grad)
    .abs()
    .max()
    .item(),
)
print(
    "Difference in out_proj.bias.grad",
    (mha_layer.out_proj.bias.grad - vanilla_mha_layer.out_proj.bias.grad)
    .abs()
    .max()
    .item(),
)
print(
    "Difference in packed_proj.bias.grad",
    (mha_layer.packed_proj.bias.grad - vanilla_mha_layer.in_proj_bias.grad)
    .abs()
    .max()
    .item(),
)
padded_bw_time=1.62963, padded_bw_peak_mem=4.68 GB
nested_bw_time=0.06652, nested_bw_peak_mem=3.04 GB
Nested backward speedup: 24.50
Nested backward peak memory reduction 1.64 GB
Difference in out_proj.weight.grad 0.000396728515625
Difference in packed_proj.weight.grad 0.00146484375
Difference in out_proj.bias.grad 0.0
Difference in packed_proj.bias.grad 0.0029296875

A100 上的樣本輸出

padded_bw_time=2.09337, padded_bw_peak_mem=5.10 GB
nested_bw_time=0.01452, nested_bw_peak_mem=3.24 GB
Nested backward speedup: 144.13
Nested backward peak memory reduction 1.86 GB
Difference in out_proj.weight.grad 0.000244140625
Difference in packed_proj.weight.grad 0.001556396484375
Difference in out_proj.bias.grad 0.0
Difference in packed_proj.bias.grad 0.001953125

GPT 風格的層

基本的 GPT 風格 Transformer 層包括一個因果自注意力層,後接一個帶有 skip connections 的前饋網路 (FFN)。使用上面的 MultiheadAttention 層實現這一點相當簡單,並且與使用 is_causal=Truenn.TransformerEncoderLayer 的結果等效。

為簡潔起見,本教程省略了實現其他 nn 層的示例,你可以在此處找到它們。

更進一步

到目前為止,我們演示瞭如何實現遵循傳統 nn.MultiheadAttention 的高效能 MultiheadAttention 層。回到我們對 Transformer 架構修改的分類,請記住我們將修改分為層型別、層順序和注意力分數修改。我們相信改變層型別和層順序(例如將 LayerNorm 替換為 RMSNorm)是相當簡單的。

在本節中,我們將討論使用上述構建塊的各種功能,包括以下內容

  • 交叉注意力

  • 完全遮蔽的行不再導致 NaNs

  • 修改注意力分數:使用 FlexAttention 和 NJT 的 ALiBi

  • Packed Projection

交叉注意力

交叉注意力是一種注意力形式,其中 query 和 key/value 張量來自不同的序列。

一個例子是在 nn.TransformerDecoderLayer 中,其中 query 來自 decoder,而 key/value 來自 encoder。

上述 MultiheadAttention 層使用 nested tensors 對 query 和 key/value 都能很好地推廣到這種情況。

query, _, _, q_len = gen_batch(N, E_q, E_k, E_v, device)
_, key, value, kv_len = gen_batch(N, E_q, E_k, E_v, device)

print(
    f"Total sequence length in nested query {q_len.sum().item()}, max sequence length {q_len.max().item()}"
)
print(
    f"Total sequence length in nested key/value {kv_len.sum().item()}, max sequence length {kv_len.max().item()}"
)
out = new_mha_layer(query, key, value, is_causal=False)
Total sequence length in nested query 10617, max sequence length 165
Total sequence length in nested key/value 10176, max sequence length 137

如上所述,我們可以將其與 vanilla 編譯的 nn.MultiheadAttention 進行比較。

torch.manual_seed(6)
query, _, _, q_len = gen_batch(N, E_q, E_k, E_v, device)
_, key, value, kv_len = gen_batch(N, E_q, E_k, E_v, device)
padded_query, padded_key, padded_value = (
    t.to_padded_tensor(0.0) for t in (query, key, value)
)

key_padding_mask = torch.where(padded_key == 0.0, -math.inf, 0)[:, :, 0]

# warmup compile
warmup_nested_result = new_mha_layer(query, key, value, is_causal=False)
warmup_vanilla_result = vanilla_mha_layer(
    padded_query,
    padded_key,
    padded_value,
    key_padding_mask=key_padding_mask,
    need_weights=False,
    is_causal=False,
)

nested_result, nested_time, nested_peak_memory = benchmark(
    new_mha_layer, query, key, value, is_causal=False
)
(padded_result, _), padded_time, padded_peak_memory = benchmark(
    vanilla_mha_layer,
    padded_query,
    padded_key,
    padded_value,
    key_padding_mask=key_padding_mask,
    need_weights=False,
    is_causal=False,
)
padded_nested_result = nested_result.to_padded_tensor(0.0)
for i, entry_length in enumerate(q_len):
    # padding-specific step: remove output projection bias from padded entries for fair comparison
    padded_result[i, entry_length:, :] = 0.0

print(
    "Max difference between vanilla and nested result",
    (padded_result - padded_nested_result).abs().max().item(),
)
print(f"Nested speedup: {(padded_time/nested_time):.2f}")
print(
    f"Nested peak memory reduction {((padded_peak_memory - nested_peak_memory)/1e9):.2f} GB"
)
Max difference between vanilla and nested result 0.0
Nested speedup: 4.98
Nested peak memory reduction 1.20 GB

A100 上的樣本輸出

Max difference between vanilla and nested result 0.0
Nested speedup: 4.01
Nested peak memory reduction 1.40 GB

完全遮蔽的行不再導致 NaNs

長期以來,nn.MultiheadAttentionscaled_dot_product_attention 存在一個問題,即如果一行被完全遮蔽,注意力層的輸出將是 NaN。參見該 issue。這是因為空集上的 softmax 是未定義的。

感謝此 PR,這種情況不再發生。相反,scaled_dot_product_attention 中對應於完全遮蔽行的輸出將為 0。對於 nn.MHA 不使用“fast-path”的情況,這也將適用。

強烈建議使用帶有 NJTs 的自定義 MHA 層,而不是現有 nn.MultiheadAttention 中的“fast-path”,因為 NJT 正確建模不規則性的能力使得能夠正確表達空序列。

FlexAttention + NJT

NJT 也可以與 FlexAttention 模組組合。這是對 MultiheadAttention 層的泛化,允許對注意力分數進行任意修改。下面的例子採用 ALiBi 的實現 alibi_mod,來自 attention gym,並將其與 nested 輸入張量一起使用。

from torch.nn.attention.flex_attention import flex_attention


def generate_alibi_bias(H: int):
    """Returns an alibi bias score_mod given the number of heads H
    Args:
        H: number of heads
    Returns:
        alibi_bias: alibi bias score_mod
    """

    def alibi_mod(score, b, h, q_idx, kv_idx):
        scale = torch.exp2(-((h + 1) * 8.0 / H))
        bias = (q_idx - kv_idx) * scale
        return score + bias

    return alibi_mod


query, key, value, _ = gen_batch(N, E_q, E_k, E_v, device)
n_heads, D = 8, E_q // 8
alibi_score_mod = generate_alibi_bias(n_heads)
query = query.unflatten(-1, [n_heads, D]).transpose(1, 2).detach().requires_grad_()
key = key.unflatten(-1, [n_heads, D]).transpose(1, 2).detach().requires_grad_()
value = value.unflatten(-1, [n_heads, D]).transpose(1, 2).detach().requires_grad_()
out_flex2 = flex_attention(query, key, value, score_mod=alibi_score_mod)

此外,也可以透過 create_nested_block_mask 函式將 FlexAttentionblock_mask 實用工具與 NJTs 一起使用。這對於利用 mask 的稀疏性加速注意力計算很有用。特別是,該函式會為 NJT 中所有可變長度序列合併到一起的“堆疊序列”建立一個稀疏的塊 mask,同時正確遮蔽序列間的注意力。在下面的例子中,我們展示瞭如何使用此實用工具建立一個因果塊 mask。

from torch.nn.attention.flex_attention import create_nested_block_mask


def causal_mask(b, h, q_idx, kv_idx):
    return q_idx >= kv_idx


query, key, value, _ = gen_batch(N, E_q, E_k, E_v, device)
block_mask = create_nested_block_mask(causal_mask, 1, 1, query, _compile=True)
query = query.unflatten(-1, [n_heads, D]).transpose(1, 2).detach().requires_grad_()
key = key.unflatten(-1, [n_heads, D]).transpose(1, 2).detach().requires_grad_()
value = value.unflatten(-1, [n_heads, D]).transpose(1, 2).detach().requires_grad_()
out_flex = flex_attention(query, key, value, block_mask=block_mask)

Packed Projection

Packed projection 是一種技術,利用了當 projection 的輸入(矩陣乘法)相同時(自注意力)的特點,可以將 projection 權重和偏差打包到單個張量中。當單個 projection 受記憶體限制而非計算限制時,它特別有用。這裡我們將演示兩個示例

  • MultiheadAttention 的輸入 projection

  • Transformer 層前饋網路中的 SwiGLU activation

MultiheadAttention 的輸入 projection

在進行自注意力時,querykeyvalue 是同一個張量。這些張量中的每一個都透過一個 Linear(E_q, E_total) 層進行 projection。我們可以將這打包到一個層中,這正是我們在上面的 MultiheadAttention 層中所做的。

讓我們比較 packed projection 與常規方法的效能

class InputProjection(nn.Module):
    def __init__(self, E_q, E_total, bias=False, device=None, dtype=None):
        factory_kwargs = {"device": device, "dtype": dtype}
        super().__init__()
        self.q_proj = nn.Linear(E_q, E_total, bias=bias, **factory_kwargs)
        self.k_proj = nn.Linear(E_q, E_total, bias=bias, **factory_kwargs)
        self.v_proj = nn.Linear(E_q, E_total, bias=bias, **factory_kwargs)

    def forward(self, x):
        return self.q_proj(x), self.k_proj(x), self.v_proj(x)


class PackedInputProjection(nn.Module):
    def __init__(self, E_q, E_total, bias=False, device=None, dtype=None):
        factory_kwargs = {"device": device, "dtype": dtype}
        super().__init__()
        self.packed_proj = nn.Linear(E_q, E_total * 3, bias=bias, **factory_kwargs)

    def forward(self, query):
        return torch.chunk(self.packed_proj(query), 3, dim=-1)


B, D, dtype = 256, 8192, torch.bfloat16

torch.set_float32_matmul_precision("high")
in_proj = torch.compile(InputProjection(D, D, device="cuda", dtype=torch.bfloat16))
packed_in_proj = torch.compile(
    PackedInputProjection(D, D, device="cuda", dtype=torch.bfloat16)
)

q, _, _, sequence_lengths = gen_batch(B, D, D, D, device="cuda", dtype=torch.bfloat16)

# warmup
in_proj(q)
packed_in_proj(q)

# benchmark
(q_out, k_out, v_out), time, _ = benchmark(in_proj, q)
(q_out, k_out, v_out), time_packed, _ = benchmark(packed_in_proj, q)
# On my A100 prints 1.05x speedup
print(
    f"InputProjection: {time:5f} s, PackedInputProjection: {time_packed:5f} s, speedup: {time/time_packed:.2f}x"
)
InputProjection: 0.034046 s, PackedInputProjection: 0.032757 s, speedup: 1.04x

Transformer 層的前饋網路中的 SwiGLU

Swish-Gated Linear Unit (SwiGLU) 是一種非線性啟用函式,在 Transformer 層的前饋網路中越來越受歡迎(例如 Llama)。帶有 SwiGLU 啟用的前饋網路定義如下

class SwiGLUFFN(nn.Module):
    def __init__(
        self,
        dim,
        hidden_dim,
        multiple_of,
        ffn_dim_multiplier=None,
        device=None,
        dtype=None,
    ):
        factory_kwargs = {"device": device, "dtype": dtype}
        super().__init__()
        hidden_dim = int(2 * hidden_dim / 3)
        # custom dim factor multiplier
        if ffn_dim_multiplier is not None:
            hidden_dim = int(ffn_dim_multiplier * hidden_dim)
        hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of)

        self.w1 = nn.Linear(dim, hidden_dim, bias=False, **factory_kwargs)
        self.w2 = nn.Linear(hidden_dim, dim, bias=False, **factory_kwargs)
        self.w3 = nn.Linear(dim, hidden_dim, bias=False, **factory_kwargs)

    def forward(self, x):
        return self.w2(F.silu(self.w1(x)) * self.w3(x))

使用 packed projection 的另一種實現方法是

class PackedSwiGLUFFN(nn.Module):
    def __init__(
        self,
        dim,
        hidden_dim,
        multiple_of,
        ffn_dim_multiplier=None,
        device=None,
        dtype=None,
    ):
        factory_kwargs = {"device": device, "dtype": dtype}
        super().__init__()
        hidden_dim = int(2 * hidden_dim / 3)
        # custom dim factor multiplier
        if ffn_dim_multiplier is not None:
            hidden_dim = int(ffn_dim_multiplier * hidden_dim)
        hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of)

        self.w13 = nn.Linear(dim, 2 * hidden_dim, bias=False, **factory_kwargs)
        self.w2 = nn.Linear(hidden_dim, dim, bias=False, **factory_kwargs)

    def forward(self, x):
        x1, x3 = torch.chunk(self.w13(x), 2, dim=-1)
        return self.w2(F.silu(x1) * x3)

我們可以如下比較這兩種實現的效能。根據你的硬體,結果可能會有所不同。在 A100 上,我看到 D=128 時有 1.12 倍的加速。

D = 128

swigluffn = torch.compile(SwiGLUFFN(D, D * 4, 256, device="cuda", dtype=torch.bfloat16))
packed_swigluffn = torch.compile(
    PackedSwiGLUFFN(D, D * 4, 256, device="cuda", dtype=torch.bfloat16)
)

q, _, _, sentence_lengths = gen_batch(D, D, D, D, device="cuda", dtype=torch.bfloat16)

# warmup
swigluffn(q)
packed_swigluffn(q)

# benchmark
_, time, _ = benchmark(swigluffn, q)
_, time_packed, _ = benchmark(packed_swigluffn, q)
# On my A100 prints 1.08x speedup
print(
    f"SwiGLUFFN: {time} s, PackedSwiGLUFFN: {time_packed} s, speedup: {time/time_packed:.2f}x"
)
SwiGLUFFN: 0.0010205730000052426 s, PackedSwiGLUFFN: 0.0010395129997959884 s, speedup: 0.98x

擴充套件示例

我們計劃更新本教程,以演示更多如何使用各種高效能構建塊(如 KV-Caching、Grouped Query Attention 等)的示例。此外,還有一些很好的例子,展示瞭如何使用各種高效能構建塊來實現不同的 Transformer 架構。一些示例包括

結論

在本教程中,我們介紹了 PyTorch 提供的用於編寫 Transformer 層的底層構建塊,並演示瞭如何組合它們的示例。我們希望本教程能讓讀者瞭解 PyTorch 使用者可以多麼輕鬆地實現靈活且高效能的 Transformer 層。

指令碼總執行時間: ( 1 分 6.445 秒)

由 Sphinx-Gallery 生成的圖集

文件

訪問 PyTorch 的完整開發者文件

檢視文件

教程

獲取針對初學者和高階開發人員的深度教程

檢視教程

資源

查詢開發資源並獲取問題解答

檢視資源