跳轉到主要內容
社群

介紹 depyf:輕鬆掌握 torch.compile

作者: 2024 年 5 月 11 日2024 年 11 月 21 日暫無評論
depyf logo

我們很高興向 PyTorch 生態系統介紹新專案 depyf,旨在幫助使用者理解、學習和適應 torch.compile

動機

torch.compile 是 PyTorch 2.x 的基石,只需一行程式碼即可為訓練和推理加速機器學習工作流。僅僅包含 @torch.compile 就可以顯著提升程式碼效能。然而,找到 torch.compile 的最佳插入點並不容易,更不用說調整各種引數以實現最大效率的複雜性了。

torch.compile 堆疊的複雜性,包括 Dynamo、AOTAutograd、Inductor 等,帶來了陡峭的學習曲線。這些元件對於深度學習效能最佳化至關重要,但如果沒有紮實的基礎,它們可能會令人望而生畏。

注意:有關 torch.compile 如何工作的介紹性示例,請參閱此演練解釋

一個常用工具:TORCH_COMPILE_DEBUG

為了揭開 torch.compile 的神秘面紗,常用的方法是利用 TORCH_COMPILE_DEBUG 環境變數。雖然它提供了更多資訊,但解讀輸出仍然是一項艱鉅的任務。

例如,當我們有以下程式碼時

# test.py
import torch
from torch import _dynamo as torchdynamo
from typing import List

@torch.compile
def toy_example(a, b):
   x = a / (torch.abs(a) + 1)
   if b.sum() < 0:
       b = b * -1
   return x * b

def main():
   for _ in range(100):
       toy_example(torch.randn(10), torch.randn(10))

if __name__ == "__main__":
   main()

並使用 TORCH_COMPILE_DEBUG=1 python test.py 執行它,我們將得到一個名為 torch_compile_debug/run_2024_02_05_23_02_45_552124-pid_9520 的目錄,其中包含這些檔案

.
├── torchdynamo
│   └── debug.log
└── torchinductor
   ├── aot_model___0_debug.log
   ├── aot_model___10_debug.log
   ├── aot_model___11_debug.log
   ├── model__4_inference_10.1
   │   ├── fx_graph_readable.py
   │   ├── fx_graph_runnable.py
   │   ├── fx_graph_transformed.py
   │   ├── ir_post_fusion.txt
   │   ├── ir_pre_fusion.txt
   │   └── output_code.py
   ├── model__5_inference_11.2
   │   ├── fx_graph_readable.py
   │   ├── fx_graph_runnable.py
   │   ├── fx_graph_transformed.py
   │   ├── ir_post_fusion.txt
   │   ├── ir_pre_fusion.txt
   │   └── output_code.py
   └── model___9.0
       ├── fx_graph_readable.py
       ├── fx_graph_runnable.py
       ├── fx_graph_transformed.py
       ├── ir_post_fusion.txt
       ├── ir_pre_fusion.txt
       └── output_code.py

生成的檔案和日誌常常提出比解答更多的問題,讓開發人員對資料中的含義和關係感到困惑。TORCH_COMPILE_DEBUG 的常見困惑包括

  • model__4_inference_10.1 是什麼意思?
  • 我有一個函式,但目錄中有三個 model__xxx.py,它們之間有什麼對應關係?
  • debug.log 中的那些 LOAD_GLOBAL 是什麼?

一個更好的工具:depyf 助你解決問題

讓我們看看 depyf 如何幫助開發人員解決上述挑戰。要使用 depyf,只需執行 pip install depyf 或按照專案頁面https://github.com/thuml/depyf 安裝最新版本,然後將主要程式碼包圍在 with depyf.prepare_debug 中。

# test.py
import torch
from torch import _dynamo as torchdynamo
from typing import List

@torch.compile
def toy_example(a, b):
   x = a / (torch.abs(a) + 1)
   if b.sum() < 0:
       b = b * -1
   return x * b

def main():
   for _ in range(100):
       toy_example(torch.randn(10), torch.randn(10))

if __name__ == "__main__":
   import depyf
   with depyf.prepare_debug("depyf_debug_dir"):
       main()

執行 python test.py 後,depyf 將生成一個名為 depyf_debug_dirprepare_debug 函式的引數)的目錄。該目錄下將包含這些檔案

.
├── __compiled_fn_0 AFTER POST GRAD 0.py
├── __compiled_fn_0 Captured Graph 0.py
├── __compiled_fn_0 Forward graph 0.py
├── __compiled_fn_0 kernel 0.py
├── __compiled_fn_3 AFTER POST GRAD 0.py
├── __compiled_fn_3 Captured Graph 0.py
├── __compiled_fn_3 Forward graph 0.py
├── __compiled_fn_3 kernel 0.py
├── __compiled_fn_4 AFTER POST GRAD 0.py
├── __compiled_fn_4 Captured Graph 0.py
├── __compiled_fn_4 Forward graph 0.py
├── __compiled_fn_4 kernel 0.py
├── __transformed_code_0_for_torch_dynamo_resume_in_toy_example_at_8.py
├── __transformed_code_0_for_toy_example.py
├── __transformed_code_1_for_torch_dynamo_resume_in_toy_example_at_8.py
└── full_code_for_toy_example_0.py

並且有兩個明顯的優點

  1. 冗長難懂的 torchdynamo/debug.log 不見了。其內容被清理並以人類可讀的原始碼形式顯示在 full_code_for_xxx.py__transformed_code_{n}_for_xxx.py 中。值得注意的是,depyf 最繁瑣和困難的工作是將 torchdynamo/debug.log 中的位元組碼反編譯成 Python 原始碼,將開發人員從 Python 令人望而生畏的內部細節中解放出來。
  2. 函式名與計算圖之間的對應關係得到了尊重。例如,在 __transformed_code_0_for_toy_example.py 中,我們可以看到一個名為 __compiled_fn_0 的函式,我們將立即知道其對應的計算圖在 __compiled_fn_0_xxx.py 中,因為它們共享相同的 __compiled_fn_0 字首名。

full_code_for_xxx.py 開始,並遵循所涉及的函式,使用者將清楚地瞭解 torch.compile 對其程式碼做了什麼。

還有一件事:逐步除錯能力

使用偵錯程式逐行除錯程式碼是理解程式碼工作原理的好方法。然而,在 TORCH_COMPILE_DEBUG 下,這些檔案僅供使用者參考,無法用使用者關心的資料執行。

注意:“除錯”指的是檢查和改程序序的過程,而不是糾正有錯誤的程式碼。

depyf 的一個突出特點是它能夠促進 torch.compile 的逐步除錯:它生成的所有檔案都與 Python 直譯器內部的執行時程式碼物件關聯,我們可以在這些檔案中設定斷點。用法很簡單,只需新增一個上下文管理器 with depyf.debug(),它就能實現這一功能

# test.py
import torch
from torch import _dynamo as torchdynamo
from typing import List

@torch.compile
def toy_example(a, b):
   x = a / (torch.abs(a) + 1)
   if b.sum() < 0:
       b = b * -1
   return x * b

def main():
   for _ in range(100):
       toy_example(torch.randn(10), torch.randn(10))

if __name__ == "__main__":
   import depyf
   with depyf.prepare_debug("depyf_debug_dir"):
       main()
   with depyf.debug():
       main()

只有一個注意事項:除錯 torch.compile 的工作流程偏離了標準除錯工作流程。使用 torch.compile,許多程式碼是動態生成的。因此,我們需要

  1. 啟動程式
  2. 當程式退出 with depyf.prepare_debug("depyf_debug_dir") 時,程式碼將在 depyf_debug_dir 中可用。
  3. 當程式進入 with depyf.debug() 時,它將自動在內部設定一個斷點,從而使程式暫停。
  4. 導航到 depyf_debug_dir 設定斷點。
  5. 繼續執行程式碼,偵錯程式將命中這些斷點!
depyf screenshot

這是它看起來的截圖。所有程式碼和張量變數都是即時的,我們可以檢查任何變數,並逐步執行程式碼,就像我們現在日常的除錯工作流程一樣!唯一的區別是我們正在除錯 torch.compile 生成的程式碼而不是人工編寫的程式碼。

結論

torch.compile 是一個寶貴的工具,可以輕鬆加速 PyTorch 程式碼。對於那些希望深入研究 torch.compile 的人來說,無論是為了充分利用其潛力還是為了整合自定義操作,學習曲線都可能非常陡峭。depyf 旨在降低這一障礙,提供使用者友好的體驗來理解、學習和適應 torch.compile

請探索 depyf 並親身體驗它的好處!該專案是開源的,可在https://github.com/thuml/depyf 上獲取。透過 pip install depyf 安裝非常簡單。我們希望 depyf 能夠增強每個人的 torch.compile 開發工作流程。