
我們很高興向 PyTorch 生態系統介紹新專案 depyf,旨在幫助使用者理解、學習和適應 torch.compile!
動機
torch.compile 是 PyTorch 2.x 的基石,只需一行程式碼即可為訓練和推理加速機器學習工作流。僅僅包含 @torch.compile 就可以顯著提升程式碼效能。然而,找到 torch.compile 的最佳插入點並不容易,更不用說調整各種引數以實現最大效率的複雜性了。
torch.compile 堆疊的複雜性,包括 Dynamo、AOTAutograd、Inductor 等,帶來了陡峭的學習曲線。這些元件對於深度學習效能最佳化至關重要,但如果沒有紮實的基礎,它們可能會令人望而生畏。
注意:有關 torch.compile 如何工作的介紹性示例,請參閱此演練解釋。
一個常用工具:TORCH_COMPILE_DEBUG
為了揭開 torch.compile 的神秘面紗,常用的方法是利用 TORCH_COMPILE_DEBUG 環境變數。雖然它提供了更多資訊,但解讀輸出仍然是一項艱鉅的任務。
例如,當我們有以下程式碼時
# test.py
import torch
from torch import _dynamo as torchdynamo
from typing import List
@torch.compile
def toy_example(a, b):
x = a / (torch.abs(a) + 1)
if b.sum() < 0:
b = b * -1
return x * b
def main():
for _ in range(100):
toy_example(torch.randn(10), torch.randn(10))
if __name__ == "__main__":
main()
並使用 TORCH_COMPILE_DEBUG=1 python test.py 執行它,我們將得到一個名為 torch_compile_debug/run_2024_02_05_23_02_45_552124-pid_9520 的目錄,其中包含這些檔案
.
├── torchdynamo
│ └── debug.log
└── torchinductor
├── aot_model___0_debug.log
├── aot_model___10_debug.log
├── aot_model___11_debug.log
├── model__4_inference_10.1
│ ├── fx_graph_readable.py
│ ├── fx_graph_runnable.py
│ ├── fx_graph_transformed.py
│ ├── ir_post_fusion.txt
│ ├── ir_pre_fusion.txt
│ └── output_code.py
├── model__5_inference_11.2
│ ├── fx_graph_readable.py
│ ├── fx_graph_runnable.py
│ ├── fx_graph_transformed.py
│ ├── ir_post_fusion.txt
│ ├── ir_pre_fusion.txt
│ └── output_code.py
└── model___9.0
├── fx_graph_readable.py
├── fx_graph_runnable.py
├── fx_graph_transformed.py
├── ir_post_fusion.txt
├── ir_pre_fusion.txt
└── output_code.py
生成的檔案和日誌常常提出比解答更多的問題,讓開發人員對資料中的含義和關係感到困惑。TORCH_COMPILE_DEBUG 的常見困惑包括
model__4_inference_10.1是什麼意思?- 我有一個函式,但目錄中有三個
model__xxx.py,它們之間有什麼對應關係? debug.log中的那些LOAD_GLOBAL是什麼?
一個更好的工具:depyf 助你解決問題
讓我們看看 depyf 如何幫助開發人員解決上述挑戰。要使用 depyf,只需執行 pip install depyf 或按照專案頁面https://github.com/thuml/depyf 安裝最新版本,然後將主要程式碼包圍在 with depyf.prepare_debug 中。
# test.py
import torch
from torch import _dynamo as torchdynamo
from typing import List
@torch.compile
def toy_example(a, b):
x = a / (torch.abs(a) + 1)
if b.sum() < 0:
b = b * -1
return x * b
def main():
for _ in range(100):
toy_example(torch.randn(10), torch.randn(10))
if __name__ == "__main__":
import depyf
with depyf.prepare_debug("depyf_debug_dir"):
main()
執行 python test.py 後,depyf 將生成一個名為 depyf_debug_dir(prepare_debug 函式的引數)的目錄。該目錄下將包含這些檔案
.
├── __compiled_fn_0 AFTER POST GRAD 0.py
├── __compiled_fn_0 Captured Graph 0.py
├── __compiled_fn_0 Forward graph 0.py
├── __compiled_fn_0 kernel 0.py
├── __compiled_fn_3 AFTER POST GRAD 0.py
├── __compiled_fn_3 Captured Graph 0.py
├── __compiled_fn_3 Forward graph 0.py
├── __compiled_fn_3 kernel 0.py
├── __compiled_fn_4 AFTER POST GRAD 0.py
├── __compiled_fn_4 Captured Graph 0.py
├── __compiled_fn_4 Forward graph 0.py
├── __compiled_fn_4 kernel 0.py
├── __transformed_code_0_for_torch_dynamo_resume_in_toy_example_at_8.py
├── __transformed_code_0_for_toy_example.py
├── __transformed_code_1_for_torch_dynamo_resume_in_toy_example_at_8.py
└── full_code_for_toy_example_0.py
並且有兩個明顯的優點
- 冗長難懂的
torchdynamo/debug.log不見了。其內容被清理並以人類可讀的原始碼形式顯示在full_code_for_xxx.py和__transformed_code_{n}_for_xxx.py中。值得注意的是,depyf最繁瑣和困難的工作是將torchdynamo/debug.log中的位元組碼反編譯成 Python 原始碼,將開發人員從 Python 令人望而生畏的內部細節中解放出來。 - 函式名與計算圖之間的對應關係得到了尊重。例如,在
__transformed_code_0_for_toy_example.py中,我們可以看到一個名為__compiled_fn_0的函式,我們將立即知道其對應的計算圖在__compiled_fn_0_xxx.py中,因為它們共享相同的__compiled_fn_0字首名。
從 full_code_for_xxx.py 開始,並遵循所涉及的函式,使用者將清楚地瞭解 torch.compile 對其程式碼做了什麼。
還有一件事:逐步除錯能力
使用偵錯程式逐行除錯程式碼是理解程式碼工作原理的好方法。然而,在 TORCH_COMPILE_DEBUG 下,這些檔案僅供使用者參考,無法用使用者關心的資料執行。
注意:“除錯”指的是檢查和改程序序的過程,而不是糾正有錯誤的程式碼。
depyf 的一個突出特點是它能夠促進 torch.compile 的逐步除錯:它生成的所有檔案都與 Python 直譯器內部的執行時程式碼物件關聯,我們可以在這些檔案中設定斷點。用法很簡單,只需新增一個上下文管理器 with depyf.debug(),它就能實現這一功能
# test.py
import torch
from torch import _dynamo as torchdynamo
from typing import List
@torch.compile
def toy_example(a, b):
x = a / (torch.abs(a) + 1)
if b.sum() < 0:
b = b * -1
return x * b
def main():
for _ in range(100):
toy_example(torch.randn(10), torch.randn(10))
if __name__ == "__main__":
import depyf
with depyf.prepare_debug("depyf_debug_dir"):
main()
with depyf.debug():
main()
只有一個注意事項:除錯 torch.compile 的工作流程偏離了標準除錯工作流程。使用 torch.compile,許多程式碼是動態生成的。因此,我們需要
- 啟動程式
- 當程式退出
with depyf.prepare_debug("depyf_debug_dir")時,程式碼將在depyf_debug_dir中可用。 - 當程式進入
with depyf.debug()時,它將自動在內部設定一個斷點,從而使程式暫停。 - 導航到
depyf_debug_dir設定斷點。 - 繼續執行程式碼,偵錯程式將命中這些斷點!

這是它看起來的截圖。所有程式碼和張量變數都是即時的,我們可以檢查任何變數,並逐步執行程式碼,就像我們現在日常的除錯工作流程一樣!唯一的區別是我們正在除錯 torch.compile 生成的程式碼而不是人工編寫的程式碼。
結論
torch.compile 是一個寶貴的工具,可以輕鬆加速 PyTorch 程式碼。對於那些希望深入研究 torch.compile 的人來說,無論是為了充分利用其潛力還是為了整合自定義操作,學習曲線都可能非常陡峭。depyf 旨在降低這一障礙,提供使用者友好的體驗來理解、學習和適應 torch.compile。
請探索 depyf 並親身體驗它的好處!該專案是開源的,可在https://github.com/thuml/depyf 上獲取。透過 pip install depyf 安裝非常簡單。我們希望 depyf 能夠增強每個人的 torch.compile 開發工作流程。