Dynamo 概覽¶
在閱讀本節之前,請先閱讀 torch.compiler。
TorchDynamo(或簡稱 Dynamo)是一個 Python 層級的即時 (JIT) 編譯器,旨在讓未修改的 PyTorch 程式更快。Dynamo 會掛鉤到 CPython 中的框架評估 API (PEP 523),以便在執行之前動態修改 Python 位元組碼。它會重寫 Python 位元組碼,以將 PyTorch 操作序列提取到 FX 圖形 中,然後使用可自訂的後端進行編譯。它透過位元組碼分析建立此 FX 圖形,旨在將 Python 執行與已編譯的後端混合,以獲得兩全其美的優勢:易用性和效能。
Dynamo 讓您可以輕鬆試驗不同的編譯器後端,只需使用一行裝飾器 torch._dynamo.optimize() 即可讓 PyTorch 程式碼更快,而該裝飾器由 torch.compile() 進行包裝以方便使用。
下圖顯示 PyTorch 在使用 torch.compile 和不使用的情況下是如何工作的。
 
TorchInductor 是 Dynamo 圖形 支援的後端之一,可將其編譯為 GPU 的 Triton 或 CPU 的 C++/OpenMP。我們有一個 訓練效能儀表板,可提供不同訓練後端的效能比較。您可以在 PyTorch dev-discuss 上的 TorchInductor 文章 中閱讀更多資訊。
若要深入瞭解,請閱讀以下章節、觀看深入探討影片,並查看 dev-discuss 主題。
Dynamo 內部¶
**作者:**Jason Ansel 和 游凱超
本節將探討 Dynamo 的一些內部機制,並示範 Dynamo 在幕後是如何運作的。
什麼是守衛?¶
Dynamo 即時運作,並根據動態屬性對圖形進行專門化。以下是如何使用 Dynamo 的基本範例。可以使用 torchdynamo.optimize 裝飾函數或方法來啟用 Dynamo 優化。
from typing import List
import torch
from torch import _dynamo as torchdynamo
def my_compiler(gm: torch.fx.GraphModule, example_inputs: List[torch.Tensor]):
    print("my_compiler() called with FX graph:")
    gm.graph.print_tabular()
    return gm.forward  # return a python callable
@torchdynamo.optimize(my_compiler)
def toy_example(a, b):
    x = a / (torch.abs(a) + 1)
    if b.sum() < 0:
        b = b * -1
    return x * b
for _ in range(100):
    toy_example(torch.randn(10), torch.randn(10))
例如,上面的第一個圖形具有以下守衛。
GUARDS:
hasattr(L['a'], '_dynamo_dynamic_indices') == False
hasattr(L['b'], '_dynamo_dynamic_indices') == False
utils_device.CURRENT_DEVICE == None
___skip_backend_check() or ___current_backend() == ___lookup_backend(140355900538256)
check_tensor(L['a'], Tensor, DispatchKeySet(CPU, BackendSelect, ADInplaceOrView, AutogradCPU), torch.float32, device=None, requires_grad=False, size=[10], stride=[1])
check_tensor(L['b'], Tensor, DispatchKeySet(CPU, BackendSelect, ADInplaceOrView, AutogradCPU), torch.float32, device=None, requires_grad=False, size=[10], stride=[1])
如果其中任何一個守衛失敗,圖形將被重新擷取並重新編譯。其中一個有趣的守衛是 check_tensor,它會檢查以下 torch.Tensor 屬性。
- 張量的 Python 類別(張量子類別等) 
- dtype 
- 裝置 
- requires_grad 
- dispatch_key(套用執行緒本機包含/排除) 
- ndim 
- sizes* 
- strides* 
完整專業化模式允許後端編譯器假設圖形完全是靜態的。不幸的是,大多數後端都需要這樣做。當不在動態形狀模式時,返回動態形狀的運算子將觸發圖形中斷。
Dynamo 在做什麼?¶
如果您想更好地瞭解 Dynamo 在做什麼,可以使用以下命令執行您的程式碼。
TORCH_LOGS="+dynamo,guards,bytecode"
如果您不熟悉 Python 位元組碼,可以新增一個反編譯器鉤子,將位元組碼反編譯成人類可讀的原始程式碼。一個可用的工具是 depyf。如果您尚未安裝 depyf,請執行 pip install depyf。然後,在執行任何程式碼之前,新增以下程式碼以安裝反編譯鉤子。
import depyf
depyf.install()
此程式碼會觸發有用的(但會產生大量垃圾訊息的)列印輸出。
例如,toy_example 中第一個圖形的列印輸出如下。
__compiled_fn_0 <eval_with_key>.1
opcode         name     target                                                  args              kwargs
-------------  -------  ------------------------------------------------------  ----------------  --------
placeholder    a        a                                                       ()                {}
placeholder    b        b                                                       ()                {}
call_function  abs_1    <built-in method abs of type object at 0x7f9ca082f8a0>  (a,)              {}
call_function  add      <built-in function add>                                 (abs_1, 1)        {}
call_function  truediv  <built-in function truediv>                             (a, add)          {}
call_method    sum_1    sum                                                     (b,)              {}
call_function  lt       <built-in function lt>                                  (sum_1, 0)        {}
output         output   output                                                  ((truediv, lt),)  {}
ORIGINAL BYTECODE toy_example example.py line 12
 14           0 LOAD_FAST                0 (a)
              2 LOAD_GLOBAL              0 (torch)
              4 LOAD_METHOD              1 (abs)
              6 LOAD_FAST                0 (a)
              8 CALL_METHOD              1
             10 LOAD_CONST               1 (1)
             12 BINARY_ADD
             14 BINARY_TRUE_DIVIDE
             16 STORE_FAST               2 (x)
 15          18 LOAD_FAST                1 (b)
             20 LOAD_METHOD              2 (sum)
             22 CALL_METHOD              0
             24 LOAD_CONST               2 (0)
             26 COMPARE_OP               0 (<)
             28 POP_JUMP_IF_FALSE       19 (to 38)
 16          30 LOAD_FAST                1 (b)
             32 LOAD_CONST               3 (-1)
             34 BINARY_MULTIPLY
             36 STORE_FAST               1 (b)
 17     >>   38 LOAD_FAST                2 (x)
             40 LOAD_FAST                1 (b)
             42 BINARY_MULTIPLY
             44 RETURN_VALUE
MODIFIED BYTECODE toy_example example.py line 12
 12           0 LOAD_GLOBAL              3 (__compiled_fn_0)
              2 LOAD_FAST                0 (a)
              4 LOAD_FAST                1 (b)
              6 CALL_FUNCTION            2
              8 UNPACK_SEQUENCE          2
             10 STORE_FAST               2 (x)
             12 POP_JUMP_IF_FALSE       12 (to 24)
             14 LOAD_GLOBAL              4 (__resume_at_30_1)
             16 LOAD_FAST                1 (b)
             18 LOAD_FAST                2 (x)
             20 CALL_FUNCTION            2
             22 RETURN_VALUE
        >>   24 LOAD_GLOBAL              5 (__resume_at_38_2)
             26 LOAD_FAST                1 (b)
             28 LOAD_FAST                2 (x)
             30 CALL_FUNCTION            2
             32 RETURN_VALUE
possible source code:
def toy_example(a, b):
    __temp_1 = __compiled_fn_0(a, b)
    x = __temp_1[0]
    if __temp_1[1]:
        return __resume_at_30_1(b, x)
    return __resume_at_38_2(b, x)
If you find the decompiled code is wrong,please submit an issue at https://github.com/youkaichao/depyf/issues.
在頂部,您可以看到 FX 圖形。接下來,您可以看到函數的原始位元組碼,然後是 Dynamo 生成的修改後位元組碼,以及用於參考的反編譯原始程式碼。最後,您可以看到我們上面提到的守衛。
在修改後的位元組碼中,__compiled_fn_0 是 my_compiler()(已編譯的圖形)的返回值。__resume_at_30_1 和 __resume_at_38_2 都是生成的延續函數,它們在圖形中斷後(在位元組碼偏移量 30 和 38 處)繼續執行。這些函數都採用以下形式。
__resume_at_<offset>:
    ... restore stack state if needed ...
    JUMP_ABSOLUTE <offset> into toy_example
    ... original bytecode of toy_example ...
透過生成這個 resume_at 函數,我們強制函數的其餘部分在新 Python 框架中執行,該框架會在執行第一次到達該點時遞迴地觸發 Dynamo 重新開始其擷取。
如何檢查 Dynamo 生成的構件?¶
要檢查 Dynamo 生成的構件,可以使用 API torch._dynamo.eval_frame._debug_get_cache_entry_list 從函數的 __code__ 物件中檢索已編譯的程式碼和守衛條件。一個已編譯的函數可以有多個快取條目,每個快取條目都包含一個用於檢查守衛條件的生成函數,以及一個 types.CodeType 物件,用於在滿足守衛條件時保留要執行的程式碼。
from torch._dynamo.eval_frame import _debug_get_cache_entry_list, innermost_fn
cache_entries = _debug_get_cache_entry_list(innermost_fn(toy_example))
cache_entry = cache_entries[0]
guard, code = cache_entry.check_fn, cache_entry.code
# the guard takes the local variables of an input frame, and tells whether a re-compilation should be triggered.
import dis
dis.dis(guard)
dis.dis(code)
如果您了解 Python 位元組碼,則可以理解上述輸出。
對於守衛函數,不需要檢查位元組碼。我們可以直接訪問其守衛條件
for code_part in guard.code_parts:
    print(code_part)
輸出為
___guarded_code.valid
___check_global_state()
hasattr(L['a'], '_dynamo_dynamic_indices') == False
hasattr(L['b'], '_dynamo_dynamic_indices') == False
utils_device.CURRENT_DEVICE == None
___skip_backend_check() or ___current_backend() == ___lookup_backend(140215810860528)
___check_tensors(L['a'], L['b'], tensor_check_names=tensor_check_names)
只有當所有條件都滿足時,守衛函數才會返回 true,並且才會執行已編譯的程式碼。
對於已編譯的程式碼,我們無法直接訪問其原始碼,而必須對其進行反編譯。
from depyf import decompile
print(decompile(code))
輸出為
def toy_example(a, b):
    __temp_1 = __compiled_fn_0(a, b)
    x = __temp_1[0]
    if __temp_1[1]:
        return __resume_at_30_1(b, x)
    return __resume_at_38_2(b, x)
程式碼中引用的一些名稱是
- 已編譯的函數,儲存在包含原始函數 - toy_example的模組的全域命名空間中。這些名稱包括- __compiled_fn_0/- __resume_at_30_1/- __resume_at_38_2。
- 用於檢查守衛條件的閉包變數。可以從 - guard.__code__.co_freevars訪問這些名稱,並且這些值儲存在- guard.__closure__中。這些名稱包括- ___guarded_code/- ___is_grad_enabled/- ___are_deterministic_algorithms_enabled/- ___is_torch_function_enabled/- utils_device/- ___check_tensors/- tensor_check_names。
- guard函數的參數- L。這是一個將- toy_example的參數名稱映射到其值的字典。這僅在呼叫函數時可用,其中框架評估 API 發揮作用。簡而言之,- L是一個結構為- {'a': value_a, 'b': value_b}的- dict。因此,您可以看到程式碼使用- L['a']來引用輸入變數- a。
圖形中斷顯示在已編譯的 toy_example 的程式碼中,我們必須使用 Python 直譯器來選擇要執行的以下圖形。
請注意,我們傳遞了一個簡單的 my_compiler 函數作為後端編譯器,因此子圖程式碼 __resume_at_38_2、__resume_at_30_1 和 __compiled_fn_0 仍然是 Python 程式碼。這也可以檢查(請忽略函數名稱,僅使用函數簽章和函數主體程式碼)
print("source code of __compiled_fn_0:")
print(innermost_fn(__compiled_fn_0).__self__.code)
print("=" * 60)
print("source code of __resume_at_30_1:")
print(decompile(__resume_at_30_1))
print("=" * 60)
print("source code of __resume_at_38_2:")
print(decompile(__resume_at_38_2))
source code of __compiled_fn_0:
def forward(self, L_a_ : torch.Tensor, L_b_ : torch.Tensor):
    l_a_ = L_a_
    l_b_ = L_b_
    abs_1 = torch.abs(l_a_)
    add = abs_1 + 1;  abs_1 = None
    truediv = l_a_ / add;  l_a_ = add = None
    sum_1 = l_b_.sum();  l_b_ = None
    lt = sum_1 < 0;  sum_1 = None
    return (truediv, lt)
# To see more debug info, please use ``graph_module.print_readable()``
============================================================
source code of __resume_at_30_1:
def <resume in toy_example>(b, x):
    b = b * -1
    return x * b
============================================================
source code of __resume_at_38_2:
def <resume in toy_example>(b, x):
    return x * b
但是,如果我們使用其他後端(如內建的 inductor),則子圖程式碼將被編譯為 GPU 的 CUDA 核心或 CPU 的 C++ 程式碼。
總之,已編譯的程式碼在概念上等效於以下程式碼
def compiled_example(a, b):
    L = {'a': a, 'b': b}
    for guard, code in get_cache_entries():
        if guard(L):
            return code(a, b)
    recompile_and_add_another_cache_entry()
下圖展示了 torch.compile 如何轉換和最佳化使用者編寫的程式碼:它首先從使用者編寫的函數中提取計算圖,並將這些圖編譯成最佳化的函數,然後將它們組裝成一個新的函數,該函數在功能上等效於使用者編寫的程式碼,但經過最佳化以具有良好的計算速度。
 
如需深入瞭解所有這些內部實作方式,請參閱Dynamo 深入探討。