Dynamo 概述¶
在閱讀本節之前,請先閱讀 torch.compiler。
TorchDynamo(簡稱 Dynamo)是一種 Python 級別的即時 (JIT) 編譯器,旨在加速未修改的 PyTorch 程式。Dynamo 透過鉤子介入 CPython 中的幀評估 API(PEP 523),在 Python 位元組碼執行前動態修改它。它重寫 Python 位元組碼,將 PyTorch 操作序列提取到一個 FX Graph 中,然後使用可定製的後端進行編譯。它透過位元組碼分析建立此 FX Graph,旨在將 Python 執行與編譯後端相結合,從而獲得兩者的優勢——可用性和效能。
Dynamo 使得使用不同的編譯器後端來加速 PyTorch 程式碼變得容易,只需一行裝飾器 torch._dynamo.optimize(),為了方便起見,它被包裝在 torch.compile() 中
下圖演示了 PyTorch 在使用 torch.compile 和不使用它時的區別
TorchInductor 是 Dynamo Graph 支援的後端之一,用於將圖轉換為適用於 GPU 的 Triton 或適用於 CPU 的 C++/OpenMP。我們有一個訓練效能儀表盤,提供了不同訓練後端之間的效能比較。您可以在PyTorch dev-discuss 上的 TorchInductor 文章中瞭解更多資訊。
要獲得深入概述,請閱讀以下部分,觀看深入講解影片,並檢視 dev-discuss 主題。
Dynamo 內部機制¶
作者: Jason Ansel 和 Kaichao You
本節將介紹一些 Dynamo 內部機制,並演示 Dynamo 如何在底層工作。
什麼是 Guard?¶
Dynamo 以即時方式執行,並根據動態屬性特化圖。下面是如何使用 Dynamo 的一個基本示例。可以使用 torchdynamo.optimize 裝飾器來裝飾函式或方法,以啟用 Dynamo 最佳化
from typing import List
import torch
from torch import _dynamo as torchdynamo
def my_compiler(gm: torch.fx.GraphModule, example_inputs: List[torch.Tensor]):
print("my_compiler() called with FX graph:")
gm.graph.print_tabular()
return gm.forward # return a python callable
@torchdynamo.optimize(my_compiler)
def toy_example(a, b):
x = a / (torch.abs(a) + 1)
if b.sum() < 0:
b = b * -1
return x * b
for _ in range(100):
toy_example(torch.randn(10), torch.randn(10))
例如,上面的第一個圖有以下 Guard
GUARDS:
hasattr(L['a'], '_dynamo_dynamic_indices') == False
hasattr(L['b'], '_dynamo_dynamic_indices') == False
utils_device.CURRENT_DEVICE == None
___skip_backend_check() or ___current_backend() == ___lookup_backend(140355900538256)
check_tensor(L['a'], Tensor, DispatchKeySet(CPU, BackendSelect, ADInplaceOrView, AutogradCPU), torch.float32, device=None, requires_grad=False, size=[10], stride=[1])
check_tensor(L['b'], Tensor, DispatchKeySet(CPU, BackendSelect, ADInplaceOrView, AutogradCPU), torch.float32, device=None, requires_grad=False, size=[10], stride=[1])
如果其中任何一個 Guard 失敗,圖將被重新捕獲並重新編譯。其中有趣的 Guard 是 check_tensor,它檢查以下 torch.Tensor 屬性
張量的 Python 類(張量子類化等)
dtype
device
requires_grad
dispatch_key(應用執行緒區域性包含/排除規則後)
ndim
sizes*
strides*
完全特化模式允許後端編譯器假定圖完全是靜態的。不幸的是,大多數後端都需要這樣做。在非動態形狀模式下,返回動態形狀的運算元將觸發圖中斷。
Dynamo 在做什麼?¶
如果您想更好地理解 Dynamo 在做什麼,可以執行以下程式碼
TORCH_LOGS="+dynamo,guards,bytecode"
如果您不熟悉 Python 位元組碼,可以新增反編譯鉤子將位元組碼反編譯為人類可讀的原始碼。一個可用的工具是 depyf。如果您尚未安裝 depyf,請執行 pip install depyf。然後,在執行任何程式碼之前新增以下程式碼來安裝反編譯鉤子。
import depyf
depyf.install()
此程式碼會觸發有用的(但可能刷屏的)列印輸出。
例如,toy_example 中第一個圖的列印輸出如下
__compiled_fn_0 <eval_with_key>.1
opcode name target args kwargs
------------- ------- ------------------------------------------------------ ---------------- --------
placeholder a a () {}
placeholder b b () {}
call_function abs_1 <built-in method abs of type object at 0x7f9ca082f8a0> (a,) {}
call_function add <built-in function add> (abs_1, 1) {}
call_function truediv <built-in function truediv> (a, add) {}
call_method sum_1 sum (b,) {}
call_function lt <built-in function lt> (sum_1, 0) {}
output output output ((truediv, lt),) {}
ORIGINAL BYTECODE toy_example example.py line 12
14 0 LOAD_FAST 0 (a)
2 LOAD_GLOBAL 0 (torch)
4 LOAD_METHOD 1 (abs)
6 LOAD_FAST 0 (a)
8 CALL_METHOD 1
10 LOAD_CONST 1 (1)
12 BINARY_ADD
14 BINARY_TRUE_DIVIDE
16 STORE_FAST 2 (x)
15 18 LOAD_FAST 1 (b)
20 LOAD_METHOD 2 (sum)
22 CALL_METHOD 0
24 LOAD_CONST 2 (0)
26 COMPARE_OP 0 (<)
28 POP_JUMP_IF_FALSE 19 (to 38)
16 30 LOAD_FAST 1 (b)
32 LOAD_CONST 3 (-1)
34 BINARY_MULTIPLY
36 STORE_FAST 1 (b)
17 >> 38 LOAD_FAST 2 (x)
40 LOAD_FAST 1 (b)
42 BINARY_MULTIPLY
44 RETURN_VALUE
MODIFIED BYTECODE toy_example example.py line 12
12 0 LOAD_GLOBAL 3 (__compiled_fn_0)
2 LOAD_FAST 0 (a)
4 LOAD_FAST 1 (b)
6 CALL_FUNCTION 2
8 UNPACK_SEQUENCE 2
10 STORE_FAST 2 (x)
12 POP_JUMP_IF_FALSE 12 (to 24)
14 LOAD_GLOBAL 4 (__resume_at_30_1)
16 LOAD_FAST 1 (b)
18 LOAD_FAST 2 (x)
20 CALL_FUNCTION 2
22 RETURN_VALUE
>> 24 LOAD_GLOBAL 5 (__resume_at_38_2)
26 LOAD_FAST 1 (b)
28 LOAD_FAST 2 (x)
30 CALL_FUNCTION 2
32 RETURN_VALUE
possible source code:
def toy_example(a, b):
__temp_1 = __compiled_fn_0(a, b)
x = __temp_1[0]
if __temp_1[1]:
return __resume_at_30_1(b, x)
return __resume_at_38_2(b, x)
If you find the decompiled code is wrong,please submit an issue at https://github.com/youkaichao/depyf/issues.
頂部是 FX 圖。接下來是函式的原始位元組碼,然後是 Dynamo 生成的修改後的位元組碼,以及反編譯後的原始碼供參考。最後,您會看到我們上面介紹的 Guard。
在修改後的位元組碼中,__compiled_fn_0 是 my_compiler() 的返回值(編譯後的圖)。__resume_at_30_1 和 __resume_at_38_2 都是生成的繼續函式,它們在圖中斷後(位元組碼偏移量 30 和 38 處)繼續執行。每個函式都採用以下形式
__resume_at_<offset>:
... restore stack state if needed ...
JUMP_ABSOLUTE <offset> into toy_example
... original bytecode of toy_example ...
透過生成這個 resume_at 函式,我們強制函式的其餘部分在一個新的 Python 幀中執行,當執行第一次到達該點時,這會遞迴地觸發 Dynamo 重新開始其捕獲。
如何檢查 Dynamo 生成的 Artifact?¶
為了檢查 Dynamo 生成的 Artifact,可以使用 API torch._dynamo.eval_frame._debug_get_cache_entry_list,該 API 從函式的 __code__ 物件中檢索編譯後的程式碼和 Guard。一個編譯後的函式可以有多個快取條目,每個快取條目包含一個用於檢查 Guard 的生成函式,以及一個 types.CodeType 物件,用於儲存滿足 Guard 條件時要執行的程式碼。
from torch._dynamo.eval_frame import _debug_get_cache_entry_list, innermost_fn
cache_entries = _debug_get_cache_entry_list(innermost_fn(toy_example))
cache_entry = cache_entries[0]
guard, code = cache_entry.check_fn, cache_entry.code
# the guard takes the local variables of an input frame, and tells whether a re-compilation should be triggered.
import dis
dis.dis(guard)
dis.dis(code)
如果您瞭解 Python 位元組碼,就可以理解上述輸出。
對於 Guard 函式,無需檢查位元組碼。我們可以直接訪問其 Guard 條件
for code_part in guard.code_parts:
print(code_part)
輸出如下
___guarded_code.valid
___check_global_state()
hasattr(L['a'], '_dynamo_dynamic_indices') == False
hasattr(L['b'], '_dynamo_dynamic_indices') == False
utils_device.CURRENT_DEVICE == None
___skip_backend_check() or ___current_backend() == ___lookup_backend(140215810860528)
___check_tensors(L['a'], L['b'], tensor_check_names=tensor_check_names)
只有當所有條件都滿足時,Guard 函式才返回 true,並執行編譯後的程式碼。
對於編譯後的程式碼,我們不能直接訪問其原始碼,而必須對其進行反編譯。
from depyf import decompile
print(decompile(code))
輸出如下
def toy_example(a, b):
__temp_1 = __compiled_fn_0(a, b)
x = __temp_1[0]
if __temp_1[1]:
return __resume_at_30_1(b, x)
return __resume_at_38_2(b, x)
程式碼中引用的某些名稱如下
編譯後的函式,儲存在包含原始函式
toy_example的模組的全域性名稱空間中。這些名稱包括__compiled_fn_0/__resume_at_30_1/__resume_at_38_2等。用於檢查 Guard 的閉包變數。名稱可以從
guard.__code__.co_freevars訪問,值儲存在guard.__closure__中。這些名稱包括___guarded_code/___is_grad_enabled/___are_deterministic_algorithms_enabled/___is_torch_function_enabled/utils_device/___check_tensors/tensor_check_names等。Guard 函式的引數
L。這是一個 dict,將toy_example的引數名稱對映到其值。這僅在呼叫函式時可用,此時幀評估 API 開始發揮作用。簡而言之,L是一個結構為{'a': value_a, 'b': value_b}的 dict。因此,您可以看到程式碼使用L['a']來引用輸入變數a。
圖中斷顯示在編譯後的 toy_example 程式碼中,我們需要使用 Python 直譯器來選擇要執行的後續圖。
注意,我們傳遞了一個簡單的 my_compiler 函式作為後端編譯器,因此子圖程式碼 __resume_at_38_2、__resume_at_30_1 和 __compiled_fn_0 仍然是 Python 程式碼。這也值得檢查(請忽略函式名稱,僅使用函式簽名和函式體程式碼)
print("source code of __compiled_fn_0:")
print(innermost_fn(__compiled_fn_0).__self__.code)
print("=" * 60)
print("source code of __resume_at_30_1:")
print(decompile(__resume_at_30_1))
print("=" * 60)
print("source code of __resume_at_38_2:")
print(decompile(__resume_at_38_2))
source code of __compiled_fn_0:
def forward(self, L_a_ : torch.Tensor, L_b_ : torch.Tensor):
l_a_ = L_a_
l_b_ = L_b_
abs_1 = torch.abs(l_a_)
add = abs_1 + 1; abs_1 = None
truediv = l_a_ / add; l_a_ = add = None
sum_1 = l_b_.sum(); l_b_ = None
lt = sum_1 < 0; sum_1 = None
return (truediv, lt)
# To see more debug info, please use ``graph_module.print_readable()``
============================================================
source code of __resume_at_30_1:
def <resume in toy_example>(b, x):
b = b * -1
return x * b
============================================================
source code of __resume_at_38_2:
def <resume in toy_example>(b, x):
return x * b
然而,如果我們使用其他後端,如內建的 inductor,子圖程式碼將被編譯為 GPU 的 CUDA 核心或 CPU 的 C++ 程式碼。
總結一下,編譯後的程式碼在概念上等同於以下程式碼
def compiled_example(a, b):
L = {'a': a, 'b': b}
for guard, code in get_cache_entries():
if guard(L):
return code(a, b)
recompile_and_add_another_cache_entry()
下圖演示了 torch.compile 如何轉換和最佳化使用者編寫的程式碼:它首先從使用者編寫的函式中提取計算圖,並將這些圖編譯成最佳化函式,然後將它們組裝成一個新函式,該函式在功能上等同於使用者編寫的程式碼,但經過最佳化後具有良好的計算速度。
要詳細瞭解所有這些內部實現,請參閱 Dynamo 深入講解。