跳轉到主要內容
公告

PyTorch 2.2:FlashAttention-v2 整合,AOTInductor

作者: 2024 年 1 月 30 日2025 年 4 月 30 日暫無評論

我們很高興地宣佈 PyTorch® 2.2 釋出(發行說明)!PyTorch 2.2 透過整合 FlashAttention-v2,使 scaled_dot_product_attention 的效能提升約 2 倍,並引入了 AOTInductor,這是一款專為非 Python 伺服器端部署而構建的新型預先編譯和部署工具。

此版本還包括對 Optimizers 的 torch.compile 改進支援、多項新的 inductor 最佳化以及名為 TORCH_LOGS 的新日誌記錄機制。

請注意,我們正在棄用 macOS x86 支援,PyTorch 2.2.x 將是支援 macOS x64 的最後一個版本。

除了 2.2,我們還將釋出一系列針對 PyTorch 領域庫的更新。更多詳細資訊可在庫更新部落格中找到。

自 PyTorch 2.1 以來,此版本由 3,628 次提交和 521 位貢獻者組成。我們衷心感謝我們敬業的社群所做的貢獻。一如既往,我們鼓勵您試用這些功能並在我們改進 2.2 的過程中報告任何問題。有關如何開始使用 PyTorch 2 系列的更多資訊,請訪問我們的入門頁面。

總結

  • scaled_dot_product_attention (SDPA) 現在支援 FlashAttention-2,與以前的版本相比,速度提升約 2 倍。
  • PyTorch 2.2 引入了 TorchInductor 的新型預先擴充套件,名為 AOTInductor,旨在為非 Python 伺服器端編譯和部署 PyTorch 程式。
  • torch.distributed 支援一種用於初始化和表示 ProcessGroups 的新抽象,稱為 device_mesh
  • PyTorch 2.2 釋出了一個名為 TORCH_LOGS 的標準化、可配置的日誌記錄機制。
  • PyTorch 2.2 包含多項 torch.compile 改進,包括對編譯 Optimizers 的改進支援以及改進的 TorchInductor 融合和佈局最佳化。
  • 請注意,我們正在棄用 macOS x86 支援,PyTorch 2.2.x 將是支援 macOS x64 的最後一個版本。
穩定版 Beta 效能改進
  FlashAttention-2 整合 Inductor 最佳化
  AOTInductor aarch64 最佳化
  TORCH_LOGS  
  device_mesh  
  Optimizer 編譯  

*要檢視完整的公開功能提交列表,請點選此處

Beta 功能

[Beta] torch.nn.functional.scaled_dot_product_attention 中對 FlashAttention-2 的支援

torch.nn.functional.scaled_dot_product_attention (SDPA) 現在支援 FlashAttention-2,速度提升約 2 倍(與以前的版本相比),在 A100 GPU 上達到理論最大 FLOPs/s 的約 50-73%。

有關 FlashAttention-2 的更多資訊,請參閱這篇論文

有關如何使用 SDPA 的教程,請參閱此教程

[Beta] AOTInductor:用於 torch.export-ed 程式的預先編譯和部署

AOTInductor 是 TorchInductor 的一個擴充套件,旨在處理匯出的 PyTorch 模型,對其進行最佳化,並生成共享庫以及其他相關工件。這些編譯後的工件可以部署在非 Python 環境中,這些環境通常用於伺服器端的推理。請注意,AOTInductor 支援與 Inductor 相同的後端,包括 CUDA、ROCm 和 CPU。

有關更多資訊,請參閱 AOTInductor 教程

[Beta] 透過 TORCH_LOGS 進行細粒度可配置日誌記錄

PyTorch 現在提供了一個標準化、可配置的日誌記錄機制,可用於分析各種子系統(例如編譯和分散式操作)的狀態。

可以透過 TORCH_LOGS 環境變數啟用日誌。例如,要將 TorchDynamo 的日誌級別設定為 logging.ERROR,將 TorchInductor 的日誌級別設定為 logging.DEBUG,請將 TORCH_LOGS=”-dynamo,+inductor” 傳遞給 PyTorch。

有關更多資訊,請參閱日誌記錄文件教程

[Beta] torch.distributed.device_mesh

PyTorch 2.2 引入了一種用於表示分散式並行中涉及的 ProcessGroups 的新抽象,稱為 torch.distributed.device_mesh。此抽象允許使用者透過 N 維陣列表示節點間和節點內程序組,例如,一個維度可以是 FSDP 中的資料並行,而另一個維度可以表示 FSDP 中的張量並行。

有關更多資訊,請參閱 device_mesh 教程

[Beta] 改進 torch.compile 最佳化器

torch.compile 最佳化器已進行多項改進,包括更少的開銷和對 cuda graphs 的支援。

有關改進的更多技術細節,請訪問 dev-discuss,有關 torch.compile 最佳化器的配方可在此處獲得。

效能改進

Inductor 效能最佳化

TorchInductor 中添加了多項效能最佳化,包括 torch.concat 的水平融合支援改進的卷積佈局最佳化以及改進的 scaled_dot_product_attention 模式匹配

有關 inductor 最佳化的完整列表,請參閱發行說明

aarch64 效能最佳化

PyTorch 2.2 包含多項針對 aarch64 的效能增強,包括支援 mkldnn 權重預打包、改進的 ideep 原始快取以及透過對 OneDNN固定格式核心改進 提高推理速度。

有關 aarch64 最佳化的完整列表,請參閱發行說明