PyTorch 2.0 剛剛釋出。其最重要的新特性是 torch.compile(),只需一行程式碼修改,便有望自動提升程式碼庫的效能。我們之前已經在 Hugging Face Transformers 和 TIMM 模型中驗證了這一承諾,並深入探討了其動機、架構和未來發展。
儘管 torch.compile() 很重要,但 PyTorch 2.0 的內容遠不止於此。值得注意的是,PyTorch 2.0 包含了多種加速 Transformer 塊的策略,這些改進對於擴散模型也至關重要。例如,FlashAttention 等技術因其能夠顯著加速 Stable Diffusion 並實現更大批次大小的能力,在擴散社群中變得非常流行,現在它們已成為 PyTorch 2.0 的一部分。
在這篇文章中,我們將討論 PyTorch 2.0 中注意力層是如何最佳化的,以及這些最佳化如何應用於流行的 🧨 Diffusers 庫。最後,我們將透過基準測試展示使用 PyTorch 2.0 和 Diffusers 如何立即在不同硬體上帶來顯著的效能提升。
更新(2023 年 6 月):新增了一個章節,展示了在修正 diffusers 程式碼庫中的圖中斷之後,使用最新版 PyTorch (2.0.1) 的 torch.compile() 帶來了顯著的效能提升。關於如何查詢和修復圖中斷的更詳細分析將在另一篇文章中釋出。
加速 Transformer 塊
PyTorch 2.0 在 torch.nn.functional 中包含了一個縮放點積注意力函式。該函式包含多種實現,可根據輸入和所使用的硬體進行應用。在 PyTorch 2.0 之前,您必須搜尋第三方實現並安裝單獨的軟體包,才能利用記憶體最佳化演算法,如 FlashAttention。可用的實現有:
- FlashAttention,來自官方的 FlashAttention 專案。
- 記憶體高效注意力,來自 xFormers 專案。
- 適用於非 CUDA 裝置或需要高精度時的原生 C++ 實現。
所有這些方法預設可用,PyTorch 將嘗試透過使用新的縮放點積注意力 (SDPA) API 自動選擇最佳方法。您也可以單獨切換它們以進行更精細的控制,詳情請參見文件。
在 diffusers 中使用縮放點積注意力
將加速 PyTorch 2.0 Transformer 注意力整合到 Diffusers 庫中,是透過使用 set_attn_processor 方法實現的,該方法允許配置可插拔的注意力模組。在這種情況下,建立了一個新的注意力處理器,它在 PyTorch 2.0 可用時預設啟用。為了清晰起見,以下是手動啟用它的方法(但通常沒有必要,因為 diffusers 會自動處理):
from diffusers import StableDiffusionPipeline
from diffusers.models.cross_attention import AttnProcessor2_0
pipe = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5")
pipe.to("cuda")
pipe.unet.set_attn_processor(AttnProcessor2_0())
prompt = "a photo of an astronaut riding a horse on mars"
image = pipe(prompt).images[0]
Stable Diffusion 基準測試
我們使用 PyTorch 2.0 中的加速點積注意力在 Diffusers 中運行了多項測試。我們從 pip 安裝了 diffusers,並使用了 PyTorch 2.0 的每晚版本,因為我們的測試是在官方釋出之前進行的。我們還使用了 torch.set_float32_matmul_precision('high') 來啟用額外的快速矩陣乘法演算法。
我們將結果與 diffusers 中傳統的注意力實現(下面稱為 vanilla)以及 PyTorch 2.0 之前效能最佳的解決方案:安裝了 xFormers 包 (v0.0.16) 的 PyTorch 1.13.1 進行了比較。
結果在未編譯(即,完全沒有程式碼更改)和對 UNet 模組進行一次 torch.compile() 呼叫(編譯)的情況下進行測量。我們沒有編譯影像解碼器,因為大部分時間都花在執行 UNet 評估的 50 次去噪迭代中。
float32 結果

下圖探討了不同代各種代表性 GPU 的效能改進與批次大小的關係。我們收集了每種組合的資料,直到達到最大記憶體利用率。Vanilla 注意力比 xFormers 或 PyTorch 2.0 更早地耗盡記憶體,這解釋了較大批次大小缺少柱狀圖的原因。同樣,A100(我們使用的是 40 GB 版本)能夠執行 64 的批次大小,但在我們的測試中,其他 GPU 只能達到 32。




我們發現,即使不使用 torch.compile(),與 vanilla 注意力相比,效能也有非常顯著的提升。開箱即用的 PyTorch 2.0 和 diffusers 安裝在 A100 上帶來了大約 50% 的加速,在 4090 GPU 上根據批次大小的不同,加速介於 35% 和 50% 之間。效能改進在 Ada (4090) 或 Ampere (A100) 等現代 CUDA 架構上更為顯著,但對於仍在雲服務中大量使用的舊架構,它們仍然非常顯著。
除了更快的速度,PyTorch 2.0 中加速的 Transformer 實現允許使用更大的批次大小。單個 40GB A100 GPU 在批次大小為 10 時記憶體不足,而 24 GB 高階消費級顯示卡(如 3090 和 4090)無法一次生成 8 張影像。使用 PyTorch 2.0 和 diffusers,我們可以在 3090 和 4090 上實現 **48** 的批次大小,在 A100 上實現 **64** 的批次大小。這對於雲服務和應用程式意義重大,因為它們可以一次高效地處理更多影像。
與 PyTorch 1.13.1 + xFormers 相比,新的加速 Transformer 實現仍然更快,並且無需額外的軟體包或依賴項。在這種情況下,我們發現在 A100 或 T4 等資料中心卡上,速度適度提升了高達 2%,但在兩代最新的消費級卡上表現出色:3090 上速度提升高達 20%,4090 上根據批次大小的不同,速度提升介於 10% 和 45% 之間。
當使用 torch.compile() 時,我們在之前的改進基礎上額外獲得了(通常)2% 到 3% 的效能提升。由於編譯需要一些時間,這更適合面向使用者的推理服務或訓練。**更新**:當圖中斷最小時,torch.compile() 實現的改進要大得多,詳情請參閱新章節。
float16 結果




當我們考慮 float16 推理時,PyTorch 2.0 中加速 Transformer 實現的效能提升,在所有我們測試的 GPU 上,與標準注意力相比,介於 20% 到 28% 之間,但 4090 除外,它屬於更現代的 Ada 架構。當使用 PyTorch 2.0 每晚版本時,這款 GPU 受益於顯著的效能提升。至於最佳化的 SDPA 與 xFormers 相比,除了 4090,大多數 GPU 的結果通常持平。將 torch.compile() 新增到其中,將整體效能再提升了幾個百分點。
最小化圖中斷後 torch.compile() 的效能
在前面的章節中,我們看到使用 PyTorch 2.0 的加速 Transformer 實現相對於 PyTorch 的早期版本(無論是否使用 xFormers)提供了重要的效能改進。然而,torch.compile() 只帶來了適度的邊際改進。在 PyTorch 團隊的幫助下,我們發現這些適度改進的原因是 diffusers 原始碼中的某些操作導致了圖中斷,這阻止了 torch.compile() 充分利用圖最佳化。
修復圖中斷後(詳情請參見這些PR),我們測量了 torch.compile() 相對於 PyTorch 2 未編譯版本的額外改進,並看到了非常重要的增量效能提升。下圖是使用 2023 年 5 月 1 日下載的 PyTorch 2 每晚版本獲得的結果,它顯示了大多數工作負載的改進範圍約為 13% 到 22%。對於現代 GPU 系列,效能提升更好,A100 的提升超過 30%。圖中還有兩個異常值。首先,我們看到 T4 在批次大小為 16 時效能下降,這給該卡帶來了巨大的記憶體壓力。另一方面,我們看到 A100 在批次大小僅為 1 時效能提升超過 100%,這很有趣,但並不代表具有如此大記憶體的 GPU 的實際使用情況——能夠服務多個客戶的更大批次大小通常對 A100 上的服務部署更有意義。

再次強調,這些效能提升是**額外**的,是在遷移到 PyTorch 2 並使用加速 Transformer 縮放點積注意力實現的基礎上實現的。我們建議在生產環境中部署 diffusers 時使用 torch.compile()。
結論
PyTorch 2.0 帶來了多項功能,可以最佳化基礎 Transformer 塊的關鍵元件,並且可以透過使用 torch.compile 進一步改進。這些最佳化為擴散模型帶來了顯著的記憶體和時間改進,並且不再需要安裝第三方庫。
要利用這些速度和記憶體改進,您只需升級到 PyTorch 2.0 並使用 diffusers >= 0.13.0。
有關更多示例和詳細的基準測試資料,請參閱 PyTorch 2.0 與 Diffusers 文件。
致謝
作者感謝 PyTorch 團隊開發出如此優秀的軟體。