快捷方式

FSDP 筆記

FSDP 預取細節

為了將 forward 的 all-gather 與 forward 計算重疊,有兩種可能的機制

  1. 隱式前向預取(始終啟用)

  2. 顯式前向預取(forward_prefetch=True

隱式 forward 預取是指依賴於從單獨的 CUDA 流發出 all-gather,以允許 all-gather 與之前發出的(從 CPU 視角看)forward 計算重疊。例如,如果我們有第 0 層 all-gather -> 第 0 層 forward 計算 -> 第 1 層 all-gather -> …,那麼即使 CPU 執行緒在其後發出,第 1 層 all-gather 也可以與第 0 層 forward 計算重疊。(第一個 all-gather 無法與任何其他操作重疊。)

顯式 forward 預取是指改變 CPU 執行緒的發出順序:例如,第 0 層 all-gather -> 第 1 層 all-gather -> 第 0 層 forward 計算 -> …。在 eager 模式下,通常無法在執行第 0 層時知道下一層(例如示例中的第 1 層)是哪一層。因此,顯式 forward 預取只能用於執行順序在迭代之間固定的模型(有時稱為“靜態圖”)。不滿足此約束的模型示例包括 FLAVA)。

顯式 forward 預取僅節省發出層級 forward 計算核函式所需的時間,代價是當前 all-gather 的輸出張量仍在使用時,必須分配下一個 all-gather 的輸出張量。透過在當前 forward 計算核函式之前發出下一個 all-gather,下一個 all-gather 可以在 GPU 上更快地開始。對於大多數 LLM 工作負載而言,情況並非如此,因此沒有理由啟用 forward_prefetch=True

相比之下,對於 backward,必須使用顯式 backward 預取,否則通訊和計算將完全沒有重疊。原因在於我們使用單個 NCCL 程序組進行 all-gather 和 reduce-scatter(部分原因是早期 NCCL 版本在同一裝置上使用相同的 ranks 併發使用多個程序組是不安全的)。單個 NCCL 程序組意味著 reduce-scatter 和 all-gather 在單個內部 NCCL 流上序列執行。因此,除非我們明確重新排序 CPU 發出順序為下一個 all-gather -> 當前 reduce-scatter,否則當前 reduce-scatter 將阻塞下一個 all-gather,從而阻塞下一個 backward 計算,阻止當前 reduce-scatter 的重疊。

通訊負載大小

在 FSDP 中,通訊包括

  1. forward 中的引數 all-gather

  2. backward 中的引數 all-gather

  3. backward 中的梯度 reduce-scatter

如果使用啟用檢查點 (checkpoint()),則沒有額外的通訊,因為引數在 backward 期間 anyway 會被預取。

在 FSDP 設計中,每個 rank 的通訊負載確定如下:每次呼叫 FullyShardedDataParallel 都會建立一個通訊組,該組由 module.parameters() 中的引數組成,但已分配給巢狀 FullyShardedDataParallel 例項的引數除外。例如,對於 Llama,如果您對每個 Transformer 塊以及根模組都應用 FullyShardedDataParallel,那麼每個 Transformer 塊都有一個通訊組,最後根模組有一個包含初始嵌入和最終線性層的通訊組。每個通訊組對應於一次 all-gather 呼叫和一次 reduce-scatter 呼叫。因此,您如何應用 FullyShardedDataParallel 決定了通訊大小。總的來說,對每個 Transformer 塊應用 FSDP 是 LLM 的一個好的啟發式方法,並且鑑於當前的設計,很難做得更好。

考慮一個例子,我們有一個基於 Transformer 的模型在 8 個 GPU 上分片,分片僅在 Transformer 塊級別進行,每個 Transformer 塊包含 1.6B 引數,引數為 fp32(每個 4 位元組)。這意味著分片後,每個 Transformer 塊在每個 rank 上包含 0.2B 引數。

  • forward 通訊將以 0.2*4 = 0.8GB 為塊進行 all-gather

  • backward 通訊將進行 2 次,每次 0.8GB(1 次 all-gather 和 1 次 reduce-scatter)

換句話說,總共有 3 次通訊,每次負載 0.8GB。如果模型包含 10 個 Transformer 塊,則總共有 30 次通訊,總計 30*0.8=24GB

正式地說,每次通訊每個 rank 的負載大小為 total_transformer_block_params_in_B*dtype_bytes/num_gpus (GBs)。

請注意,在此示例中,我們未包含嵌入所需的額外通訊,這也應予以考慮。並且計算方式取決於輸入和輸出嵌入是否繫結。如果未繫結,則通訊次數將是 2 倍。

FSDP 緩衝區大小

首先,讓我們來看為通訊分配的緩衝區

forward 目前需要 2 倍 all-gather 緩衝區大小。原因如下:

FSDP 預取細節 中所解釋的,在顯式 forward 預取(forward_prefetch=True)的情況下,即第 0 層 all-gather -> 第 0 層 forward 計算 -> 第 1 層 all-gather,需要 2 個 all-gather 大小的緩衝區,因為一個緩衝區用於當前的 forward,而另一個用於進行預取。

儘管隱式 forward 預取(forward_prefetch=False,預設)在理論上只需要 1 個緩衝區,但實際上仍需要 2 倍 all-gather 大小的緩衝區。原因在於,在扁平引數 FSDP 設計中,我們不會從 all-gather 緩衝區中複製出來。用於計算的引數直接作為 all-gather 緩衝區的檢視(實際上,“扁平引數”的主要好處正是這個原因)。在這種情況下,儘管“第 1 層 all-gather”與“第 0 層 forward 計算”重疊,但“第 0 層 forward 計算”正在使用作為“第 0 層 all-gather”緩衝區檢視的引數。

那麼一個很自然的問題是,何時會需要 forward_prefetch=False?對於靜態圖模型(如大多數 LLM),有一個主要的技術原因。更實際地說,我們快速為一些 CPU 繫結的內部模型添加了此選項,並且尚未在單元測試中測試其所有程式碼路徑,因此我們對其信心不足。forward_prefetching=False 可能更容易理解,因為我們不必檢查記錄的前向順序作為可能的“故障模式”;模組的 all-gather 始終可以在其自己的 record_function 標籤下在其 profiler 跟蹤中找到。

backward 目前至少需要 2 倍 all-gather 緩衝區大小,並且可能更多。原因如下:

當前的 FSDP 設計使用 recordStream 來管理在一個流中生成並在另一個流中使用的分配,這可能導致比預期更多的記憶體使用。增加多少可能“不確定”,因為它取決於 GPU 核函式計時相對於 CPU 的情況。limit_all_gathers=True 引數是對此的緩解措施 - 有關更多詳細資訊,請參閱此討論 FSDP & CUDACachingAllocator

現有 FSDP 與自動求導的協作方式

  • 現有 FSDP 對 flat_param 執行 all-gather 操作,flat_param 是自動求導的葉節點。

  • 它呼叫 torch.split 以獲取 flat_param 中與其組成的原始引數對應的 1D 檢視。

  • 它對每個 1D 分割呼叫 torch.view 以將其檢視恢復為 ND。

  • 這意味著在 backward 中,我們最終得到 ViewBackward(ND -> 1D)和 SplitWithSizesBackward(它是一個 concat)。特別是,每個單獨的梯度都計算為一個單獨的分配,並且會發生顯式 concat 來構建 reduce-scatter 輸入緩衝區。這意味著在該記憶體峰值點,reduce-scatter 的緩衝區大小實際上是 2 倍。

總而言之,對於 backward,緩衝區大小大約是 reduce-scatter 的 2 倍,再加上任何 recordStream 的影響。

其次,讓我們討論額外的緩衝區

一旦從所有 rank 收集了分片引數,它們需要一個額外的緩衝區來儲存完整引數,大小為 total_transformer_block_params_in_B*dtype_bytes - 所以繼續前面的例子,如果每個 Transformer 塊是 1.6B 引數,引數是 fp32,那麼緩衝區大小將是 1.6*4=6.4GB

需要兩個這樣的緩衝區,因為一個當前正在使用,另一個正在被預取。

總結一下,我們有

  1. 2 倍於 total_transformer_block_params_in_B*dtype_bytes/num_gpus 的通訊緩衝區

  2. 2 倍於未分片 Transformer 塊引數緩衝區 ``total_transformer_block_params_in_B*dtype_bytes

或者按照前面的例子

  1. 2*1.6*4/8=1.6GB

  2. 2*1.6*4=12.8GB

總計 14.4GB

現在讓我們簡要討論一下嵌入層會發生什麼,因為我們之前的計算中忽略了它們

根據我們討論過的規則,即在筆記中以“通訊緩衝區大小確定如下”開頭的部分,我們可以進行如下分析

  • 假設我們將 FSDP 應用於根模組(例如 Transformer 類)。假設我們進一步將 FSDP 應用於每個 Transformer 塊(例如 TransformerBlock 類)。

  • 通常,嵌入層和最終線性投影是根 Transformer 類的直接子模組。

  • 根據我們的規則,這意味著嵌入層和最終線性投影被分配給根 Transformer 的扁平引數。

  • 我們有_另一個_特殊規則,即根模組在前向傳播後不會釋放其引數,因為無論如何它們都會在反向傳播中立即進行 all-gather。

  • 綜合來看,這意味著包含嵌入層和最終投影層的根模組扁平引數在開始前向傳播時進行 all-gather,並保留在 GPU 記憶體中直到反向傳播結束。

  • 如果嵌入層和最終線性層不繫結權重,那麼我們_可以_進一步將 FSDP 應用於嵌入層和最終線性層。對於繫結權重的引數,我們要求它們屬於同一個扁平引數(否則會重複計數)。這將允許嵌入層在其前向使用後被釋放,並僅在反向傳播結束時進行 all-gather。

  • 希望這能提供更好的理解——每個 FSDP 模組都會被分配其 module.parameters 中的引數,除非這些引數已被分配給另一個巢狀的 FSDP 模組,並且 FSDP 模組的 forward 定義了其引數的“活躍”區間。因此,巢狀的 nn.Module 結構會影響 all-gather/free 排程,從而影響記憶體/吞吐量效能。

文件

訪問 PyTorch 的全面開發者文件

檢視文件

教程

獲取面向初學者和高階開發者的深度教程

檢視教程

資源

查詢開發資源並獲得問題解答

檢視資源