快捷方式

CommDebugMode 入門

建立日期:2024 年 8 月 19 日 | 最後更新:2024 年 10 月 8 日 | 最後驗證:2024 年 11 月 5 日

作者: Anshul Sinha

在本教程中,我們將探討如何在分散式訓練環境中使用 CommDebugMode 配合 PyTorch 的 DistributedTensor (DTensor) 來跟蹤集合操作進行除錯。

先決條件

  • Python 3.8 - 3.11

  • PyTorch 2.2 或更高版本

CommDebugMode 是什麼以及為何有用

隨著模型規模不斷增大,使用者正尋求利用各種並行策略組合來擴充套件分散式訓練。然而,現有解決方案之間的互操作性不足帶來了巨大挑戰,這主要是由於缺乏能夠橋接這些不同並行策略的統一抽象。為了解決這個問題,PyTorch 提出了 DistributedTensor(DTensor),它抽象了分散式訓練中張量通訊的複雜性,提供了無縫的使用者體驗。然而,在使用現有並行解決方案以及利用 DTensor 等統一抽象開發並行解決方案時,底層集合通訊的內容和發生時間缺乏透明度,這可能導致高階使用者難以識別和解決問題。為了應對這一挑戰,CommDebugMode(一個 Python 上下文管理器)將作為 DTensor 的主要除錯工具之一,使使用者能夠檢視在使用 DTensor 時集合操作的發生時間和原因,從而有效解決此問題。

使用 CommDebugMode

您可以按如下方式使用 CommDebugMode

# The model used in this example is a MLPModule applying Tensor Parallel
comm_mode = CommDebugMode()
    with comm_mode:
        output = model(inp)

# print the operation level collective tracing information
print(comm_mode.generate_comm_debug_tracing_table(noise_level=0))

# log the operation level collective tracing information to a file
comm_mode.log_comm_debug_tracing_table_to_file(
    noise_level=1, file_name="transformer_operation_log.txt"
)

# dump the operation level collective tracing information to json file,
# used in the visual browser below
comm_mode.generate_json_dump(noise_level=2)

這是 MLPModule 在噪聲級別 0 時的輸出示例

Expected Output:
    Global
      FORWARD PASS
        *c10d_functional.all_reduce: 1
        MLPModule
          FORWARD PASS
            *c10d_functional.all_reduce: 1
            MLPModule.net1
            MLPModule.relu
            MLPModule.net2
              FORWARD PASS
                *c10d_functional.all_reduce: 1

要使用 CommDebugMode,您必須將執行模型的程式碼包裝在 CommDebugMode 中,並呼叫您想要用來顯示資料的 API。您還可以使用 noise_level 引數來控制顯示資訊的詳細程度。以下是每個噪聲級別顯示的內容:

0. 列印模組級別的集合計數。
1. 列印 DTensor 操作(不包括微不足道的操作)、模組分片資訊。
2. 列印張量操作(不包括微不足道的操作)。
3. 列印所有操作。

在上面的示例中,您可以看到集合操作 all_reduce 在 MLPModule 的前向傳播中發生了一次。此外,您可以使用 CommDebugMode 精確定位到 all-reduce 操作發生在 MLPModule 的第二個線性層中。

下面是互動式模組樹視覺化工具,您可以使用它上傳自己的 JSON dump 檔案

CommDebugMode 模組樹 - PyTorch 框架
將檔案拖到此處

結論

在本 Recipe 中,我們學習瞭如何使用 CommDebugMode 來除錯 Distributed Tensors 以及使用 PyTorch 中通訊集合的並行解決方案。您可以在嵌入的視覺化瀏覽器中使用您自己的 JSON 輸出。

有關 CommDebugMode 的更多詳細資訊,請參閱 comm_mode_features_example.py

文件

檢視 PyTorch 的完整開發者文件

檢視文件

教程

獲取針對初學者和高階開發者的深入教程

檢視教程

資源

查詢開發資源並獲得問題解答

檢視資源