Meta:Less Wright、Hamid Shojanazeri、Vasiliy Kuznetsov、Daniel Vega-Myhre、Gokul Nadathur、Will Constable、Tianyu Liu、Tristan Rice、Driss Guessous、Josh Fromm、Luca Wehrstedt、Jiecao Yu Crusoe:Ethan Petersen、Martin Cala、Chip Smith
透過與 Crusoe.AI 合作,我們獲得了他們位於冰島的新型 2K H200 叢集的訪問許可權。這使我們能夠利用 TorchTitan 的 HSDP2 和 TorchAO 新的 Float8 Rowwise,展示大規模訓練加速 34-43%,同時保持與 BF16 可比的收斂性和穩定性。
在這篇文章中,我們將詳細介紹 H200 與 PyTorch 新的 Float8 Rowwise 訓練在 TorchTitan 的 FSDP2/HSDP2 和大規模 CP 下的協同作用。
背景 – 什麼是 H200?
H200 是 H100 的“增強版”,提供與 H100 完全相同的計算能力,但有兩項額外改進。
- 更大的全域性記憶體,141GiB HBM3e,而標準為 80GiB HBM3。
- 記憶體頻寬快約 43%,達到 4.8TB/s,而標準為 3.35 TB/s。更快的記憶體傳輸對訓練速度有顯著影響,尤其是對於 PyTorch 的 AsyncTP。
什麼是 PyTorch Float8 Rowwise?
Float8 Rowwise 是 Float8 相對於之前“tensor wise”Float8 的更精細解析度。它旨在確保更精細的精度,以支援更大的工作負載,這些工作負載在規模擴大和訓練進展過程中往往對量化變得更敏感。
Float8 Rowwise 有兩個關鍵改進。
- 現在每行都保持自己的縮放因子,而不是整個張量使用單個縮放因子,從而提高了量化精度。每行更精細的縮放有助於減少異常值(極端值會迫使量化縮放因子拉伸並降低正態分佈值的精度)的影響,從而確保更好的精度。
- 縮放因子本身現在透過向下舍入到最接近的 2 的冪來實現。這已被證明有助於在乘以/除以縮放因子時減少量化誤差,並確保大值在正向和反向傳播中都縮放到相同的值。
請注意,其他大規模模型已使用 Float8 以 2K 規模進行訓練,結合了 1×128 groupwise 和 128×128 blockwise,並使用了 2 的冪次縮放因子。它們的目標相同,即提高 Float8 的精度以支援大規模訓練。
因此,Float8 Rowwise 提供了類似的承諾,可以在超大規模訓練中使用 Float8,但我們希望提供大規模穩定性和收斂性的證據,Crusoe H200 2K 叢集上的訓練為此提供了初步驗證。
展示 Float8 Rowwise 損失收斂與 BF16 在 1600 和 1920 GPU 規模下的比較
為了驗證可比的損失收斂性,我們使用 TorchTitan 和 Llama3 70B,在 1920 和 1600 (1.6k) GPU 規模下分別運行了兩次。1.6K GPU 執行設定為 2.5k 迭代,使用 TorchTitans 的 HSDP2 和 Context Parallel 來實現 2D 並行性。
損失收斂測試使用 Titan 的確定性模式執行——這種模式有效地凍結了每次執行之間大多數潛在的變異源,從而有助於確保唯一實質性的變化是我們想要測試的,即 BF16 與 Float8 Rowwise 的損失收斂和損失曲線。
請注意,確定性模式也會降低訓練速度,因為各種核心不會自動調優以最大化吞吐量(否則我們可能會在不同執行之間使用不同的核心並引入方差)。
完成了兩次執行,一次使用 BF16,另一次使用 Float8 Rowwise。
兩次執行都完成了分配的 2.5k 迭代,沒有出現問題,展示了 Crusoe 叢集的穩定性,FP8 在 24 小時內完成,BF16 在 31 小時 19 分鐘後完成。
| 資料型別 | 時間/迭代 | 損失 |
| BF16 | 24 小時 | 3.15453 |
| Float8 行式 | 24 小時 | 2.86386 |
| BF16 | 31 小時 19 分鐘 / 2.5K | 2.88109 |
| Float8 行式 | 24 小時 / 2.5K | 2.86386 |
在 24 小時標記處,Float8 完成了 2.5K 迭代,展示了 Float8 訓練的相對加速(即使在確定性模式下)。在 24 小時標記處,對於相同 24 小時的大規模訓練時間,Float8 相對於 BF16 在損失方面實現了 +9.21% 的相對改進。
31 小時 19 分鐘後,BF16 執行最終完成了 2.5k 迭代。
最終損失資料
BF16 = 2.88109 Float8 = 2.86386
從損失曲線來看,我們觀察到在前三分之一和後三分之一部分曲線非常相似,中間部分則出現了一個動盪區域,兩者都顯示出類似的峰值,但峰值的相對時間略有偏差。
因此,我們可以看到 PyTorch 的 Float8 Rowwise 提供了相似的收斂性,但在相同的訓練時間內,速度提高了 33% 以上。
Float8 行式長期訓練穩定性
除了展示可比的收斂性之外,我們還希望展示 Float8 的長期訓練穩定性,因此我們啟動了一個 4 天、15K 次執行,規模為 256。
如上所示,Float8 訓練運行了超過 100 小時,沒有出現任何問題,突顯了 Float8 Rowwise 的長期穩定性。
TorchTitan 中的確定性
為了驗證確定性並檢視較長時間執行中的尖峰是否來自規模,我們還進行了一次較小的執行,包括兩次 BF16 執行和一次 Float8 執行,規模為 256,並且僅使用 HSDP2(即不使用 2D Context Parallel)。
在這種情況下,兩次 BF16 執行都具有相同的曲線和最終損失,並且我們觀察到所有三次執行都出現了類似的尖峰區域。
在 2K 迭代標記處,Float8 和 BF16 都結束在幾乎相同的位置
BF16 *2 = 3.28538
Float8 行式 = 3.28203
上述結果證實,損失中的尖峰既不是 CP 也不是規模(2k)造成的,因為我們在 256 規模下也看到了類似的效果。損失尖峰最可能的原因可能是資料集中內容的分佈。
為了確定性,實驗使用序列化的 C4 資料集(未打亂)執行,這意味著尖峰可能來自資料集中遇到的新內容。
Float8 Rowwise 在各種規模下的實際加速
我們在不同的 GPU 規模下進行了較短的執行,以瞭解 Float8 Rowwise 在叢集規模擴大時訓練加速方面的表現。從 960 擴充套件到 1920,Float8 持續提供令人印象深刻的訓練加速,與 BF16 相比,效能提升幅度在 34-43% 之間。我們還想指出,從 1k 擴充套件到 2k GPU 時,通訊開銷可能開始顯現,我們觀察到 BF16 的吞吐量下降了 4%。
如上所述,在大規模更長的訓練執行中,Float8 Rowwise 提供了顯著的加速,並實現了相同甚至略有改進的損失終點,同時在 1920 (DeepSeek) 規模下實現了 34% 的加速。
如何在訓練中使用 Float8 Rowwise?
Float8 Rowwise 現已推出,可用於您的大規模訓練。它打包在 TorchAO 的最新構建(0.9 及更高版本)中,如果您想快速啟動並執行,它已原生整合到 TorchTitan 中。
在 TorchTitan 中啟用 Float8 Rowwise
首先在模型的 .toml 檔案中啟用模型轉換器,將 nn.linears 熱插拔為 float8 線性層 – 參見第 29 行

其次,指定“rowwise”float8 配方 – 參見第 72 行
請注意,“recipe_name”有三種選擇。
- rowwise,這是推薦的預設值,
- tensorwise(舊版 float8)和
- rowwise_with_gw_hp。
gw_hp 行式選項在反向傳播過程中將權重的梯度保持在 BF16 精度,這可以進一步提高 Float8 對極其敏感的工作負載的精度。但是,如果模型中大多數矩陣乘法的大小較小(在 H100 上估計臨界點約為 13-16K 維度),它反而可能比通用行式更具效能。
因此,雖然我們推薦將行式作為預設選項,但值得將其與 gw_hp 在您的模型上進行比較,以驗證哪種提供最佳效能,並有可能獲得更高的精度。
透過切換模型轉換器的開/關(使用 # ),您可以直接比較 BF16 和 Float8 Rowwise 之間的訓練加速,從而瞭解您自己訓練的潛在加速。
未來更新
我們將推出一項額外更新,展示管道並行和非同步分散式檢查點的多項改進,敬請期待。