FSDP 筆記¶
FSDP 預取細微差別¶
對於與 forward 計算重疊的 forward all-gather,有兩種可能的機制
- 隱式正向預取(始終啟用) 
- 顯式正向預取 ( - forward_prefetch=True)
隱式 forward 預取是指依賴從單獨的 CUDA 串流發出 all-gather,以便允許與在其之前發出的 forward 計算(從 CPU 的角度來看)重疊 all-gather。例如,如果我們有第 0 層 all-gather -> 第 0 層 forward 計算 -> 第 1 層 all-gather -> …,則第 1 層 all-gather 可以與第 0 層 forward 計算重疊,即使 CPU 執行緒是在之後發出的。(第一個 all-gather 將無法與任何東西重疊。)
顯式 forward 預取是指更改 CPU 執行緒的發出順序:例如,第 0 層 all-gather -> 第 1 層 all-gather -> 第 0 層 forward 計算 -> …。在急切模式下,在仍在第 0 層上執行時,通常無法知道下一層是哪一層(例如,範例中的第 1 層)。因此,顯式 forward 預取應僅用於執行順序在每次迭代中都固定的模型(我們有時稱之為「靜態圖」)。不滿足此約束的模型範例是 FLAVA)。
顯式 forward 預取僅節省發出層的 forward 計算核心所需的時間,但代價是在當前 all-gather 的輸出張量仍在使用時必須分配下一個 all-gather 的輸出張量。透過在下一個 all-gather 之前發出當前的 forward 計算核心,下一個 all-gather 可以在 GPU 上更快開始。對於大多數 LLM 工作負載,情況並非如此,因此沒有動機啟用 forward_prefetch=True。
相比之下,對於 backward,我們必須使用顯式 backward 預取,否則通訊和計算將 0 重疊。原因是因為我們對 all-gather 和 reduce-scatter 都使用單個 NCCL 程序組(部分原因是在早期的 NCCL 版本中,在相同的裝置上針對相同的等級同時使用多個程序組是不安全的)。單個 NCCL 程序組意味著單個內部 NCCL 串流,reduce-scatter 和 all-gather 在其上串聯執行。因此,除非我們明確地將 CPU 發出順序重新排序為下一個 all-gather -> 當前的 reduce-scatter,否則當前的 reduce-scatter 將阻塞下一個 all-gather,從而阻塞下一個 backward 計算,從而防止當前的 reduce-scatter 重疊。
通訊負載大小¶
在 FSDP 中,通訊是
- 在 - forward中對參數進行 all-gather
- 在 - backward中對參數進行 all-gather
- 在 - backward中對梯度進行 reduce-scatter
如果使用激活檢查點(checkpoint()),則不會有額外的通訊,因為在 backward 期間會預取參數。
在 FSDP 設計中,每個 rank 的通訊負載大小如下決定:每次呼叫 FullyShardedDataParallel 都會建立一個通訊群組,其中包含 module.parameters() 中的參數,但已分配給巢狀 FullyShardedDataParallel  執行個體的參數除外。例如,對於 Llama,如果您將 FullyShardedDataParallel 套用到每個 transformer 區塊以及根模組,則每個 transformer 區塊都有一個通訊群組,最後一個通訊群組包含初始嵌入和最終線性層。每個通訊群組對應於一個 all-gather 呼叫和一個 reduce-scatter 呼叫。這樣,您套用 FullyShardedDataParallel 的方式決定了通訊大小。一般來說,將 FSDP 套用到每個 transformer 區塊對於 LLM 來說是一個很好的啟發式方法,而且在目前的設計下很難做得更好。
讓我們考慮一個範例,我們有一個基於 Transformer 的模型,分佈在 8 個 GPU 上,其中分片僅發生在 transformer 區塊級別,每個 transformer 區塊包含 1.6B 個參數,並且參數採用 fp32 格式(每個 4 位元組)。這意味著分片後,每個 transformer 區塊在每個 rank 上將包含 0.2B 個參數。
- forward傳遞將以- 0.2*4 = 0.8GB的區塊進行 all-gather 通訊
- backward傳遞將通訊 2 次,每次- 0.8GB(1 次 all-gather 和 1 次 reduce-scatter)
換句話說,將有 3 次通訊,每次的負載為 0.8GB。如果模型由 10 個 transformer 區塊組成,則總共會有 30 次通訊,總計 30*0.8=24GB。
將每個 rank 每次通訊的負載大小形式化為 total_transformer_block_params_in_B*dtype_bytes/num_gpus(GB)。
請注意,在此範例中,我們沒有包含嵌入所需的額外通訊,這也應該考慮在內。計算方式取決於輸入和輸出嵌入是否綁定。如果它們沒有綁定,則通訊量將增加 2 倍。
FSDP 緩衝區大小¶
首先,讓我們介紹為通訊分配的緩衝區
forward 目前需要 2 倍的 all-gather 緩衝區大小。原因如下
如 FSDP 預取細微差異 中所述,在顯式 forward 預取(forward_prefetch=True`) 情況 下 ,層 0 all-gather -> 層 0 forward 計算 -> 層 1 all-gather,需要 2 個 all-gather 大小的 緩衝區,因為一個緩衝區用於當前的 ``forward,而另一個緩衝區用於進行預取。
雖然相同的順序在理論上只需要 1 個緩衝區,但在隱式 forward 預取(forward_prefetch=False,預設值)的情況下,實際上仍然是 2 倍的 all-gather 大小的緩衝區。原因是在平面參數 FSDP 設計中,我們沒有從 all-gather 緩衝區複製出來。用於計算的參數直接查看 all-gather 緩衝區(實際上,“平面參數”的主要優點正是這個原因)。在這種情況下,當“層 1 all-gather”與“層 0 forward 計算”重疊時,“層 0 forward 計算”使用查看“層 0 all-gather”緩衝區的參數。
那麼,什麼時候需要 forward_prefetch=False 呢?對於靜態圖模型(例如大多數 LLM),有一個主要的技術原因。實際上,我們是為了一些 CPU 密集型內部模型快速添加了此選項,並且沒有在單元測試中使用它測試每個代碼路徑,因此我們對它的信心較低。forward_prefetching=False 可能更容易推理,因為我們不必檢查記錄的 forward 順序作為可能的“故障模式”;模組的 all-gather 始終可以在其分析器追蹤中在其自身的 record_function 標籤下找到。
backward 目前至少需要 2 倍的 all-gather 緩衝區大小,並且可能會更多一些。原因如下
目前的 FSDP 設計使用 recordStream 來管理一個流中產生的分配,這些分配在另一個流中消耗,這可能導致比預期更多的記憶體使用量。多出的量可能是“不確定的”,因為它取決於 GPU 核心時間相對於 CPU 的時間。limit_all_gathers=True 參數是對此的一種緩解措施 - 有關更多詳細信息,請參閱 FSDP 和 CUDACachingAllocator 中的討論。
現有的 FSDP 如何與自動梯度一起工作
- 現有的 FSDP 對 - flat_param進行 all-gather,這是自動梯度的葉子。
- 它呼叫 - torch.split以獲得對應於其組成原始參數的- flat_param的一維視圖。
- 它在每個一維拆分上呼叫 - torch.view以返回到 ND 視圖。
- 這意味著在 - backward中,我們最終得到- ViewBackward(ND -> 1D)和- SplitWithSizesBackward(這是一個串聯操作)。特別是,每個單獨的梯度都作為一個單獨的分配進行計算,並且發生顯式串聯以構造 reduce-scatter 輸入緩衝區。這意味著在峰值記憶體點處,reduce-scatter 的緩衝區大小實際上是 2 倍。
總之,對於 backward,它是大約 2 倍的 reduce-scatter 緩衝區大小加上任何 recordStream 影響。
其次,讓我們討論其他緩衝區
從所有 rank 收集分片參數後,它們需要額外的 total_transformer_block_params_in_B*dtype_bytes 緩衝區來存放完整的參數 - 因此繼續前面的範例,如果每個 transformer 區塊是 1.6B 參數並且參數採用 fp32 格式,則它將是 1.6*4=6.4GB 緩衝區。
並且需要 2 個這樣的緩衝區,因為一個當前正在使用,另一個正在預取。
總之,我們有
- 2 倍的通訊緩衝區,大小為 - total_transformer_block_params_in_B*dtype_bytes/num_gpus
- 2 倍的未分片 transformer 區塊參數緩衝區,大小為 - ``total_transformer_block_params_in_B*dtype_bytes
或者,如果您一直在關注這個例子
- 2*1.6*4/8=1.6GB
- 2**1.6*4=12.8GB
總計 14.4GB。
現在讓我們簡要討論一下嵌入會發生什麼,因為我們在計算中忽略了它們
鑑於我們討論過的規則,您在從“通訊緩衝區大小如下決定”開始的註釋中包含了該規則,我們可以分析如下
- 假設我們將 FSDP 套用到根模組(例如 - Transformer類)。假設我們進一步將 FSDP 套用到每個 transformer 區塊(例如- TransformerBlock類)。
- 最常見的情況是,嵌入和最終線性投影是根 - Transformer類的直接子級。
- 根據我們的規則,這意味著嵌入和最終線性投影被分配給根 - Transformer的平面參數。
- 我們還有_另一個_特殊規則,即根在 forward 後不會釋放其參數,因為它們無論如何都會在 backward 中立即進行 all-gather。 
- 綜上所述,這意味著包含嵌入和最終投影的根的平面參數在開始 forward 時進行 all-gather,並在 backward 結束之前一直保留在 GPU 記憶體中。 
- 如果嵌入和最終線性層沒有權重綁定,那麼我們_可以_進一步將 FSDP 套用到嵌入和最終線性層。對於權重綁定的參數,我們要求它們是同一個平面參數的一部分(否則它會被重複計算)。這將允許在 forward 中使用嵌入後釋放嵌入,並且僅在 backward 結束時進行 all-gather。 
- 希望這能讓您更好地理解 - 每個 FSDP 模組都會在其 - module.parameters中分配參數,但已分配給另一個嵌套 FSDP 模組的參數除外,並且 FSDP 模組的- forward定義了其參數的“活動”間隔。因此,嵌套的- nn.Module結構會影響 all-gather/free 時間表,從而影響記憶體/吞吐量性能。