跳轉到主要內容
部落格

使用 float8 和 FSDP2 超級訓練

作者: 2024年11月25日2025年5月5日暫無評論

IBM:Tuan Hoang Trong、Alexei Karve、Yan Koyfman、Linsong Chu、Divya Kumari、Shweta Salaria、Robert Walkup、Praneet Adusumilli、Nirmit Desai、Raghu Ganti、Seetharami Seelam
Meta:Less Wright、Wei Feng、Vasiliy Kuznetsov、Driss Guesseous

在這篇部落格中,我們將展示如何透過利用 FSDP2、DTensor 和帶 torchao float8 的 torch.compile(透過線性層更新(計算)和 float8 all_gathers 進行權重通訊),在訓練中實現高達 50% 的吞吐量加速,同時在損失和評估基準方面與 FSDP1 bf16 訓練持平。我們在各種 Meta LLaMa 模型架構尺寸上展示了這些改進,從小型 1.8B 模型到 405B 模型,使訓練比以往任何時候都更快。

我們使用 Meta Llama3 架構演示了這些改進,然後在兩個規模上進行了模型質量研究:8B 模型規模下 100B 詞元,以及 70B 模型規模下 50B 詞元,這提供了 float8 和 bf16 訓練損失曲線的精確比較。我們證明,與 bf16 對應的模型訓練執行相比,這些損失曲線導致了相同的損失收斂。此外,我們使用 FineWeb-edu 資料集將一個 3B 模型訓練到 1T 詞元,並執行標準評估基準,以確保模型質量完好無損,並與 bf16 執行相當。

在 IBM Research,我們計劃採用這些功能進行資料消融,以在給定的 GPU 預算內增加我們可以執行的實驗數量。從長遠來看,我們將進行更大規模的模型執行,以演示 float8 訓練的端到端可行性。

什麼是 Float8?

用於訓練模型的 float8 格式由 NVIDIA、ARM 和 Intel 在 2022 年的一篇論文中引入,該論文證明了使用更低精度的 float8 訓練的可行性,而不會犧牲模型質量。隨著 NVIDIA Hopper 系列等新型 GPU 的引入,FP8 訓練變得可行,由於原生 float8 Tensor Core 支援,訓練吞吐量有可能提高 2 倍以上。實現這一承諾存在一些挑戰:
(i) 在 float8 中啟用核心模型操作,如 matmulattention
(ii) 在分散式框架中啟用 float8 訓練,以及
(iii) 在 float8 中啟用 GPU 之間的權重通訊。
雖然 float8 matmul 已由 NVIDIA 庫啟用,但後兩者已在 FSDP2torchao 的最新更新中提供。

在這篇部落格中,我們使用 torchtitan 作為訓練的入口點,IBM 的確定性資料載入器,來自 torchaofloat8 線性層實現,以及與 FSDP2 結合的最新 PyTorch nightly 版本中的 float8 all gather。對於此次訓練,我們使用 float8 逐張量(tensorwise)縮放粒度,而不是逐行。我們利用 torch.compile 來確保獲得最大的效能增益。我們正在使用 SDPA 以 bf16 格式計算 attention,目前也正在努力將其轉換為 float8。

實驗

我們進行了各種實驗來證明 float8 訓練的優勢。首先是為了確保模型質量不被犧牲。為了驗證這一點,我們訓練了一個 8B 模型和一個 70B 模型幾千步,並比較了 float8 和 bf16 訓練執行之間的損失曲線。我們的實驗在三個不同的 H100 叢集上進行,分別配置了 128、256 和 512 個 H100 GPU,在非常不同的環境中進行,以證明其可復現性。第一個叢集是在 Meta 的 Grand Teton 上定製的,採用 400Gbps 定製互連;第二個是 IBM 研究叢集,採用 3.2Tbps InfiniBand 互連;第三個是 IBM Cloud 叢集,採用 3.2Tbps RoCE 互連用於 GPU 間通訊。

首先,我們繪製了這兩個模型在下面圖表中的損失曲線比較,以演示幾千步的損失持平。

Figure 1: (a) 8B model loss parity for 2k steps, (b) 70B loss parity for 1k steps
Figure 1: (a) 8B model loss parity for 2k steps, (b) 70B loss parity for 1k steps

圖 1:(a) 8B 模型 2k 步損失持平,(b) 70B 模型 1k 步損失持平

我們觀察到,在這些不同的模型和不同環境中,我們獲得了小規模詞元的損失持平。接下來,我們描述了從 1.8B 到 405B 四種不同模型尺寸的吞吐量增益。我們探索了 float8 和 bf16 訓練執行的最佳批次大小和啟用檢查點方案,以確定 tokens/sec/GPU (wps) 指標並報告效能增益。對於 405B 模型,我們利用 DTensor 進行 FSDP2 的張量並行訓練。我們所有測量都使用 8K 的序列長度。

模型大小 wps (bf16) wps (float8) 增益百分比
1.8B 29K 35K 18%
8K 8K 10K 28%
70B 956 1430 50%
405B (TP4) 149 227 52%

表 1:相對於 bf16 的效能增益(bf16 和 float8 都使用 torch.compile)

我們從表 1 中觀察到,較大模型(70B 和 405B)的增益高達 50%,而較小模型的增益在 20% 到 30% 之間。在進一步的實驗中,我們觀察到 float8 all_gather 的新增在 float8 計算本身之外還能帶來約 5% 的提升,這與 這篇部落格中的觀察結果一致。

其次,為了證明 FP8 模型的有效性,我們使用 Hugging Face 的 FineWeb-edu 資料集,按照 Llama3 架構訓練了一個 3B 模型,共 1T 詞元。我們使用 lm-eval-harness 框架進行了評估,並在下表中展示了這些結果的一小部分。我們觀察到 bf16 的效能略優於 float8 的分數(大約百分之一)。雖然某些分數在 bf16 下明顯更好(例如,MMLU 高出 3 分),但我們預計在選擇正確的超引數和更大規模的訓練執行中,這些差距會消失(例如,bf16 執行的批次大小減半,眾所周知,較小的批次大小執行可以提高評估分數)。

基準 分數 (float8) 分數 (bf16)
MMLU (5-shot) 0.26 0.29
ARC-e 0.73 0.73
ARC-c 0.43 0.46
Hellaswag 0.65 0.67
sciq 0.89 0.88
OpenBook QA 0.43 0.43
PIQA 0.76 0.76
Winogrande 0.60 0.65
平均 0.59 0.60

表 2:在 FP16 中進行評估的 float8 訓練模型(在 1T 詞元的 FineWeb 預訓練後)的基準分數。

最後,我們將實驗規模擴大到 IBM Cloud 叢集上的 512 個 H100 GPU。即使在 512 GPU 規模下,我們也能夠重現我們觀察到的結果和加速。我們僅在下表中總結了大型模型(70B 和 405B)的這些結果。

模型大小 wps (bf16) wps (float8) 增益百分比
70B 960 1448 51%
405B (TP4) 152 217 43%

表 3:在 512 GPU 規模下相對於 bf16 的效能增益(bf16 和 float8 都使用 torch.compile)

未來工作

我們還在評估其他形式的並行性,例如上下文並行性。我們計劃評估所有這些功能,以展示大規模模型訓練的可組合性和選擇能力。

致謝

我們感謝 IBM Research 的 Davis Wertheimer 啟用了 torchtitan 執行的資料載入器,使我們能夠在多次執行中以相同的順序重放資料。我們還要感謝 IBM Cloud 為我們提供了 H100 叢集的早期測試訪問。