跳轉到主要內容
部落格

一個更好的 Transformer,用於快速 Transformer 推理

作者: 2022 年 7 月 12 日2024 年 11 月 15 日暫無評論

概要: Transformer 在自然語言處理(NLP)領域取得了最先進的效能,並正成為其他眾多工的熱門選擇。它們計算成本高昂,這阻礙了其廣泛的生產化。PyTorch 1.12 釋出了 BetterTransformer,它為 Transformer 編碼器推理實現了一個向後相容的 torch.nn.TransformerEncoder 快速路徑,並且不需要模型作者修改他們的模型。BetterTransformer 的改進在許多常見的執行場景中可以將速度和吞吐量提高兩倍以上。要使用 BetterTransformer,請安裝 PyTorch 1.12,並立即開始使用 PyTorch API 的高質量、高效能 Transformer 模型。

Transformer 編碼器架構圖(來自“Attention Is All You Need”)。在推理過程中,整個模組將作為單個 PyTorch 原生函式執行。

在這篇博文中,我們將分享以下主題——效能改進、向後相容性和利用快速路徑。請在下面瞭解更多這些主題。

效能改進

BetterTransformer 推出了針對 CPU 和 GPU 的 MultiHeadAttention 和 TransformerEncoderLayer 的加速原生實現。這些快速路徑已整合到標準的 PyTorch Transformer API 中,並將加速 TransformerEncoderTransformerEncoderLayerMultiHeadAttention nn.module。這些新模組實現了兩種型別的最佳化:(1) 融合核心將通常用於實現 Transformer 的多個獨立運算子組合在一起,以提供更高效的實現;(2) 利用輸入中的稀疏性,避免對填充令牌執行不必要的操作。在許多用於自然語言處理的 Transformer 模型中,填充令牌經常佔輸入批次很大一部分。

向後相容性

值得慶幸的是,無需修改模型即可受益於 BetterTransformer 帶來的效能提升。要受益於快速路徑執行,輸入和操作條件必須滿足一些訪問條件(見下文)。雖然 Transformer API 的內部實現已更改,但 PyTorch 1.12 保持與以前版本中釋出的 Transformer 模組的嚴格相容性,使 PyTorch 使用者可以使用使用以前 PyTorch 版本建立和訓練的模型,同時受益於 BetterTransformer 的改進。

除了啟用 PyTorch nn.Modules,BetterTransformer 還為 PyTorch 庫提供了改進。效能優勢將透過兩種不同的啟用路徑實現:

  1. 透明加速: MultiHeadAttention 等 PyTorch nn.Modules 以及更高級別的 Transformer 元件的現有使用者將自動受益於新 nn.Modules 的改進效能。一個例子是 torchvision 庫中使用的視覺 Transformer (ViT) 實現(程式碼連結)。
  2. Torchtext 庫加速:作為該專案的一部分,我們優化了 Torchtext,使其基於 PyTorch 核心 API 構建,以受益於 BetterTransformer 的增強,同時保持與以前的庫版本以及使用以前 Torchtext 版本訓練的模型嚴格和透明的相容性。在 Torchtext 中使用 PyTorch Transformer 還確保 Torchtext 將受益於 PyTorch Transformer 實現未來預期的增強。

利用快速路徑

BetterTransformer 是 PyTorch Transformer API 的快速路徑。快速路徑是用於 CPU 和 GPU 的關鍵 Transformer 函式的本機專用實現,適用於常見的 Transformer 用例。

為了利用輸入稀疏性(即填充)來加速模型(參見圖 2),在例項化 TransformerEncoder 時將關鍵字引數 enable_nested_tensor=True 設定為 true,並在推理期間傳入 src_key_padding_mask 引數(表示填充令牌)。這要求填充掩碼是連續的,這是典型情況。

目前,BetterTransformer 的加速僅適用於推理中使用的 Transformer 編碼器模型。為了受益於快速路徑執行,模型必須由以下任何元件組成:TransformerEncoderTransformerEncoderLayerMultiheadAttention (MHA)。快速路徑執行還受某些條件限制。最重要的是,模型必須在推理模式下執行,並且在不收集梯度帶資訊(例如,使用 torch.no_grad 執行)的輸入張量上操作。條件的完整列表可以在 nn.MultiHeadAttentionnn.TransformerEncoder 的這些連結中找到。如果不滿足條件,控制將流向舊的 PyTorch 1.11 Transformer 實現,該實現具有相同的 API,但缺少快速路徑效能提升。

其他使用 PyTorch MultiheadAttention 模組的 Transformer 模型(例如解碼器模型)將受益於 BetterTransformer 快速路徑。未來計劃的工作是將端到端 BetterTransformer 快速路徑擴充套件到基於 TransformerDecoder 的模型,以支援流行的 seq2seq 和僅解碼器(例如,OPT)模型架構,以及用於訓練。

加速

以下圖表顯示了 BERT-base 模型在小規模和大規模輸入下實現的效能

圖 1:PyTorch 1.12 與 BetterTransformer 快速路徑執行的改進

圖 2:PyTorch 1.12 與 BetterTransformer 快速路徑執行的改進
透過 enable_nested_tensor=True 啟用稀疏性最佳化

BetterTransformer 包含兩種型別的最佳化:(1) 融合核心,在一個核心中更有效地實現多項操作;(2) 透過避免對填充令牌進行不必要的處理來利用稀疏性。小輸入尺寸的效能增強主要受益於融合核心實現,並且無論填充量如何,都顯示出持續的效能改進。雖然大輸入仍然受益於融合核心,但計算密集型處理限制了融合核心可能獲得的收益,因為基線效能已經接近理論峰值。然而,隨著填充量的增加,效能顯著提高,因為透過利用 NLP 工作負載中填充引入的稀疏性,可以避免越來越多的計算。

未來工作

作為我們正在進行的 PyTorch BetterTransformer 工作的一部分,我們正在努力將 BetterTransformer 的改進擴充套件到 Transformer 解碼器。我們旨在將範圍從推理擴充套件到訓練。

我們正在合作在 FairSeq、MetaSeq 和 HuggingFace 等其他庫上啟用 BetterTransformer,以惠及所有基於 Transformer 的 PyTorch 模型。作為本部落格系列的一部分,我們將提供有關 BetterTransformer 加速在更廣泛的 PyTorch 生態系統中進展的未來更新。

致謝:作者衷心感謝 Lin Qiao、Ajit Mathews、Andrew Tulloch、Dmytro Dzhulgakov、Natalia Gimelshein、Emad El-Haraty、Mark Saroufim、Adnan Aziz、Geeta Chauhan 和 Hamid Shojanazeri 在本專案期間以及本部落格的準備過程中給予的支援、貢獻和許多有益的建議。