Dynamo 深度探索¶
TorchDynamo(或簡稱為 Dynamo)是 torch.compile 中的跟蹤器,它通常是那些瘋狂回溯的罪魁禍首。然而,我們不能盲目地將這些錯誤歸咎於 Dynamo。為了給使用者提供所需的靈活性,Dynamo 承擔了理解任何 Python 程式的艱鉅任務。特別是,Dynamo 必須在內部實現 Python 程式語言的很大一部分!
在這篇文章中,我們將從頭開始介紹 Dynamo 的內部設計。我們將討論它提供的功能以及如何實現。讀完這篇文章後,你將更好地理解當你使用 torch.compile 編譯 PyTorch 程式時出現錯誤,或者編譯成功但加速效果不如預期時,究竟是哪裡出了問題。
Dynamo 溫和入門¶
在深入瞭解所有實現細節之前,我們先來討論 Dynamo 的作用。
Dynamo 是一個跟蹤器。這意味著,給定一個函式及其輸入,它會執行該函式並將一系列線性指令(無控制流)記錄到一個圖中。例如,考慮以下程式
import torch
@torch.compile
def mse(x, y):
z = (x - y) ** 2
return z.sum()
x = torch.randn(200)
y = torch.randn(200)
mse(x, y)
如果我們將此程式儲存到檔案 example.py 中並執行
TORCH_LOGS=graph_code python example.py
我們會看到 Dynamo 跟蹤的輸出
def forward(l_x_: torch.Tensor, l_y_: torch.Tensor):
# File: example.py:5, code: z = (x - y) ** 2
sub = l_x_ - l_y_
z = sub ** 2
# File: example.py:6, code: return z.sum()
sum_1 = z.sum()
return (sum_1,)
我們將這稱為給定輸入的函式圖(或跟蹤)。它透過 FX 圖表示。我們可以簡單地將 FX 圖視為一個儲存函式呼叫列表的容器。
我們首先應該注意到的是,圖是 PyTorch 操作的線性序列。1 Dynamo 記錄所有 PyTorch 操作並按順序儲存。例如,它將 z = (x - y) ** 2 拆分為兩個組成操作:sub = l_x_ - l_y_ 和 z = sub ** 2。
當說跟蹤是線性的時,意味著沒有分支或任何控制流。為了驗證這一點,考慮
import torch
@torch.compile
def fn(x, n):
y = x ** 2
if n >= 0:
return (n + 1) * y
else:
return y / n
x = torch.randn(200)
fn(x, 2)
當使用 TORCH_LOGS=graph_code 執行時,會返回
def forward(l_x_: torch.Tensor):
# File: example.py:5, code: y = x ** 2
y = l_x_ ** 2
# File: example.py:7, code: return (n + 1) * y
mul = 3 * y
return (mul,)
我們看到 Dynamo 完全從跟蹤中移除了 if 語句,只記錄了使用輸入執行的操作。
因此,應該清楚的是,函式的跟蹤取決於輸入。特別是,這意味著跟蹤不是在我們編寫 @torch.compile 時生成的,而是在我們使用實際引數執行函式 fn(x, 2) 時生成的。
這裡另一個值得注意的有趣之處是,Dynamo 移除了函式的第二個引數。相反,它將其視為常量並在圖中記錄了操作 n + 1 的結果。這是 Dynamo 的另一個特性:Dynamo 會將除整數以外的任何非張量值視為常量。現在來看看整數為何特別。
Dynamo 的最後一個決定性特性是它知道如何處理動態形狀。符號形狀是指 Dynamo 跟蹤形狀(更普遍地說,整數)的能力,而不是將其視為常量。這有助於避免重新編譯,並在生產環境中部署適用於任何尺寸的通用模型。出現動態形狀的主要例子是批處理大小,我們可能會使用固定批處理大小訓練模型,但隨後對任意批處理大小執行推理;或者處理文字或音訊時遇到的變長序列。
我們可以透過多執行幾次上面的示例來看到這一點
import torch
@torch.compile
def fn(x, n):
y = x ** 2
if n >= 0:
return (n + 1) * y
else:
return y / n
x = torch.randn(200)
fn(x, 2)
fn(x, 3)
fn(x, -2)
在這種情況下,TORCH_LOGS=graph_code 生成另外兩個圖
# Graph for n==2 omitted
def forward(self, l_x_: torch.Tensor, l_n_: torch.SymInt):
# File: a.py:5, code: y = x ** 2
y = l_x_ ** 2
# File: a.py:7, code: return (n + 1) * y
add = l_n_ + 1
mul = add * y
return (mul,)
def forward(self, l_x_: torch.Tensor, l_n_: torch.SymInt):
# File: a.py:5, code: y = x ** 2
y = l_x_ ** 2
# File: a.py:9, code: return y / n
truediv = y / l_n_
return (truediv,)
Dynamo 檢測到第一個呼叫後一個整數改變了其值,並開始跟蹤它。我們看到這些圖是通用的,透過型別為 SymInt 的物件符號性地跟蹤變數 n。
如果在這些呼叫之後,我們呼叫 fn(x, 4),Dynamo 不會重新編譯,而是重用已經跟蹤的圖。
總結一下: 1. Dynamo 是一個 Python 跟蹤器 2. 給定一些輸入,它返回一個包含已執行 PyTorch 函式的 FX 圖 3. 如果檢測到整數在呼叫之間發生了變化,它也可以跟蹤整數 4. 它會特殊化除張量或標量之外的任何其他值
當然,Dynamo 還做了更多事情,比如判斷何時需要重新跟蹤、重寫函式的位元組碼、實現圖中斷等…… 為了使介紹簡短,我們將在後續內容中逐步討論所有這些。
PEP 523: 為 CPython 新增一個幀評估 API¶
現在想象一下,我們接到了實現 Dynamo 的任務。我們甚至從哪裡開始呢?對我們來說相當方便的是,PEP 523 隨 Python 3.6 釋出。這個 PEP 旨在允許第三方為 Python 建立 JIT 編譯器。來看看如何實現。
關於 CPython 的說明:CPython 內部實現為一個棧機。Python 程式被編譯成位元組碼,然後由該直譯器執行。要了解更多關於這些位元組碼的資訊,請參閱標準庫中的 dis 模組。另請參閱開發者文件,瞭解 CPython 直譯器的介紹。我們假設讀者熟悉棧機的概念。
PEP 523 暴露了一個 API,使用者可以新增一個自定義的按函式直譯器。然後,CPython 將使用此直譯器而不是其自己的直譯器來執行該函式。為了能夠執行函式,在進入時,CPython 會向自定義直譯器提供以下資訊: - 函式的位元組碼 - 函式引數的值(即區域性變數)及其名稱 - 全域性變數的值及其名稱 - 內建函式,例如 abs 或 print
總之,CPython 為使用者的直譯器提供了執行函式所需的所有資訊。3
有了這個 API,我們可以透過實現一個執行程式碼並將執行過程中發生的所有 PyTorch 操作記錄到圖中的直譯器來實現跟蹤器。這正是 Dynamo 所做的。
Dynamo 使用這個 CPython API 來解析所有這些物件,並將它們打包到一個 Python 結構中。完成這些後……它就從 C 回到 Python 了。除了這部分與 CPython 通訊的程式碼外,Dynamo 完全是用 Python 實現的。
應該清楚的是,裝飾器 @torch.compile 的作用是安裝必要的支架,以便在函式呼叫時將位元組碼、引數、全域性變數等傳遞給 Dynamo。再次強調,@torch.compile 本身實際上不編譯任何東西。
在 Python 中實現 CPython¶
所以,我們回到了 Python 世界。我們有了函式的位元組碼,以及執行它所需的所有上下文。特別是,我們抵達了 _convert_frame_assert。這是裝飾器 torch.compile 返回的函式!我們從 _dynamo.optimize 到達此函式。裝飾器 torch.compile 只是 _dynamo.optimize 的一個便捷 API。
在開始實現 Python 直譯器之前,我們想定義一個 IR(中間表示)。特別是,我們想將所有區域性變數和全域性變數封裝在我們自己的內部類中。這使我們能夠更好地跟蹤這些物件,並將 Dynamo 看來可以以相同方式處理的物件分組在一起。
內部類結構的父類是 VariableTracker,它代表 Dynamo 理解的不同物件。例如,ListVariable 代表一個 list 物件,並在內部維護一個 VariableTrackers 列表。另一個 VariableTracker 的例子是 ConstantVariable。ConstantVariable 封裝了所有被 Dynamo 視為常量的物件。我們還為需要特別關注的物件設定了特殊的子類,例如 TensorVariable。所有這些內部類都在 torch/_dynamo/variables 資料夾中定義。
Python 物件在 VariableBuilder._wrap 中被封裝到其對應的 VariableTracker 類中。此函式只是一個非常長的 elif 鏈,它嘗試將 Python 輸入遞迴地模式匹配到適當的 VariableTracker 型別。
除錯技巧。當我們從 dynamo 獲得意外結果時,有時是由於構建器引起的。如果構建器的邏輯錯誤,有時 Dynamo 可能會將變數封裝到不正確的 VariableTracker 型別中,這可能導致後續問題。檢視錯誤中出現的 VariableTracker 型別以及遇到 Dynamo 錯誤時丟擲異常的 VariableTracker 方法非常有用。特別是,有時我們會發現一個物件被跟蹤為 UserDefinedObjectVariable(這是 Dynamo 的包羅永珍類),而它本應被跟蹤為更具體的型別。在這些情況下,通常是 SourceBuilder.__call__ 的邏輯問題。
除錯技巧。當使用 TORCH_LOGS=dynamo 執行程式時,輸出的其中一個資訊是以下形式的行
TRACE LOAD_GLOBAL y [TorchInGraphFunctionVariable(<built-in method any>), TensorVariable()]
這是原始程式的位元組碼以及當時棧的狀態。這對於查詢物件未被正確跟蹤到 VariableTracker 中的位置非常有用。
好的,我們現在有了跟蹤器的 IR,現在我們只需要重新實現 CPython 的棧機。這在 symbolic_convert.py 中的 InstructorTranslatorBase 中實現。
InstructionTranslatorBase 大約有 200 個方法,實現了幾乎所有的 Python 位元組碼。例如,我們可以看看 BUILD_LIST 的實現
def BUILD_LIST(self, inst):
items = self.popn(inst.argval)
self.push(ListVariable(items, mutation_type=ValueMutationNew()))
這是由 l = [2, 3, 4] 這樣的結構生成的位元組碼。在這種情況下,由於有三個元素,生成的位元組碼是 BUILD_LIST 3。這意味著我們彈出棧頂的 3 個元素,並將由這三個元素形成的新列表物件壓入棧頂。
生成輸出圖¶
有了符號性執行 Python 程式碼的方法,我們就可以提取給定輸入的程式在符號性執行過程中發生的 PyTorch 操作。這在 Dynamo 中透過 OutputGraph 物件實現。OutputGraph 物件繫結到一個 `InstructionTranslator 物件,它跟蹤建立 Dynamo 將返回的 FX 圖所需的所有資料。
FX 圖的所有輸入和中間元素都是 fx.Node。在 Dynamo 中,fx.Node 被封裝在 fx.Proxy 中。fx.Proxy 用於構建 FX 圖。特別是,它們會將在其上執行的每個 PyTorch 操作記錄到圖中。你可以透過呼叫 create_proxy 建立一個新的操作新增到圖中。然後,我們可以透過函式 wrap_fx_proxy 將其新增到圖中。
一個圖儲存張量上的操作……以及符號整數上的操作。我們稍後會討論符號整數,但首先我們討論 Dynamo 如何解決一個相當重要的正確性問題。
使 Dynamo 健全:Guards¶
至此,我們有了一種完全忽略控制流來跟蹤程式的方法。為此,我們重新實現了整個 CPython……如果這聽起來有點過度,那是因為確實如此。torch.jit.trace 已經實現了這一點,而無需所有這些機制,那這是為什麼呢?
torch.jit.trace 的問題在於,正如其文件中所警告的,它只適用於跟蹤的程式不依賴於資料的情況。換句話說,只有程式本身是線性的時候它才起作用。這意味著編寫程式時不能使用 if-else、for-while 迴圈、異常。更甚者,我們使用的任何庫都不能使用任何控制流!總而言之,在一個像 Python 這樣動態的語言中不使用控制流,實際上是一個巨大的限制。
JAX 透過始終重新跟蹤並在重新跟蹤後快取圖來解決這個問題。而 Dynamo 則使用 guards 來避免每次都重新跟蹤整個程式。
一個 guard 是為了針對一組示例輸入特殊化(specialize)一個幀而做出的假設(關於輸入的布林表示式)。只有當這些假設在新輸入上仍然成立時,重用該圖才有效。
例如,函式中的任何常量輸入,比如字串,都會安裝一個 guard,表明該輸入必須是型別 str 且等於我們傳遞的字串。執行
import torch
@torch.compile
def fn(a, b):
return a * len(b)
fn(torch.arange(10), "Hello")
使用 TORCH_LOGS=guards 會打印出(以及其他 guards)
___check_type_id(L['b'], 94334122025024)
L['b'] == 'Hello'
這可以解讀為“區域性變數 b 應該具有特定型別(在本例中為 str,由常量 9433... 表示)且其值應為 'Hello'”。如果我們隨後再次執行該函式並傳遞不同的引數
import torch
@torch.compile
def fn(a, b):
return a * len(b)
fn(torch.arange(10), "Hello")
fn(torch.arange(10), "Hi")
我們可以透過執行 TORCH_LOGS=recompiles 來檢視失敗的 guard
Recompiling function fn in script.py:3
triggered by the following guard failure(s):
- L['b'] == 'Hello'
Guards 在構建器中封裝函式輸入和程式執行期間累積。我們將在下一節展示更多 guards 的例子,但首先讓我們討論 sources。
一個 source 跟蹤如何從進入當前幀時存在的原始區域性變數或全域性變數重構一個變數。特別是,它跟蹤原始區域性物件和全域性物件以及它們包含的任何物件。在
def foo(x: Tensor, y: List[Tensor]):
a = x * y[0]
return a * x
x 和 y 的 source 是 LocalSource,而 y[0] 的 source 是 GetItemSource,後者內部儲存一個 LocalSource。另一方面,a 沒有 source,因為它是一個只存在於 FX 圖中的中間變數。
所有這些都定義在 torch/_dynamo/source.py 中。我們可以在下面的示例中看到 GetItemSource 生成的 guard
import torch
@torch.compile
def fn(x, l):
return x * len(l[0])
fn(torch.randn(8), ["Hi", "Hello"])
生成以下 guards
___check_type_id(L['l'], 94439025877664)
len(L['l']) == 2
___check_type_id(L['l'][0], 94439025840192)
L['l'][0] == 'Hi'
___check_type_id(L['l'][1], 94439025840192)
L['l'][1] == 'Hello'
這裡,我們看到 GetItemSource ([0] 和 [1]) 生成的程式碼,它封裝了一個 LocalSource (L['l'])。
至此,有了 sources 和 guards,我們就能夠實現一個快取系統,避免每次都重新跟蹤,從而避免重新編譯。我們將在後續內容中更詳細地討論這個快取系統。
細心的讀者會注意到,這並沒有解釋為什麼我們需要對 Python 直譯器進行如此精細的控制,以至於不得不重新實現它。我們展示的 guards 示例依賴於輸入物件,因此我們仍然可以在執行函式之前計算這些 guards。換句話說,我們可以在 torch.jit.trace 的基礎上實現這個 guard 系統,並以少得多的精力獲得相同的功能…… 這就需要引入符號形狀了。
符號形狀¶
我們在介紹中討論的另一點是 Dynamo 知道如何跟蹤整數。為了實現這一點,我們使用一個符號類 torch.SymInt,它表現得像一個 int,但在輸出 FX 圖中記錄了對其執行的所有操作。4 我們在介紹符號整數跟蹤時已經在介紹中看到了這個類。
現在讓我們討論定義 Dynamo 中符號形狀跟蹤的三個屬性,以及如何實現它們。
預設靜態¶
Dynamo 假定每個整數,無論是輸入還是張量的形狀,預設都是靜態的。換句話說,在函式的第一次執行中,不會追蹤任何整數。只有當 Dynamo 檢測到整數或形狀值在執行過程中發生了變化時,它才會對其進行追蹤,並生成一個針對該變數的通用圖。
我們已經在介紹中使用整數看到了這種行為。現在讓我們看一個使用張量形狀的例子。
import torch
@torch.compile
def fn(a, b):
return a.shape[0] * a * b
fn(torch.randn(4, 3), torch.randn(4, 3))
fn(torch.randn(8, 3), torch.randn(8, 3))
使用 TORCH_LOGS=graph_code 執行此程式,我們看到這兩個呼叫被追蹤為
def forward(self, l_a_: torch.Tensor, l_b_: torch.Tensor):
mul = 4 * l_a_
mul_1 = mul * l_b_
return (mul_1,)
def forward(self, s0: torch.SymInt, l_a_: torch.Tensor, l_b_: torch.Tensor):
size = l_a_.size()
getitem = size[0]
mul = getitem * l_a_
mul_1 = mul * l_b_
return (mul_1,)
在第一個圖中,形狀被追蹤為一個常量,但一旦它發生變化,它就會使用 SymInt 符號化地追蹤它。通常,檢視中間值形狀的更簡單方法是使用 TORCH_LOGS=graph_sizes 執行程式
TRACED GRAPH TENSOR SIZES
===== __compiled_fn_1 =====
l_a_: (s0, 3)
l_a_ (concrete): (8, 3)
l_b_: (s0, 3)
l_b_ (concrete): (8, 3)
mul: (s0, 3)
mul (concrete): (8, 3)
mul_1: (s0, 3)
mul_1 (concrete): (8, 3)
在這裡我們可以看到,由於它由 s0 變量表示,因此兩個張量引數的第一個維度是動態的。
我們可以透過執行 TORCH_LOGS=guards 來了解 Dynamo 如何實現這一點
# Guards first call
check_tensor(L['a'], torch.float32, device=None, requires_grad=False, size=[4, 3], stride=[3, 1])
check_tensor(L['b'], torch.float32, device=None, requires_grad=False, size=[4, 3], stride=[3, 1])
# Guards second call
check_tensor(L['a'], torch.float32, device=None, requires_grad=False, size=[None, 3], stride=[3, 1])
check_tensor(L['b'], torch.float32, device=None, requires_grad=False, size=[None, 3], stride=[3, 1])
L['b'].size()[0] == L['a'].size()[0]
2 <= L['a'].size()[0]
我們看到在第一次呼叫時,guards 檢查張量是否具有固定的尺寸和步長。這些 guards 在第二次執行中失敗,因此它會重新追蹤。由於失敗的是一個 int guard,因此在第二次迭代中,它會對這個 int 進行符號化追蹤,並在更通用的 kernel 上安裝更通用的 guards。
編譯效能提示。如果你知道某個維度的大小會變化,可以在呼叫 torch.compile 之前透過呼叫 torch._dynamo.mark_dynamic 將其標記為動態。這將避免第一次使用靜態形狀的編譯。還有其他有用的實用函式,如 maybe_mark_dynamic 或 mark_static。你還可以透過呼叫 torch.compile(dynamic=True) 來追蹤所有整數和形狀。這主要用於除錯目的。
0、1 總是會被特殊化¶
無論我們是否將某個維度標記為動態,如果我們傳入一個該維度為 0 或 1 的輸入,Dynamo 都會將其追蹤為非動態,併為其生成一個特定的圖。這就是為什麼在上面的例子中我們發現 guards 的形式是 2 <= L['a'].size()[0]。
做出這個選擇有幾個原因。其中兩個尤其重要:- 當且僅當張量的任一維度為零時,該張量為空。- 當且僅當張量的步長之一為一時,該張量才能是連續的。
此策略決定不適用於普通的 Python int;如果我們認為 Python int 應該動態編譯,我們預設不會將其特殊化;相反,它是否被特殊化取決於其用法。
“鴨子”形狀 (Duck shaping)¶
Dynamo 執行我們所謂的“鴨子”形狀 (duck shaping)。如果在追蹤時兩個動態整數具有相同的值,我們將假定它們相等並進行守衛 (guard)。實際上,這意味著在上面的示例中,我們不是擁有兩個符號 s0、s1,而是將它們統一為 s0 並設定守衛 L['b'].size()[0] == L['a'].size()[0]。這使得能夠在編譯器內執行融合,同時能夠生成足夠通用的 kernel。
符號整數上的守衛 (Guards on symbolic ints)¶
我們現在在高層次上理解了符號形狀是如何實現的以及它們具有的屬性。那麼,為什麼符號形狀迫使我們走上控制 CPython 直譯器的棘手道路呢?考慮以下示例:
import torch
@torch.compile(dynamic=True)
def fn(a):
if a.shape[0] * 2 < 16:
return a
else:
return a + 1
fn(torch.randn(8))
此程式碼有一個形式為 2*L['a'].size()[0] >= 16 的守衛。這是一個在函式輸入方面非平凡的守衛,但在程式執行過程中註冊。更重要的是,我們直到看到依賴於 SymNodeVariable 引數的 if 語句條件時,才知道需要這個守衛。這些條件對於 torch.jit.trace 是不可見的,需要對 Python 程式碼進行深入分析。
除錯技巧 使用 TORCH_LOGS=dynamo 執行此程式碼可以告訴我們這個守衛是在哪裡新增的
eval 2*s0 >= 16 [guard added] at script.py:5 in fn (_dynamo/variables/tensor.py:812 in evaluate_expr)
在那裡設定一個斷點並查看回溯對於理解守衛來自何處非常有用。
使 Dynamo 完整:圖中斷 (Graph Breaks)¶
有了我們討論過的所有工具,我們現在擁有一個能夠追蹤張量和整數上的 PyTorch 操作的追蹤器,並且它具有一個快取系統,知道何時可以重用之前追蹤的圖以及何時需要重新追蹤。所有這些都能執行任意 Python 程式碼!
但這有一個小問題。“執行任意 Python 程式碼”的說法可能過於寬泛了。Dynamo 實現了 Python 的大部分功能,但它是否實現了更復雜的部分,比如協程 (coroutines) 或非同步 (async)?它是否實現了整個 Python 標準庫?NumPy 也有 Python API。torch.compile 是否也能理解 NumPy?還有 Django? 5
Python 的生態系統非常龐大,其中很大一部分是用 C++ 或 Rust 等效能更高的語言編寫的,並且只暴露了 Python 繫結。Dynamo 無法追蹤透過 C++ 實現的 Python 物件。當追蹤器遇到它不理解的操作時,它能做什麼?
機器學習追蹤器處理這個問題通常的方式是告知使用者它們在哪個操作上遇到了困難,並完全放棄追蹤。這在 PyTorch 中會帶來實際的可用性問題,因為 PyTorch 的使用者習慣了它提供的靈活性。舉一個現實世界的例子,doctr_det_predictor 模型使用了 NumPy 和 cv2 庫來對模型結果進行後處理。
這是另一個訪問 CPython 很有意義的地方。Dynamo 不會報錯,而是可以讓 CPython 執行那段有問題程式碼!為此,Dynamo 在追蹤時生成一個包含有問題程式碼之前所有操作的圖,以及一個包含有問題程式碼之後所有操作的圖。6 然後,在執行時,它將委託給 CPython 執行第一個圖,然後是有問題的程式碼,最後是第二個圖。停止追蹤並生成多個圖的過程稱為 圖中斷 (graph break)。
一個小小的坦白:我在整個介紹和前幾節中都在撒謊。Dynamo 生成的不是一個圖,而是 多個圖!實際上,將圖中斷後重新開始追蹤視為開始追蹤一個新的函式。圖中斷後的新圖將有自己的 guards、新的區域性變數集等等。
要討論如何實現圖中斷,我們需要首先回顧 Dynamo 如何與 CPython 互動。使用 PEP 523,CPython 允許使用者使用自己的幀評估機制。我們之前沒有討論的是,CPython 也暴露了自己的幀評估供其他人使用。Dynamo 利用這一點,讓快速的 CPython 直譯器執行編譯後的程式碼。對於一個沒有圖中斷的函式,程式呼叫該函式兩次且引數相同時的整個追蹤/執行過程如下所示:
在第一次呼叫函式時
Dynamo 將函式追蹤成一個 FX 圖
FX 圖由編譯器 (Inductor) 編譯成高效的底層程式碼……但這又是另一天的故事了
它重寫函式的位元組碼,使其只需呼叫編譯後的函式
它將這段新的位元組碼交給 CPython 並要求它執行 [此處]
在第二次呼叫函式時
這個過程本身看起來過於複雜。為什麼生成新的位元組碼並要求 CPython 執行,而不是簡單地建立一個到編譯函式的 C++ 繫結並執行它呢?嗯,這種模式使我們能夠實現圖中斷!由圖中斷生成的位元組碼具有以下結構:
執行第一個圖的位元組碼
使棧狀態與 CPython 執行第一個圖後相同的位元組碼。它還會重放在此刻可見的區域性或全域性變數的任何修改
導致 Dynamo 圖中斷的位元組碼
執行第二個圖的位元組碼
讓我們透過一個簡單的例子來看看
import torch
@torch.compile
def fn(a):
b = a + 2
print("Hi")
return b + a
fn(torch.randn(4))
使用 TORCH_LOGS=bytecode 執行此程式會顯示初始位元組碼和修改後的位元組碼
MODIFIED BYTECODE fn script.py line 3
0 LOAD_GLOBAL 1 (__compiled_fn_0)
2 LOAD_FAST 0 (a)
4 CALL_FUNCTION 1
6 STORE_FAST 3 (graph_out_0)
8 LOAD_GLOBAL 0 (print)
10 LOAD_CONST 2 ('Hi')
12 LOAD_FAST 3 (graph_out_0)
14 LOAD_CONST 3 (0)
16 BINARY_SUBSCR
18 STORE_FAST 1 (b)
20 CALL_FUNCTION 1
22 LOAD_GLOBAL 2 (__resume_at_14_1)
24 ROT_TWO
26 LOAD_FAST 0 (a)
28 LOAD_FAST 1 (b)
30 CALL_FUNCTION 3
32 RETURN_VALUE
MODIFIED BYTECODE resume_in_fn script.py line 6
0 LOAD_GLOBAL 1 (__compiled_fn_2)
2 LOAD_FAST 2 (b)
4 LOAD_FAST 1 (a)
6 CALL_FUNCTION 2
8 UNPACK_SEQUENCE 1
10 RETURN_VALUE
我們可以看到修改後的位元組碼被分割成兩個函式:原始函式 fn,以及一個名為 resume_in_fn 的函式。第二個函式是 Dynamo 為了實現程式從圖中斷點開始執行而建立的函式。這通常被稱為延續函式 (continuation function)。這個延續函式只需使用正確的引數呼叫第二個編譯後的函式。初始函式的程式碼根據我們之前描述的策略進行了重寫
L0-4. 呼叫編譯後的函式 (
a + 2)。L6. 將其結果儲存在一個名為
graph_out_0的區域性變數中。graph_out_0是一個元組L8-18. 使棧在圖中斷點保持其應有的狀態
L20. 執行導致圖中斷的程式碼
L22-32. 呼叫編譯後的延續函式 (
a + b)
Dynamo 中棧的程式碼生成被委託給 VariableTracker 子類。Dynamo 中的每個 VariableTracker 物件都有一個 reconstruct 方法,該方法生成必要的位元組碼以在棧上建立它所代表的 Python 物件。
除錯技巧。圖中斷會影響效能,因此最好避免它們。使用 TORCH_LOGS=graph_breaks 執行程式是找出程式發生了多少次圖中斷的好方法。它返回的資訊是以 VariableTracker 物件的形式呈現的,因此上面的除錯技巧有時也有助於弄清楚是什麼導致了圖中斷。
結論¶
Dynamo 是一塊複雜的軟體。一旦你著手實現 CPython 直譯器,你就知道這將是一段不尋常的旅程。話雖如此,我們希望這篇文章能幫助你揭開它的一些神秘面紗。
Dynamo(大部分)是用 Python 實現的。我們留下了許多討論過的程式碼片段的連結。我們希望閱讀這些程式碼片段,並搜尋呼叫它們的地方,或在它們上面設定斷點並檢視呼叫棧,有助於理解其餘的程式碼庫。
當然,學習軟體如何工作的最佳方法是擴充套件它。在這種情況下,最好的方法是檢視 github 上的 Dynamo 未解決問題。其中許多隻需要對程式碼進行很小的更改,一旦你找到需要進行這些更改的地方。
腳註¶
- 1
在文獻中,這被稱為有向無環圖 (Directed Acyclical Graph, DAG)。
- 2
所有這些繫結程式碼都位於
torch/csrc/dynamo/eval_frame.c中。- 3
在 CPython 術語中,所有這些物件的集合稱為一個 frame。
- 4
還有
SymBool和SymFloat類。後一個在撰寫本文時用得不多。- 5
有趣的是,它確實理解 NumPy 程式碼!看看這篇部落格文章和文件。現在,這之所以可能,是因為我們使用 PyTorch 重新實現了 NumPy。不過,祝你將 Django 用 PyTorch 實現順利……
- 6
假設只有一段問題程式碼。如果問題程式碼更多,Dynamo 可以將程式碼分割成所需數量的圖。