編譯後的Autograd:為torch.compile捕獲更大的反向圖¶
建立日期:Oct 09, 2024 | 最後更新:Oct 23, 2024 | 最後驗證:Oct 09, 2024
作者: Simon Fan
編譯後的autograd如何與
torch.compile互動如何使用編譯後的autograd API
如何使用
TORCH_LOGS檢查日誌
PyTorch 2.4
閱讀PyTorch 2.x入門中的TorchDynamo和AOTAutograd部分
概述¶
編譯後的Autograd是PyTorch 2.4中引入的一個torch.compile擴充套件,它允許捕獲更大的反向圖。
雖然torch.compile確實會捕獲反向圖,但它是部分捕獲的。AOTAutograd元件會提前捕獲反向圖,但也存在某些限制:
前向過程中的圖中斷會導致反向過程中的圖中斷
反向鉤子不會被捕獲
編譯後的Autograd透過直接與autograd引擎整合來解決這些限制,使其能夠在執行時捕獲完整的反向圖。具有這兩個特徵的模型應該嘗試編譯後的Autograd,並有可能觀察到更好的效能。
然而,編譯後的Autograd也引入了自己的限制:
在反向傳播開始時增加快取查詢的執行時開銷
由於捕獲更大,更容易在dynamo中導致重新編譯和圖中斷
注意
編譯後的Autograd正在積極開發中,尚未相容所有現有的PyTorch功能。有關特定功能的最新狀態,請參閱編譯後的Autograd登陸頁面。
設定¶
在本教程中,我們將基於這個簡單的神經網路模型構建示例。它接受一個10維的輸入向量,透過一個線性層進行處理,並輸出另一個10維的向量。
import torch
class Model(torch.nn.Module):
def __init__(self):
super().__init__()
self.linear = torch.nn.Linear(10, 10)
def forward(self, x):
return self.linear(x)
基本用法¶
在呼叫torch.compile API之前,請確保將torch._dynamo.config.compiled_autograd設定為True
model = Model()
x = torch.randn(10)
torch._dynamo.config.compiled_autograd = True
@torch.compile
def train(model, x):
loss = model(x).sum()
loss.backward()
train(model, x)
在上面的程式碼中,我們建立了一個Model類的例項,並使用torch.randn(10)生成了一個隨機的10維張量x。我們定義了訓練迴圈函式train,並用@torch.compile裝飾它來最佳化其執行。當呼叫train(model, x)時:
Python直譯器呼叫Dynamo,因為此呼叫已用
@torch.compile裝飾。Dynamo攔截Python位元組碼,模擬其執行並將操作記錄到圖中。
AOTDispatcher停用鉤子並呼叫autograd引擎來計算model.linear.weight和model.linear.bias的梯度,並將操作記錄到圖中。使用torch.autograd.Function,AOTDispatcher重寫train的前向和反向實現。Inductor生成一個函式,該函式對應於AOTDispatcher前向和反向的最佳化實現。
Dynamo設定最佳化後的函式,由Python直譯器下一步評估。
Python直譯器執行最佳化後的函式,該函式執行
loss = model(x).sum()。Python直譯器執行
loss.backward(),呼叫autograd引擎,由於我們設定了torch._dynamo.config.compiled_autograd = True,因此會路由到編譯後的Autograd引擎。編譯後的Autograd計算
model.linear.weight和model.linear.bias的梯度,並將操作記錄到圖中,包括遇到的任何鉤子。在此過程中,它將記錄先前由AOTDispatcher重寫的反向傳播。然後,編譯後的Autograd生成一個新函式,該函式對應於loss.backward()的完全追蹤實現,並在推理模式下使用torch.compile執行它。同樣的步驟遞迴地應用於編譯後的Autograd圖,但這次AOTDispatcher將不再需要對圖進行分割槽。
檢查編譯後的autograd日誌¶
使用TORCH_LOGS環境變數執行指令碼
僅列印編譯後的autograd圖,使用
TORCH_LOGS="compiled_autograd" python example.py以犧牲效能為代價,列印包含更多張量元資料和重新編譯原因的圖,使用
TORCH_LOGS="compiled_autograd_verbose" python example.py
重新執行上面的程式碼片段,編譯後的autograd圖現在應該被記錄到stderr中。某些圖節點的名稱將帶有aot0_字首,這些對應於先前在AOTAutograd反向圖0中提前編譯的節點,例如,aot0_view_2對應於id=0的AOT反向圖的view_2。
在下面的圖片中,紅色框封裝了在沒有編譯後的Autograd情況下被torch.compile捕獲的AOT反向圖。
注意
這是我們將呼叫torch.compile的圖,不是最佳化後的圖。編譯後的Autograd本質上生成一些未最佳化的Python程式碼來表示整個C++ autograd執行。
使用不同標誌編譯前向和反向傳播¶
你可以對兩次編譯使用不同的編譯器配置,例如,即使前向傳播中存在圖中斷,反向傳播也可能是一個全圖(fullgraph)。
def train(model, x):
model = torch.compile(model)
loss = model(x).sum()
torch._dynamo.config.compiled_autograd = True
torch.compile(lambda: loss.backward(), fullgraph=True)()
或者你可以使用上下文管理器,它將應用於其作用域內的所有autograd呼叫。
def train(model, x):
model = torch.compile(model)
loss = model(x).sum()
with torch._dynamo.compiled_autograd.enable(torch.compile(fullgraph=True)):
loss.backward()
編譯後的Autograd解決了AOTAutograd的某些限制¶
前向傳播中的圖中斷不再必然導致反向傳播中的圖中斷
@torch.compile(backend="aot_eager")
def fn(x):
# 1st graph
temp = x + 10
torch._dynamo.graph_break()
# 2nd graph
temp = temp + 10
torch._dynamo.graph_break()
# 3rd graph
return temp.sum()
x = torch.randn(10, 10, requires_grad=True)
torch._dynamo.utils.counters.clear()
loss = fn(x)
# 1. base torch.compile
loss.backward(retain_graph=True)
assert(torch._dynamo.utils.counters["stats"]["unique_graphs"] == 3)
torch._dynamo.utils.counters.clear()
# 2. torch.compile with compiled autograd
with torch._dynamo.compiled_autograd.enable(torch.compile(backend="aot_eager")):
loss.backward()
# single graph for the backward
assert(torch._dynamo.utils.counters["stats"]["unique_graphs"] == 1)
在第一個torch.compile案例中,我們看到由於編譯函式fn中的2個圖中斷,產生了3個反向圖。而在使用編譯後的autograd的第二個torch.compile案例中,我們看到儘管有圖中斷,仍然追蹤到了一個完整的反向圖。
注意
Dynamo在追蹤由編譯後的Autograd捕獲的反向鉤子時,仍然可能發生圖中斷。
反向鉤子現在可以被捕獲
@torch.compile(backend="aot_eager")
def fn(x):
return x.sum()
x = torch.randn(10, 10, requires_grad=True)
x.register_hook(lambda grad: grad+10)
loss = fn(x)
with torch._dynamo.compiled_autograd.enable(torch.compile(backend="aot_eager")):
loss.backward()
圖中應該有一個call_hook節點,Dynamo稍後會將其內聯到以下內容
編譯後的Autograd常見的重新編譯原因¶
由於loss值的autograd結構發生變化
torch._dynamo.config.compiled_autograd = True
x = torch.randn(10, requires_grad=True)
for op in [torch.add, torch.sub, torch.mul, torch.div]:
loss = op(x, x).sum()
torch.compile(lambda: loss.backward(), backend="eager")()
在上面的示例中,我們在每次迭代時呼叫不同的運算子,導致loss每次跟蹤不同的autograd歷史。你應該會看到一些重新編譯訊息:Cache miss due to new autograd node。
由於張量形狀發生變化
torch._dynamo.config.compiled_autograd = True
for i in [10, 100, 10]:
x = torch.randn(i, i, requires_grad=True)
loss = x.sum()
torch.compile(lambda: loss.backward(), backend="eager")()
在上面的示例中,x的形狀發生變化,編譯後的autograd會在第一次變化後將x標記為動態形狀張量。你應該會看到重新編譯訊息:Cache miss due to changed shapes。
結論¶
在本教程中,我們回顧了帶有編譯後的autograd的torch.compile的高層生態系統、編譯後的autograd的基礎知識以及一些常見的重新編譯原因。請關注dev-discuss上的深入探討。