IBM:Tuan Hoang Trong、Alexei Karve、Yan Koyfman、Linsong Chu、Divya Kumari、Shweta Salaria、Robert Walkup、Praneet Adusumilli、Nirmit Desai、Raghu Ganti、Seetharami Seelam
Meta:Less Wright、Wei Feng、Vasiliy Kuznetsov、Driss Guesseous
在這篇部落格中,我們將展示如何透過利用 FSDP2、DTensor 和帶 torchao float8 的 torch.compile(透過線性層更新(計算)和 float8 all_gathers 進行權重通訊),在訓練中實現高達 50% 的吞吐量加速,同時在損失和評估基準方面與 FSDP1 bf16 訓練持平。我們在各種 Meta LLaMa 模型架構尺寸上展示了這些改進,從小型 1.8B 模型到 405B 模型,使訓練比以往任何時候都更快。
我們使用 Meta Llama3 架構演示了這些改進,然後在兩個規模上進行了模型質量研究:8B 模型規模下 100B 詞元,以及 70B 模型規模下 50B 詞元,這提供了 float8 和 bf16 訓練損失曲線的精確比較。我們證明,與 bf16 對應的模型訓練執行相比,這些損失曲線導致了相同的損失收斂。此外,我們使用 FineWeb-edu 資料集將一個 3B 模型訓練到 1T 詞元,並執行標準評估基準,以確保模型質量完好無損,並與 bf16 執行相當。
在 IBM Research,我們計劃採用這些功能進行資料消融,以在給定的 GPU 預算內增加我們可以執行的實驗數量。從長遠來看,我們將進行更大規模的模型執行,以演示 float8 訓練的端到端可行性。
什麼是 Float8?
用於訓練模型的 float8 格式由 NVIDIA、ARM 和 Intel 在 2022 年的一篇論文中引入,該論文證明了使用更低精度的 float8 訓練的可行性,而不會犧牲模型質量。隨著 NVIDIA Hopper 系列等新型 GPU 的引入,FP8 訓練變得可行,由於原生 float8 Tensor Core 支援,訓練吞吐量有可能提高 2 倍以上。實現這一承諾存在一些挑戰:
(i) 在 float8 中啟用核心模型操作,如 matmul 和 attention,
(ii) 在分散式框架中啟用 float8 訓練,以及
(iii) 在 float8 中啟用 GPU 之間的權重通訊。
雖然 float8 matmul 已由 NVIDIA 庫啟用,但後兩者已在 FSDP2 和 torchao 的最新更新中提供。
在這篇部落格中,我們使用 torchtitan 作為訓練的入口點,IBM 的確定性資料載入器,來自 torchao 的 float8 線性層實現,以及與 FSDP2 結合的最新 PyTorch nightly 版本中的 float8 all gather。對於此次訓練,我們使用 float8 逐張量(tensorwise)縮放粒度,而不是逐行。我們利用 torch.compile 來確保獲得最大的效能增益。我們正在使用 SDPA 以 bf16 格式計算 attention,目前也正在努力將其轉換為 float8。
實驗
我們進行了各種實驗來證明 float8 訓練的優勢。首先是為了確保模型質量不被犧牲。為了驗證這一點,我們訓練了一個 8B 模型和一個 70B 模型幾千步,並比較了 float8 和 bf16 訓練執行之間的損失曲線。我們的實驗在三個不同的 H100 叢集上進行,分別配置了 128、256 和 512 個 H100 GPU,在非常不同的環境中進行,以證明其可復現性。第一個叢集是在 Meta 的 Grand Teton 上定製的,採用 400Gbps 定製互連;第二個是 IBM 研究叢集,採用 3.2Tbps InfiniBand 互連;第三個是 IBM Cloud 叢集,採用 3.2Tbps RoCE 互連用於 GPU 間通訊。
首先,我們繪製了這兩個模型在下面圖表中的損失曲線比較,以演示幾千步的損失持平。


圖 1:(a) 8B 模型 2k 步損失持平,(b) 70B 模型 1k 步損失持平
我們觀察到,在這些不同的模型和不同環境中,我們獲得了小規模詞元的損失持平。接下來,我們描述了從 1.8B 到 405B 四種不同模型尺寸的吞吐量增益。我們探索了 float8 和 bf16 訓練執行的最佳批次大小和啟用檢查點方案,以確定 tokens/sec/GPU (wps) 指標並報告效能增益。對於 405B 模型,我們利用 DTensor 進行 FSDP2 的張量並行訓練。我們所有測量都使用 8K 的序列長度。
| 模型大小 | wps (bf16) | wps (float8) | 增益百分比 |
| 1.8B | 29K | 35K | 18% |
| 8K | 8K | 10K | 28% |
| 70B | 956 | 1430 | 50% |
| 405B (TP4) | 149 | 227 | 52% |
表 1:相對於 bf16 的效能增益(bf16 和 float8 都使用 torch.compile)
我們從表 1 中觀察到,較大模型(70B 和 405B)的增益高達 50%,而較小模型的增益在 20% 到 30% 之間。在進一步的實驗中,我們觀察到 float8 all_gather 的新增在 float8 計算本身之外還能帶來約 5% 的提升,這與 這篇部落格中的觀察結果一致。
其次,為了證明 FP8 模型的有效性,我們使用 Hugging Face 的 FineWeb-edu 資料集,按照 Llama3 架構訓練了一個 3B 模型,共 1T 詞元。我們使用 lm-eval-harness 框架進行了評估,並在下表中展示了這些結果的一小部分。我們觀察到 bf16 的效能略優於 float8 的分數(大約百分之一)。雖然某些分數在 bf16 下明顯更好(例如,MMLU 高出 3 分),但我們預計在選擇正確的超引數和更大規模的訓練執行中,這些差距會消失(例如,bf16 執行的批次大小減半,眾所周知,較小的批次大小執行可以提高評估分數)。
| 基準 | 分數 (float8) | 分數 (bf16) |
| MMLU (5-shot) | 0.26 | 0.29 |
| ARC-e | 0.73 | 0.73 |
| ARC-c | 0.43 | 0.46 |
| Hellaswag | 0.65 | 0.67 |
| sciq | 0.89 | 0.88 |
| OpenBook QA | 0.43 | 0.43 |
| PIQA | 0.76 | 0.76 |
| Winogrande | 0.60 | 0.65 |
| 平均 | 0.59 | 0.60 |
表 2:在 FP16 中進行評估的 float8 訓練模型(在 1T 詞元的 FineWeb 預訓練後)的基準分數。
最後,我們將實驗規模擴大到 IBM Cloud 叢集上的 512 個 H100 GPU。即使在 512 GPU 規模下,我們也能夠重現我們觀察到的結果和加速。我們僅在下表中總結了大型模型(70B 和 405B)的這些結果。
| 模型大小 | wps (bf16) | wps (float8) | 增益百分比 |
| 70B | 960 | 1448 | 51% |
| 405B (TP4) | 152 | 217 | 43% |
表 3:在 512 GPU 規模下相對於 bf16 的效能增益(bf16 和 float8 都使用 torch.compile)
未來工作
我們還在評估其他形式的並行性,例如上下文並行性。我們計劃評估所有這些功能,以展示大規模模型訓練的可組合性和選擇能力。
致謝
我們感謝 IBM Research 的 Davis Wertheimer 啟用了 torchtitan 執行的資料載入器,使我們能夠在多次執行中以相同的順序重放資料。我們還要感謝 IBM Cloud 為我們提供了 H100 叢集的早期測試訪問。