
去年,IBM 研究院開始與我們合作,為他們的大型基礎模型引入完全分片資料並行 (FSDP)。他們對此產生了興趣,因為 FSDP 是 PyTorch 原生提供的解決方案,用於在 IBM Cloud 上擴充套件他們的分散式訓練工作。
我們很高興地宣佈,透過與 IBM 的合作,我們已經為大型模型實現了顯著的檢查點加速(與原始 PyTorch 1.13 的儲存速度相比,提升了 72 倍),證明了模型和最佳化器檢查點可以擴充套件到 30 億引數,並支援使用 FSDP + 分散式檢查點在 S3 後端進行雲優先訓練。
什麼是分散式檢查點?
分散式檢查點是 PyTorch 原生解決方案,用於從多個 rank 儲存和載入 PyTorch 模型和最佳化器狀態,並支援在重新載入之間動態更改世界大小。

PyTorch 分散式檢查點 (DCP) API 在 PyTorch 1.13 中引入,並作為官方原型功能包含在 PyTorch 2.0 中。
分散式檢查點與 torch.save() 和 torch.load() 在幾個重要方面有所不同:
- DCP 每個檢查點生成多個檔案,每個 rank 至少一個檔案。
- DCP 原地操作,這意味著模型應首先分配其資料,然後分散式檢查點將使用該儲存空間。
從 1.13 到 2.0 的一個主要改進是增加了對分片 state_dict 的支援,用於檢查點 FSDP 模型。這允許對更大尺寸的模型進行檢查點,並增加了對載入時重新分片的支援。載入時重新分片允許在一個叢集拓撲中儲存,然後載入到另一個叢集拓撲中。此功能需求很高,因為它允許在某一個叢集上執行訓練作業,儲存後,然後可以在具有不同世界大小的不同叢集上繼續。
另一個主要變化是,我們將儲存層與檢查點規劃層解耦,併為這兩層分離了實現和介面。透過此更改,使用者現在可以在檢查點規劃階段指定其 state_dict 應如何分塊或轉換。此外,可定製的儲存層可以輕鬆適應不同的後端。
有關分散式檢查點包的更多資訊可以在此處找到。
與 IBM 合作,在生產環境中實現高效能分散式檢查點
IBM 在 Think 2023 上宣佈了其針對企業基礎模型開發和部署的 watsonx.ai 平臺。該平臺建立在混合雲之上,支援跨多種模式的用例,如自然語言處理、時間序列、天氣、化學、表格資料和網路安全,模型大小從數億到數百億引數不等。模型架構包括視覺 Transformer、多模態 RoBERTa 風格的特徵提取器,以及類似於 T5、GPT 和 Llama 的大規模生成式語言模型。
截至今天,IBM 已為 T5 風格的架構(最高 110 億引數)和解碼器架構(GPT 風格,最高 300 億引數)啟用了檢查點。
IBM 幫助我們認識到,這限制了 DCP 從記憶體和效能角度的擴充套件能力。根據他們的建議,我們增強了 FileSystemWriter,使其每個 rank 生成單個檢查點,以減少讀寫開銷。
有了這個新預設選項,DCP 現在在檢查點儲存期間為每個 rank 建立一個檔案,然後在載入時讀取引數時進行切片。
透過將 sharded_state_dict 支援與每個 rank 的單個檔案寫入器相結合,分散式檢查點能夠將檢查點儲存時間加速 72 倍以上(與原始 PyTorch 1.13 的儲存速度相比),並使超過 150 億引數的模型能夠進行快速檢查點,而這些模型以前會簡單地超時。
“回想起來,我們所看到的這些模型的訓練速度提升著實令人震驚。我們從 PyTorch 1.13 中編寫一個 110 億引數的檢查點需要將近半小時,到現在能夠處理一個 300 億引數的模型,包括最佳化器和資料載入器狀態——這相當於超過八倍的原始資料——僅需 3 分鐘多一點。這極大地提高了我們作業的穩定性和效率,因為我們將訓練擴充套件到數百個 GPU。”—— Davis Wertheimer,IBM 研究院
IBM 的採用也幫助我們在真實的、大規模的訓練環境中驗證和改進了我們的解決方案。例如,IBM 發現 DCP 在單個節點上使用多個 GPU 時執行良好,但在多個節點上使用時卻出錯。
在調查此問題時,我們意識到我們假設寫入 NFS 類的共享檔案系統,這假定強讀寫一致性。帶有檔案系統 API 的物件儲存(例如 S3FS)提供最終一致性語義,因此導致在此類設定中的分散式檢查點失敗。我們與 IBM 合作,發現了這個問題並透過 一行程式碼更改 修復了它,併為 DCP 啟用了物件儲存後端!這種儲存方法通常比共享檔案系統便宜一個數量級,因此可以實現更細粒度的檢查點。
尋求合作
如果您有興趣嘗試分散式檢查點,請隨時與我們聯絡!
如果您在嘗試時遇到任何問題,可以在我們的 Github 倉庫中提出問題。
致謝
如果沒有許多合作者的幫助,這個專案是不可能實現的。我們要感謝 Yanli Zhao、Andrew Gu、Rohan Varma 對 FSDP 的支援。感謝 Pritam Damania、Junjie Zhao 和 Wanchao Liang 對 ShardedTensor 的支援。