跳轉到主要內容
部落格

使用 PyTorch FSDP 和 Torch.compile 最大化訓練吞吐量

作者: 2024 年 5 月 21 日2024 年 11 月 13 日暫無評論

最近,我們展示瞭如何使用 FSDP 和選擇性啟用檢查點,在 A100 GPU 上訓練 7B 模型時實現 57% 的 MFU(模型浮點運算利用率)。我們還展示瞭如何訓練出一個高質量模型,並將其作為 Granite 7B 基礎模型 在 Hugging Face Hub 上以 Apache v2.0 許可證開源。

我們繼續透過利用 torch.compile 來提高 GPU 利用率。結合 torch.compile 和我們之前工作中的選擇性啟用檢查點,我們在 A100 GPU 上為 7B 模型實現了 68% 的 MFU!torch.compile 將各種模型大小的訓練 MFU 提高了 10% 到 23%。

本部落格分為三個部分:(1)為使用 torch.compile 進行訓練而解決的挑戰,(2)compile 與 no-compile 的數值一致性,以及(3)MFU 報告。

我們已將所有程式碼開源並更新到 fms-fsdp 倉庫 中。我們還在與 Meta 的 PyTorch 團隊合作,將這些貢獻給新發布的 torch titan 倉庫,用於預訓練。

使用 torch.compile 的挑戰

torch.compile 是一種圖編譯技術,可以提高 GPU 利用率。有關 torch.compile 工作原理的詳細資訊,我們建議讀者參考最近的 PyTorch 論文和相關教程。讓 torch.compile 表現出色的一個關鍵挑戰是最大程度地減少(或消除)圖中斷。我們最初從 Meta 提供的 Llama 實現開始,但編譯它導致了太多圖中斷,從而降低了訓練吞吐量。

模型架構的幾個部分必須進行修改,其中最重要的是位置嵌入層(RoPE)。典型的 RoPE 實現使用複數,這在測試時 torch.compile 尚不支援。我們使用 einops 實現了 RoPE,同時保持與原始模型架構實現的一致性。我們必須正確快取頻率,這樣就不會在 RoPE 實現中遇到圖中斷。

編譯 FSDP 模型確實會導致圖中斷,Meta 的 PyTorch 團隊正在努力消除這些中斷。然而,截至 PyTorch 2.3,這些圖中斷髮生在 FSDP 單元邊界,並且不會顯著影響吞吐量。

使用自定義核心時,我們需要透過將其 API 暴露給 torch.compile 來包裝每個核心。這包括指示哪些引數是原地修改的、如何修改以及它們的返回值將根據輸入具有什麼形狀和步長。在我們的案例中,SDPA Flash attention 已適當整合,我們能夠讓該核心與 torch.compile 協同工作,且沒有圖中斷。

我們還注意到,當資料量從 2T 增加到 6T tokens 時,資料載入器成為了瓶頸。一個主要原因是,我們之前在資料載入器中天真地實現了文件混洗,每個 worker 都維護一個混洗文件指標列表。

對於更大的資料集,這些指標列表增長到每個 worker 數十萬條目。在這種規模下維護指標列表變得非常昂貴,以至於 CPU 爭用扼殺了我們的訓練吞吐量。我們使用 線性同餘生成器 重新實現了文件混洗,而無需任何指標列表。LCG 是一種偽隨機數生成器演算法,它實現了一個在群體上的隨機遊走,提供了無放回抽樣。

我們利用相同的思想來生成從有序到混洗文件索引的隱式雙射對映。這使我們能夠將那些煩人的數十萬指標列表縮小為 LCG 的單個整數狀態。這消除了 80% 的瓶頸,並顯著提升了我們的效能。我們將專門撰寫一篇部落格,詳細介紹我們高效能預訓練資料載入器的所有細節。

torch.compile 與 torch.no-compile 的數值一致性

我們之前在 compile 和 no-compile 選項訓練時觀察到一致性問題,其中一個與 SDPA 的使用有關。經過 Meta 和 IBM 的 PyTorch 團隊幾天緊張的除錯會話後,我們能夠在 PyTorch compile 和 no-compile 模式之間實現一致性。為了記錄和驗證這種一致性,我們採用一個 1.4B 大小的 mini-Llama 模型架構,並以四種變體對其進行 100B tokens 的訓練——no-compile,compile 不帶啟用檢查點,compile 帶選擇性啟用檢查點,以及 compile 帶完全啟用檢查點。

我們繪製了這些選項的損失曲線和梯度範數如下:

Figure 1: Loss curve and gradient norm for various compile options

圖 1:各種編譯選項的損失曲線和梯度範數

此外,我們運行了 lm-evaluation-harness,並比較了各種模型在不同基準上的分數,觀察到 compile 和 no-compile 之間沒有重大差異,如下所示。

Figure 2: lm-evaluation-harness comparison of various benchmarks between compile and no-compile

圖 2:lm-evaluation-harness 對 compile 和 no-compile 之間各種基準的比較

從所有這些結果中我們觀察到,compile 及其所有變體與 no-compile 選項相等,從而證明了 compile 和 no-compile 之間的一致性。

MFU 報告

最後,像我們之前的部落格一樣,我們在兩個叢集上計算了四種不同模型大小的 MFU。一個叢集是 128 個 A100 GPU,具有 400 Gbps 的節點間連線;另一個叢集是 464 個 H100 GPU,具有 3.2 Tbps 的節點間連線。除了 compile,我們還使用了 之前部落格 中介紹的選擇性啟用檢查點。結果記錄在下表中。

模型大小批次大小無編譯 MFU編譯 MFU增益百分比 (%)
7B20.570.6820
13B20.510.6017
34B20.470.5415
70B20.500.5510

表 1:Llama2 模型架構在 128 個 A100 80GB GPU 上,採用 400Gbps 節點間互連的編譯和非編譯 MFU 結果

模型大小批次大小無編譯 MFU編譯 MFU增益百分比
7B20.370.4521
13B20.350.4323
34B20.320.3819
70B20.320.3819

表 2:Llama2 模型架構在 464 個 H100 80GB GPU 上,採用 3.2Tbps 節點間互連的編譯和非編譯 MFU 結果

我們還在 448 個 GPU 上使用 Llama2 7B 架構進行了內部生產執行。使用 compile 和選擇性啟用檢查點,全域性批次大小為 3.7M,我們在 13 天 10 小時內訓練了 4T token!

在訓練期間,資料中心冷卻系統不得不啟動額外的空調,我們的訓練團隊也收到了警報,因為我們有效地使用了 GPU ☺

從表 1 和表 2 中一個關鍵的觀察是 MFU 數值並非隨模型大小線性擴充套件。我們正在積極調查兩種可能的解釋,一是隨著模型大小的增加 FSDP 的可伸縮性以及何時需要啟用張量並行以更有效地使用 GPU,二是批次大小,可以進一步增加以獲得更好的 MFU。我們計劃探索 FSDP v2 和選擇性運算子檢查點以及張量並行功能,以研究 FSDP 隨模型大小的縮放定律。

未來工作

我們計劃開始測試 FSDP v2,它將作為 PyTorch 2.4 的一部分發布。FSDP2 提供了按引數分片和選擇性運算子檢查點功能,這可能會提供更好的記憶體-計算權衡。

我們還與 Meta 的 PyTorch 團隊合作,評估新的非同步檢查點功能,該功能可以透過減少寫入檢查點的時間來進一步提高 GPU 利用率。

我們正在探索擴充套件目前在推理中使用的各種 Triton 核心,以執行反向操作,從而獲得超越僅推理的加速。

最後,隨著 fp8 使用的最新工作不斷湧現,我們計劃探索如何使用這種承諾 2 倍加速的新資料型別進一步加速模型訓練。

致謝

有幾個團隊參與了實現這一證明點,我們要感謝 Meta 和 IBM 的所有團隊。特別地,我們向 Meta PyTorch 分散式和編譯器團隊以及 IBM Research 表示感謝。

多位人員廣泛參與了實現 torch.compile 與我們模型數值一致性的工作,我們希望感謝參與這項工作的關鍵人員;Meta 的 Animesh Jain 和 Less Wright,以及 IBM Research 的 Linsong Chu、Davis Wertheimer、Brian Vaughan、Antoni i Viros Martin、Mudhakar Srivatsa 和 Raghu Ganti。

特別感謝 Stas Bekman,他提供了大量反饋並幫助改進了這篇部落格。他們的見解對於突出最佳化訓練和探索進一步增強的關鍵方面非常有價值。