注意
點選 此處 下載完整的示例程式碼
透過使用 Nested Tensors 和 torch.compile() 加速 PyTorch Transformer¶
瞭解 PyTorch 提供的用於構建自定義 Transformer 層的底層構建塊(nested tensors,
scaled_dot_product_attention,torch.compile()和FlexAttention)瞭解上述技術如何透過以 MultiHeadAttention 為例來改進記憶體使用和效能
探索使用上述構建塊進行高階定製
PyTorch v.2.6.0 或更高版本
在過去幾年裡,PyTorch 團隊開發了各種底層功能,這些功能組合在一起可以建立各種 Transformer 變體。這些功能包括
使用
torch.jagged佈局的 Nested Tensors(也稱為 NJTs)scaled_dot_product_attentiontorch.compile()FlexAttention
本教程將簡要概述上述技術,並演示如何組合它們以獲得靈活且高效能的 Transformer 層,同時改善使用者體驗。
你可能會注意到 torch.nn 模組目前提供了各種與 Transformer 相關的層。特別是,它包括 TransformerEncoderLayer, TransformerEncoder, TransformerDecoderLayer, TransformerDecoder, Transformer 和 MultiheadAttention。這系列層最初是按照 Attention is All You Need 論文實現的。本教程中討論的元件在現有 nn 層之上提供了更好的使用者體驗、靈活性和效能。
本教程適合我嗎?¶
如果你想了解 torch 庫為編寫自己的 Transformer 層提供了哪些構建塊以及最佳實踐,那麼你來對地方了。請繼續閱讀!
如果你正在尋找一個流行的 Transformer 架構的開箱即用實現,請注意有許多開源庫提供了它們,包括
如果你只對高效能的注意力分數修改感興趣,請檢視 FlexAttention 部落格,其中包含一個 mask 的 gym。
介紹構建塊¶
首先,我們將簡要介紹引言中提到的四項技術
Nested tensors 泛化了常規稠密張量的形狀,允許使用相同的張量使用者體驗表示大小不規則的資料。在 Transformer 的上下文中,我們可以將 nested tensors 視為一種表示可變序列長度的工具。它們消除了顯式 padding 和 masking(想想 nn.MultiHeadAttention 中的 key_padding_mask)這種易出錯實踐的必要性。
scaled_dot_product_attention 是一個用於計算 \(\text{softmax}(\frac{QK^T}{\sqrt{E}} + B)V\) 的原語,它可以分派到該操作的融合實現或回退實現。它在 eager 模式(即 PyTorch 的預設模式,操作會即時執行)下開箱即用,並且與 torch.compile() 無縫整合。截至 2.6 版本,它還將原生提供分組查詢注意力。
torch.compile() 是一個在 2.0 版本中引入的編譯器,能夠捕獲 PyTorch 程式碼圖並對其執行各種最佳化,例如融合一系列操作。使用 torch.jagged 佈局的 Nested tensors 和 scaled_dot_product_attention 可以與 compile 無縫協作。在 Transformer 的上下文中,將 compile 與 nested tensor 和 SDPA 結合使用的好處是 compile 可以消除 eager 模式下的框架開銷,並將 Transformer 中的一系列操作(例如 projection 和 activation)融合在一起。
FlexAttention 是一個原語,允許使用者在 softmax 操作之前修改注意力分數。它泛化了上述 scaled_dot_product_attention 的加性 B 項,允許進行任意計算。它需要 compile 才能獲得良好效能。
上述構建塊是“你所需要的一切”(截至 2024 年 10 月)¶
本節的主要前提是,大多數 Transformer 變體都是 GPT 風格的,由 Embedding、Positional Encoding、Attention Blocks 和 Feed Forward networks 等層組成。如果我們試圖對此領域的差異進行分類,可能會得出以下幾點
層型別(啟用函式,如
SwiGLU等,歸一化函式,如RMSNorm等,位置編碼,如 Sinusoidal, Rotary 等)層順序,例如在哪應用歸一化和位置編碼。
注意力分數修改,例如
ALiBi, Relative Positional Bias 等等。
在非編譯環境(pre-compiler environment)中,你可能編寫一個自定義 Transformer,注意到它可以正常工作但速度很慢。為了解決這個問題,你可能需要為特定的操作序列開發一個自定義的融合核心。在編譯環境(compiler environment)中,你只需執行第一步,然後進行編譯即可從改進的效能中受益。
MultiheadAttention¶
請記住,MultiheadAttention 接受 query、key 和 value 作為輸入,並由一個輸入 projection、一個 scaled_dot_product_attention 運算子和一個輸出 projection 組成。這裡我們想要展示的主要亮點是,用 nested tensors 替換 padded/masked 輸入所帶來的改進。改進有三個方面
使用者體驗 請記住,
nn.MultiheadAttention要求query、key和value是稠密的torch.Tensors。它還提供了一個key_padding_mask,用於遮蔽由於批處理中不同序列長度而產生的key中的 padding token。由於nn.MHA中沒有query_padding_mask,使用者必須小心地對輸出進行 mask/slice 以考慮 query 序列長度。NestedTensor清晰地消除了這種易出錯的 padding mask 的需求。記憶體 Nested tensors 允許你清晰地表示一批不同序列長度的資料,而無需例項化一個帶有
[B, S]padding mask(其中B是批大小,S是批處理中的最大序列長度,D是嵌入大小)的稠密[B, S, D]張量。因此,輸入和中間啟用將使用更少的記憶體。效能 由於未例項化 padding 且跳過了對 padding 的不必要計算,效能和記憶體使用都得到了改善。
我們將透過在 Nested Tensor 教程 中的 MultiheadAttention 層基礎上進行構建,並將其與 nn.MultiheadAttention 層進行比較,來演示上述優勢。
import torch
import torch.nn as nn
import torch.nn.functional as F
class MultiHeadAttention(nn.Module):
"""
Computes multi-head attention. Supports nested or padded tensors.
Args:
E_q (int): Size of embedding dim for query
E_k (int): Size of embedding dim for key
E_v (int): Size of embedding dim for value
E_total (int): Total embedding dim of combined heads post input projection. Each head
has dim E_total // nheads
nheads (int): Number of heads
dropout (float, optional): Dropout probability. Default: 0.0
bias (bool, optional): Whether to add bias to input projection. Default: True
"""
def __init__(
self,
E_q: int,
E_k: int,
E_v: int,
E_total: int,
nheads: int,
dropout: float = 0.0,
bias=True,
device=None,
dtype=None,
):
factory_kwargs = {"device": device, "dtype": dtype}
super().__init__()
self.nheads = nheads
self.dropout = dropout
self._qkv_same_embed_dim = E_q == E_k and E_q == E_v
if self._qkv_same_embed_dim:
self.packed_proj = nn.Linear(E_q, E_total * 3, bias=bias, **factory_kwargs)
else:
self.q_proj = nn.Linear(E_q, E_total, bias=bias, **factory_kwargs)
self.k_proj = nn.Linear(E_k, E_total, bias=bias, **factory_kwargs)
self.v_proj = nn.Linear(E_v, E_total, bias=bias, **factory_kwargs)
E_out = E_q
self.out_proj = nn.Linear(E_total, E_out, bias=bias, **factory_kwargs)
assert E_total % nheads == 0, "Embedding dim is not divisible by nheads"
self.E_head = E_total // nheads
self.bias = bias
def forward(
self,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attn_mask=None,
is_causal=False,
) -> torch.Tensor:
"""
Forward pass; runs the following process:
1. Apply input projection
2. Split heads and prepare for SDPA
3. Run SDPA
4. Apply output projection
Args:
query (torch.Tensor): query of shape (``N``, ``L_q``, ``E_qk``)
key (torch.Tensor): key of shape (``N``, ``L_kv``, ``E_qk``)
value (torch.Tensor): value of shape (``N``, ``L_kv``, ``E_v``)
attn_mask (torch.Tensor, optional): attention mask of shape (``N``, ``L_q``, ``L_kv``) to pass to SDPA. Default: None
is_causal (bool, optional): Whether to apply causal mask. Default: False
Returns:
attn_output (torch.Tensor): output of shape (N, L_t, E_q)
"""
# Step 1. Apply input projection
if self._qkv_same_embed_dim:
if query is key and key is value:
result = self.packed_proj(query)
query, key, value = torch.chunk(result, 3, dim=-1)
else:
q_weight, k_weight, v_weight = torch.chunk(
self.packed_proj.weight, 3, dim=0
)
if self.bias:
q_bias, k_bias, v_bias = torch.chunk(
self.packed_proj.bias, 3, dim=0
)
else:
q_bias, k_bias, v_bias = None, None, None
query, key, value = (
F.linear(query, q_weight, q_bias),
F.linear(key, k_weight, k_bias),
F.linear(value, v_weight, v_bias),
)
else:
query = self.q_proj(query)
key = self.k_proj(key)
value = self.v_proj(value)
# Step 2. Split heads and prepare for SDPA
# reshape query, key, value to separate by head
# (N, L_t, E_total) -> (N, L_t, nheads, E_head) -> (N, nheads, L_t, E_head)
query = query.unflatten(-1, [self.nheads, self.E_head]).transpose(1, 2)
# (N, L_s, E_total) -> (N, L_s, nheads, E_head) -> (N, nheads, L_s, E_head)
key = key.unflatten(-1, [self.nheads, self.E_head]).transpose(1, 2)
# (N, L_s, E_total) -> (N, L_s, nheads, E_head) -> (N, nheads, L_s, E_head)
value = value.unflatten(-1, [self.nheads, self.E_head]).transpose(1, 2)
# Step 3. Run SDPA
# (N, nheads, L_t, E_head)
attn_output = F.scaled_dot_product_attention(
query, key, value, dropout_p=self.dropout, is_causal=is_causal
)
# (N, nheads, L_t, E_head) -> (N, L_t, nheads, E_head) -> (N, L_t, E_total)
attn_output = attn_output.transpose(1, 2).flatten(-2)
# Step 4. Apply output projection
# (N, L_t, E_total) -> (N, L_t, E_out)
attn_output = self.out_proj(attn_output)
return attn_output
實用工具¶
在本節中,我們包含了一個實用工具,用於使用 Zipf 分佈生成半真實資料以獲取句子長度。這用於生成巢狀的 query、key 和 value 張量。我們還包含了一個基準測試實用工具。
import numpy as np
def zipf_sentence_lengths(alpha: float, batch_size: int) -> torch.Tensor:
# generate fake corpus by unigram Zipf distribution
# from wikitext-2 corpus, we get rank "." = 3, "!" = 386, "?" = 858
sentence_lengths = np.empty(batch_size, dtype=int)
for ibatch in range(batch_size):
sentence_lengths[ibatch] = 1
word = np.random.zipf(alpha)
while word != 3 and word != 386 and word != 858:
sentence_lengths[ibatch] += 1
word = np.random.zipf(alpha)
return torch.tensor(sentence_lengths)
# Generate a batch of semi-realistic data using Zipf distribution for sentence lengths
# in the form of nested tensors with the jagged layout.
def gen_batch(N, E_q, E_k, E_v, device, dtype=torch.float32, query_seq_len_1=False):
# generate semi-realistic data using Zipf distribution for sentence lengths
sentence_lengths = zipf_sentence_lengths(alpha=1.2, batch_size=N)
# Note: the torch.jagged layout is a nested tensor layout that supports a single ragged
# dimension and works with torch.compile. The batch items each have shape (B, S*, D)
# where B = batch size, S* = ragged sequence length, and D = embedding dimension.
if query_seq_len_1:
query = torch.nested.nested_tensor(
[torch.randn(1, E_q, dtype=dtype, device=device) for l in sentence_lengths],
layout=torch.jagged,
)
else:
query = torch.nested.nested_tensor(
[
torch.randn(l.item(), E_q, dtype=dtype, device=device)
for l in sentence_lengths
],
layout=torch.jagged,
)
key = torch.nested.nested_tensor(
[
torch.randn(s.item(), E_k, dtype=dtype, device=device)
for s in sentence_lengths
],
layout=torch.jagged,
)
value = torch.nested.nested_tensor(
[
torch.randn(s.item(), E_v, dtype=dtype, device=device)
for s in sentence_lengths
],
layout=torch.jagged,
)
return query, key, value, sentence_lengths
import math
import timeit
def benchmark(func, *args, **kwargs):
torch.cuda.synchronize()
torch.cuda.reset_peak_memory_stats()
begin = timeit.default_timer()
output = func(*args, **kwargs)
torch.cuda.synchronize()
end = timeit.default_timer()
return output, (end - begin), torch.cuda.max_memory_allocated()
現在我們將演示在 MultiheadAttention 層中使用 nested tensors + compile 進行自注意力計算時的效能改進。我們將其與傳統的 nn.MultiheadAttention + compile(帶 padding 和 masking)進行比較。
N, E_q, E_k, E_v, E_total = 512, 512, 512, 512, 512
E_out = E_q
d_model = E_q
nheads = 8
dropout = 0.0
bias = True
device = "cuda"
torch.manual_seed(6)
query, key, value, sentence_lengths = gen_batch(N, E_q, E_k, E_v, device)
S = sentence_lengths.max().item()
print(
f"Total sequence length in nested query {sentence_lengths.sum().item()}, max sequence length {S}"
)
padded_query, padded_key, padded_value = (
t.to_padded_tensor(0.0) for t in (query, key, value)
)
torch.manual_seed(6)
mha_layer = MultiHeadAttention(
E_q, E_k, E_v, E_total, nheads, dropout=dropout, bias=bias, device="cuda"
)
torch.manual_seed(6)
vanilla_mha_layer = nn.MultiheadAttention(
E_q, nheads, dropout=dropout, batch_first=True, bias=bias, device="cuda"
)
# ``nn.MultiheadAttention`` uses a non conventional initialization for layers, so do this for exact parity :(
mha_layer.out_proj.weight = nn.Parameter(
vanilla_mha_layer.out_proj.weight.clone().detach()
)
mha_layer.packed_proj.weight = nn.Parameter(
vanilla_mha_layer.in_proj_weight.clone().detach()
)
mha_layer.out_proj.bias = nn.Parameter(vanilla_mha_layer.out_proj.bias.clone().detach())
mha_layer.packed_proj.bias = nn.Parameter(
vanilla_mha_layer.in_proj_bias.clone().detach()
)
new_mha_layer = torch.compile(mha_layer)
# warmup compile
nested_result_warmup = new_mha_layer(query, query, query, is_causal=True)
# benchmark
nested_result, nested_time, nested_peak_memory = benchmark(
new_mha_layer, query, query, query, is_causal=True
)
padded_nested_result = nested_result.to_padded_tensor(0.0)
# For the vanilla ``nn.MultiheadAttention``, we need to construct the ``key_padding_mask``
# Further, ``nn.MultiheadAttention`` forces one to materialize the ``attn_mask`` even if using ``is_causal``
src_key_padding_mask = torch.where(padded_query == 0.0, -math.inf, 0)[:, :, 0]
attn_mask = torch.empty((N, S, S), device=device).fill_(float("-inf"))
for i, s in enumerate(sentence_lengths):
attn_mask[i, :s, :s] = nn.Transformer.generate_square_subsequent_mask(s)
attn_mask = attn_mask.unsqueeze(1).expand(N, nheads, S, S).reshape(N * nheads, S, S)
vanilla_mha_layer = torch.compile(vanilla_mha_layer)
# warmup compile
warmup_vanilla_result = vanilla_mha_layer(
padded_query,
padded_query,
padded_query,
attn_mask=attn_mask,
key_padding_mask=src_key_padding_mask,
need_weights=False,
is_causal=True,
)
# benchmark
(padded_result, _), padded_time, padded_peak_memory = benchmark(
vanilla_mha_layer,
padded_query,
padded_query,
padded_query,
key_padding_mask=src_key_padding_mask,
need_weights=False,
attn_mask=attn_mask,
is_causal=True,
)
print(f"{padded_time=:.5f}, padded_peak_memory={padded_peak_memory/1e9:.2f} GB")
print(f"{nested_time=:.5f}, nested_peak_memory={nested_peak_memory/1e9:.2f} GB")
print(
"Max difference between vanilla and nested result",
(padded_result - padded_nested_result).abs().max().item(),
)
print(f"Nested speedup: {(padded_time/nested_time):.2f}")
print(
f"Nested peak memory reduction {((padded_peak_memory - nested_peak_memory)/1e9):.2f} GB"
)
Total sequence length in nested query 10436, max sequence length 128
padded_time=0.01608, padded_peak_memory=3.87 GB
nested_time=0.00254, nested_peak_memory=0.92 GB
Max difference between vanilla and nested result 0.0
Nested speedup: 6.33
Nested peak memory reduction 2.96 GB
作為參考,以下是在 A100 上的樣本輸出
padded_time=0.03454, padded_peak_memory=4.14 GB
nested_time=0.00612, nested_peak_memory=0.76 GB
Max difference between vanilla and nested result 0.0
Nested speedup: 5.65
Nested peak memory reduction 3.39 GB
我們也可以看到反向傳播的相同情況
for i, entry_length in enumerate(sentence_lengths):
# padding-specific step: remove output projection bias from padded entries for fair comparison
padded_result[i, entry_length:, :] = 0.0
_, padded_bw_time, padded_bw_peak_mem = benchmark(
lambda: padded_result.sum().backward()
)
_, nested_bw_time, nested_bw_peak_mem = benchmark(
lambda: padded_nested_result.sum().backward()
)
print(f"{padded_bw_time=:.5f}, padded_bw_peak_mem={padded_bw_peak_mem/1e9:.2f} GB")
print(f"{nested_bw_time=:.5f}, nested_bw_peak_mem={nested_bw_peak_mem/1e9:.2f} GB")
print(f"Nested backward speedup: {(padded_bw_time/nested_bw_time):.2f}")
print(
f"Nested backward peak memory reduction {((padded_bw_peak_mem - nested_bw_peak_mem)/1e9):.2f} GB"
)
print(
"Difference in out_proj.weight.grad",
(mha_layer.out_proj.weight.grad - vanilla_mha_layer.out_proj.weight.grad)
.abs()
.max()
.item(),
)
print(
"Difference in packed_proj.weight.grad",
(mha_layer.packed_proj.weight.grad - vanilla_mha_layer.in_proj_weight.grad)
.abs()
.max()
.item(),
)
print(
"Difference in out_proj.bias.grad",
(mha_layer.out_proj.bias.grad - vanilla_mha_layer.out_proj.bias.grad)
.abs()
.max()
.item(),
)
print(
"Difference in packed_proj.bias.grad",
(mha_layer.packed_proj.bias.grad - vanilla_mha_layer.in_proj_bias.grad)
.abs()
.max()
.item(),
)
padded_bw_time=1.62963, padded_bw_peak_mem=4.68 GB
nested_bw_time=0.06652, nested_bw_peak_mem=3.04 GB
Nested backward speedup: 24.50
Nested backward peak memory reduction 1.64 GB
Difference in out_proj.weight.grad 0.000396728515625
Difference in packed_proj.weight.grad 0.00146484375
Difference in out_proj.bias.grad 0.0
Difference in packed_proj.bias.grad 0.0029296875
A100 上的樣本輸出
padded_bw_time=2.09337, padded_bw_peak_mem=5.10 GB
nested_bw_time=0.01452, nested_bw_peak_mem=3.24 GB
Nested backward speedup: 144.13
Nested backward peak memory reduction 1.86 GB
Difference in out_proj.weight.grad 0.000244140625
Difference in packed_proj.weight.grad 0.001556396484375
Difference in out_proj.bias.grad 0.0
Difference in packed_proj.bias.grad 0.001953125
GPT 風格的層¶
基本的 GPT 風格 Transformer 層包括一個因果自注意力層,後接一個帶有 skip connections 的前饋網路 (FFN)。使用上面的 MultiheadAttention 層實現這一點相當簡單,並且與使用 is_causal=True 的 nn.TransformerEncoderLayer 的結果等效。
為簡潔起見,本教程省略了實現其他 nn 層的示例,你可以在此處找到它們。
更進一步¶
到目前為止,我們演示瞭如何實現遵循傳統 nn.MultiheadAttention 的高效能 MultiheadAttention 層。回到我們對 Transformer 架構修改的分類,請記住我們將修改分為層型別、層順序和注意力分數修改。我們相信改變層型別和層順序(例如將 LayerNorm 替換為 RMSNorm)是相當簡單的。
在本節中,我們將討論使用上述構建塊的各種功能,包括以下內容
交叉注意力
完全遮蔽的行不再導致 NaNs
修改注意力分數:使用 FlexAttention 和 NJT 的 ALiBi
Packed Projection
交叉注意力¶
交叉注意力是一種注意力形式,其中 query 和 key/value 張量來自不同的序列。
一個例子是在 nn.TransformerDecoderLayer 中,其中 query 來自 decoder,而 key/value 來自 encoder。
上述 MultiheadAttention 層使用 nested tensors 對 query 和 key/value 都能很好地推廣到這種情況。
query, _, _, q_len = gen_batch(N, E_q, E_k, E_v, device)
_, key, value, kv_len = gen_batch(N, E_q, E_k, E_v, device)
print(
f"Total sequence length in nested query {q_len.sum().item()}, max sequence length {q_len.max().item()}"
)
print(
f"Total sequence length in nested key/value {kv_len.sum().item()}, max sequence length {kv_len.max().item()}"
)
out = new_mha_layer(query, key, value, is_causal=False)
Total sequence length in nested query 10617, max sequence length 165
Total sequence length in nested key/value 10176, max sequence length 137
如上所述,我們可以將其與 vanilla 編譯的 nn.MultiheadAttention 進行比較。
torch.manual_seed(6)
query, _, _, q_len = gen_batch(N, E_q, E_k, E_v, device)
_, key, value, kv_len = gen_batch(N, E_q, E_k, E_v, device)
padded_query, padded_key, padded_value = (
t.to_padded_tensor(0.0) for t in (query, key, value)
)
key_padding_mask = torch.where(padded_key == 0.0, -math.inf, 0)[:, :, 0]
# warmup compile
warmup_nested_result = new_mha_layer(query, key, value, is_causal=False)
warmup_vanilla_result = vanilla_mha_layer(
padded_query,
padded_key,
padded_value,
key_padding_mask=key_padding_mask,
need_weights=False,
is_causal=False,
)
nested_result, nested_time, nested_peak_memory = benchmark(
new_mha_layer, query, key, value, is_causal=False
)
(padded_result, _), padded_time, padded_peak_memory = benchmark(
vanilla_mha_layer,
padded_query,
padded_key,
padded_value,
key_padding_mask=key_padding_mask,
need_weights=False,
is_causal=False,
)
padded_nested_result = nested_result.to_padded_tensor(0.0)
for i, entry_length in enumerate(q_len):
# padding-specific step: remove output projection bias from padded entries for fair comparison
padded_result[i, entry_length:, :] = 0.0
print(
"Max difference between vanilla and nested result",
(padded_result - padded_nested_result).abs().max().item(),
)
print(f"Nested speedup: {(padded_time/nested_time):.2f}")
print(
f"Nested peak memory reduction {((padded_peak_memory - nested_peak_memory)/1e9):.2f} GB"
)
Max difference between vanilla and nested result 0.0
Nested speedup: 4.98
Nested peak memory reduction 1.20 GB
A100 上的樣本輸出
Max difference between vanilla and nested result 0.0
Nested speedup: 4.01
Nested peak memory reduction 1.40 GB
完全遮蔽的行不再導致 NaNs¶
長期以來,nn.MultiheadAttention 和 scaled_dot_product_attention 存在一個問題,即如果一行被完全遮蔽,注意力層的輸出將是 NaN。參見該 issue。這是因為空集上的 softmax 是未定義的。
感謝此 PR,這種情況不再發生。相反,scaled_dot_product_attention 中對應於完全遮蔽行的輸出將為 0。對於 nn.MHA 不使用“fast-path”的情況,這也將適用。
強烈建議使用帶有 NJTs 的自定義 MHA 層,而不是現有 nn.MultiheadAttention 中的“fast-path”,因為 NJT 正確建模不規則性的能力使得能夠正確表達空序列。
FlexAttention + NJT¶
NJT 也可以與 FlexAttention 模組組合。這是對 MultiheadAttention 層的泛化,允許對注意力分數進行任意修改。下面的例子採用 ALiBi 的實現 alibi_mod,來自 attention gym,並將其與 nested 輸入張量一起使用。
from torch.nn.attention.flex_attention import flex_attention
def generate_alibi_bias(H: int):
"""Returns an alibi bias score_mod given the number of heads H
Args:
H: number of heads
Returns:
alibi_bias: alibi bias score_mod
"""
def alibi_mod(score, b, h, q_idx, kv_idx):
scale = torch.exp2(-((h + 1) * 8.0 / H))
bias = (q_idx - kv_idx) * scale
return score + bias
return alibi_mod
query, key, value, _ = gen_batch(N, E_q, E_k, E_v, device)
n_heads, D = 8, E_q // 8
alibi_score_mod = generate_alibi_bias(n_heads)
query = query.unflatten(-1, [n_heads, D]).transpose(1, 2).detach().requires_grad_()
key = key.unflatten(-1, [n_heads, D]).transpose(1, 2).detach().requires_grad_()
value = value.unflatten(-1, [n_heads, D]).transpose(1, 2).detach().requires_grad_()
out_flex2 = flex_attention(query, key, value, score_mod=alibi_score_mod)
此外,也可以透過 create_nested_block_mask 函式將 FlexAttention 的 block_mask 實用工具與 NJTs 一起使用。這對於利用 mask 的稀疏性加速注意力計算很有用。特別是,該函式會為 NJT 中所有可變長度序列合併到一起的“堆疊序列”建立一個稀疏的塊 mask,同時正確遮蔽序列間的注意力。在下面的例子中,我們展示瞭如何使用此實用工具建立一個因果塊 mask。
from torch.nn.attention.flex_attention import create_nested_block_mask
def causal_mask(b, h, q_idx, kv_idx):
return q_idx >= kv_idx
query, key, value, _ = gen_batch(N, E_q, E_k, E_v, device)
block_mask = create_nested_block_mask(causal_mask, 1, 1, query, _compile=True)
query = query.unflatten(-1, [n_heads, D]).transpose(1, 2).detach().requires_grad_()
key = key.unflatten(-1, [n_heads, D]).transpose(1, 2).detach().requires_grad_()
value = value.unflatten(-1, [n_heads, D]).transpose(1, 2).detach().requires_grad_()
out_flex = flex_attention(query, key, value, block_mask=block_mask)
Packed Projection¶
Packed projection 是一種技術,利用了當 projection 的輸入(矩陣乘法)相同時(自注意力)的特點,可以將 projection 權重和偏差打包到單個張量中。當單個 projection 受記憶體限制而非計算限制時,它特別有用。這裡我們將演示兩個示例
MultiheadAttention 的輸入 projection
Transformer 層前饋網路中的 SwiGLU activation
MultiheadAttention 的輸入 projection¶
在進行自注意力時,query、key 和 value 是同一個張量。這些張量中的每一個都透過一個 Linear(E_q, E_total) 層進行 projection。我們可以將這打包到一個層中,這正是我們在上面的 MultiheadAttention 層中所做的。
讓我們比較 packed projection 與常規方法的效能
class InputProjection(nn.Module):
def __init__(self, E_q, E_total, bias=False, device=None, dtype=None):
factory_kwargs = {"device": device, "dtype": dtype}
super().__init__()
self.q_proj = nn.Linear(E_q, E_total, bias=bias, **factory_kwargs)
self.k_proj = nn.Linear(E_q, E_total, bias=bias, **factory_kwargs)
self.v_proj = nn.Linear(E_q, E_total, bias=bias, **factory_kwargs)
def forward(self, x):
return self.q_proj(x), self.k_proj(x), self.v_proj(x)
class PackedInputProjection(nn.Module):
def __init__(self, E_q, E_total, bias=False, device=None, dtype=None):
factory_kwargs = {"device": device, "dtype": dtype}
super().__init__()
self.packed_proj = nn.Linear(E_q, E_total * 3, bias=bias, **factory_kwargs)
def forward(self, query):
return torch.chunk(self.packed_proj(query), 3, dim=-1)
B, D, dtype = 256, 8192, torch.bfloat16
torch.set_float32_matmul_precision("high")
in_proj = torch.compile(InputProjection(D, D, device="cuda", dtype=torch.bfloat16))
packed_in_proj = torch.compile(
PackedInputProjection(D, D, device="cuda", dtype=torch.bfloat16)
)
q, _, _, sequence_lengths = gen_batch(B, D, D, D, device="cuda", dtype=torch.bfloat16)
# warmup
in_proj(q)
packed_in_proj(q)
# benchmark
(q_out, k_out, v_out), time, _ = benchmark(in_proj, q)
(q_out, k_out, v_out), time_packed, _ = benchmark(packed_in_proj, q)
# On my A100 prints 1.05x speedup
print(
f"InputProjection: {time:5f} s, PackedInputProjection: {time_packed:5f} s, speedup: {time/time_packed:.2f}x"
)
InputProjection: 0.034046 s, PackedInputProjection: 0.032757 s, speedup: 1.04x
Transformer 層的前饋網路中的 SwiGLU¶
Swish-Gated Linear Unit (SwiGLU) 是一種非線性啟用函式,在 Transformer 層的前饋網路中越來越受歡迎(例如 Llama)。帶有 SwiGLU 啟用的前饋網路定義如下
class SwiGLUFFN(nn.Module):
def __init__(
self,
dim,
hidden_dim,
multiple_of,
ffn_dim_multiplier=None,
device=None,
dtype=None,
):
factory_kwargs = {"device": device, "dtype": dtype}
super().__init__()
hidden_dim = int(2 * hidden_dim / 3)
# custom dim factor multiplier
if ffn_dim_multiplier is not None:
hidden_dim = int(ffn_dim_multiplier * hidden_dim)
hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of)
self.w1 = nn.Linear(dim, hidden_dim, bias=False, **factory_kwargs)
self.w2 = nn.Linear(hidden_dim, dim, bias=False, **factory_kwargs)
self.w3 = nn.Linear(dim, hidden_dim, bias=False, **factory_kwargs)
def forward(self, x):
return self.w2(F.silu(self.w1(x)) * self.w3(x))
使用 packed projection 的另一種實現方法是
class PackedSwiGLUFFN(nn.Module):
def __init__(
self,
dim,
hidden_dim,
multiple_of,
ffn_dim_multiplier=None,
device=None,
dtype=None,
):
factory_kwargs = {"device": device, "dtype": dtype}
super().__init__()
hidden_dim = int(2 * hidden_dim / 3)
# custom dim factor multiplier
if ffn_dim_multiplier is not None:
hidden_dim = int(ffn_dim_multiplier * hidden_dim)
hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of)
self.w13 = nn.Linear(dim, 2 * hidden_dim, bias=False, **factory_kwargs)
self.w2 = nn.Linear(hidden_dim, dim, bias=False, **factory_kwargs)
def forward(self, x):
x1, x3 = torch.chunk(self.w13(x), 2, dim=-1)
return self.w2(F.silu(x1) * x3)
我們可以如下比較這兩種實現的效能。根據你的硬體,結果可能會有所不同。在 A100 上,我看到 D=128 時有 1.12 倍的加速。
D = 128
swigluffn = torch.compile(SwiGLUFFN(D, D * 4, 256, device="cuda", dtype=torch.bfloat16))
packed_swigluffn = torch.compile(
PackedSwiGLUFFN(D, D * 4, 256, device="cuda", dtype=torch.bfloat16)
)
q, _, _, sentence_lengths = gen_batch(D, D, D, D, device="cuda", dtype=torch.bfloat16)
# warmup
swigluffn(q)
packed_swigluffn(q)
# benchmark
_, time, _ = benchmark(swigluffn, q)
_, time_packed, _ = benchmark(packed_swigluffn, q)
# On my A100 prints 1.08x speedup
print(
f"SwiGLUFFN: {time} s, PackedSwiGLUFFN: {time_packed} s, speedup: {time/time_packed:.2f}x"
)
SwiGLUFFN: 0.0010205730000052426 s, PackedSwiGLUFFN: 0.0010395129997959884 s, speedup: 0.98x
擴充套件示例¶
我們計劃更新本教程,以演示更多如何使用各種高效能構建塊(如 KV-Caching、Grouped Query Attention 等)的示例。此外,還有一些很好的例子,展示瞭如何使用各種高效能構建塊來實現不同的 Transformer 架構。一些示例包括
結論¶
在本教程中,我們介紹了 PyTorch 提供的用於編寫 Transformer 層的底層構建塊,並演示瞭如何組合它們的示例。我們希望本教程能讓讀者瞭解 PyTorch 使用者可以多麼輕鬆地實現靈活且高效能的 Transformer 層。
指令碼總執行時間: ( 1 分 6.445 秒)