• 教程 >
  • 編譯後的Autograd:為torch.compile捕獲更大的反向圖
快捷方式

編譯後的Autograd:為torch.compile捕獲更大的反向圖

建立日期:Oct 09, 2024 | 最後更新:Oct 23, 2024 | 最後驗證:Oct 09, 2024

作者: Simon Fan

你將學到什麼
  • 編譯後的autograd如何與torch.compile互動

  • 如何使用編譯後的autograd API

  • 如何使用TORCH_LOGS檢查日誌

先決條件

概述

編譯後的Autograd是PyTorch 2.4中引入的一個torch.compile擴充套件,它允許捕獲更大的反向圖。

雖然torch.compile確實會捕獲反向圖,但它是部分捕獲的。AOTAutograd元件會提前捕獲反向圖,但也存在某些限制:

  • 前向過程中的圖中斷會導致反向過程中的圖中斷

  • 反向鉤子不會被捕獲

編譯後的Autograd透過直接與autograd引擎整合來解決這些限制,使其能夠在執行時捕獲完整的反向圖。具有這兩個特徵的模型應該嘗試編譯後的Autograd,並有可能觀察到更好的效能。

然而,編譯後的Autograd也引入了自己的限制:

  • 在反向傳播開始時增加快取查詢的執行時開銷

  • 由於捕獲更大,更容易在dynamo中導致重新編譯和圖中斷

注意

編譯後的Autograd正在積極開發中,尚未相容所有現有的PyTorch功能。有關特定功能的最新狀態,請參閱編譯後的Autograd登陸頁面

設定

在本教程中,我們將基於這個簡單的神經網路模型構建示例。它接受一個10維的輸入向量,透過一個線性層進行處理,並輸出另一個10維的向量。

import torch

class Model(torch.nn.Module):
   def __init__(self):
      super().__init__()
      self.linear = torch.nn.Linear(10, 10)

   def forward(self, x):
      return self.linear(x)

基本用法

在呼叫torch.compile API之前,請確保將torch._dynamo.config.compiled_autograd設定為True

model = Model()
x = torch.randn(10)

torch._dynamo.config.compiled_autograd = True
@torch.compile
def train(model, x):
   loss = model(x).sum()
   loss.backward()

train(model, x)

在上面的程式碼中,我們建立了一個Model類的例項,並使用torch.randn(10)生成了一個隨機的10維張量x。我們定義了訓練迴圈函式train,並用@torch.compile裝飾它來最佳化其執行。當呼叫train(model, x)時:

  • Python直譯器呼叫Dynamo,因為此呼叫已用@torch.compile裝飾。

  • Dynamo攔截Python位元組碼,模擬其執行並將操作記錄到圖中。

  • AOTDispatcher停用鉤子並呼叫autograd引擎來計算model.linear.weightmodel.linear.bias的梯度,並將操作記錄到圖中。使用torch.autograd.Function,AOTDispatcher重寫train的前向和反向實現。

  • Inductor生成一個函式,該函式對應於AOTDispatcher前向和反向的最佳化實現。

  • Dynamo設定最佳化後的函式,由Python直譯器下一步評估。

  • Python直譯器執行最佳化後的函式,該函式執行loss = model(x).sum()

  • Python直譯器執行loss.backward(),呼叫autograd引擎,由於我們設定了torch._dynamo.config.compiled_autograd = True,因此會路由到編譯後的Autograd引擎。

  • 編譯後的Autograd計算model.linear.weightmodel.linear.bias的梯度,並將操作記錄到圖中,包括遇到的任何鉤子。在此過程中,它將記錄先前由AOTDispatcher重寫的反向傳播。然後,編譯後的Autograd生成一個新函式,該函式對應於loss.backward()的完全追蹤實現,並在推理模式下使用torch.compile執行它。

  • 同樣的步驟遞迴地應用於編譯後的Autograd圖,但這次AOTDispatcher將不再需要對圖進行分割槽。

檢查編譯後的autograd日誌

使用TORCH_LOGS環境變數執行指令碼

  • 僅列印編譯後的autograd圖,使用TORCH_LOGS="compiled_autograd" python example.py

  • 以犧牲效能為代價,列印包含更多張量元資料和重新編譯原因的圖,使用TORCH_LOGS="compiled_autograd_verbose" python example.py

重新執行上面的程式碼片段,編譯後的autograd圖現在應該被記錄到stderr中。某些圖節點的名稱將帶有aot0_字首,這些對應於先前在AOTAutograd反向圖0中提前編譯的節點,例如,aot0_view_2對應於id=0的AOT反向圖的view_2

在下面的圖片中,紅色框封裝了在沒有編譯後的Autograd情況下被torch.compile捕獲的AOT反向圖。

../_images/entire_verbose_log.png

注意

這是我們將呼叫torch.compile的圖,不是最佳化後的圖。編譯後的Autograd本質上生成一些未最佳化的Python程式碼來表示整個C++ autograd執行。

使用不同標誌編譯前向和反向傳播

你可以對兩次編譯使用不同的編譯器配置,例如,即使前向傳播中存在圖中斷,反向傳播也可能是一個全圖(fullgraph)。

def train(model, x):
    model = torch.compile(model)
    loss = model(x).sum()
    torch._dynamo.config.compiled_autograd = True
    torch.compile(lambda: loss.backward(), fullgraph=True)()

或者你可以使用上下文管理器,它將應用於其作用域內的所有autograd呼叫。

def train(model, x):
   model = torch.compile(model)
   loss = model(x).sum()
   with torch._dynamo.compiled_autograd.enable(torch.compile(fullgraph=True)):
      loss.backward()

編譯後的Autograd解決了AOTAutograd的某些限制

  1. 前向傳播中的圖中斷不再必然導致反向傳播中的圖中斷

@torch.compile(backend="aot_eager")
def fn(x):
   # 1st graph
   temp = x + 10
   torch._dynamo.graph_break()
   # 2nd graph
   temp = temp + 10
   torch._dynamo.graph_break()
   # 3rd graph
   return temp.sum()

x = torch.randn(10, 10, requires_grad=True)
torch._dynamo.utils.counters.clear()
loss = fn(x)

# 1. base torch.compile
loss.backward(retain_graph=True)
assert(torch._dynamo.utils.counters["stats"]["unique_graphs"] == 3)
torch._dynamo.utils.counters.clear()

# 2. torch.compile with compiled autograd
with torch._dynamo.compiled_autograd.enable(torch.compile(backend="aot_eager")):
   loss.backward()

# single graph for the backward
assert(torch._dynamo.utils.counters["stats"]["unique_graphs"] == 1)

在第一個torch.compile案例中,我們看到由於編譯函式fn中的2個圖中斷,產生了3個反向圖。而在使用編譯後的autograd的第二個torch.compile案例中,我們看到儘管有圖中斷,仍然追蹤到了一個完整的反向圖。

注意

Dynamo在追蹤由編譯後的Autograd捕獲的反向鉤子時,仍然可能發生圖中斷。

  1. 反向鉤子現在可以被捕獲

@torch.compile(backend="aot_eager")
def fn(x):
   return x.sum()

x = torch.randn(10, 10, requires_grad=True)
x.register_hook(lambda grad: grad+10)
loss = fn(x)

with torch._dynamo.compiled_autograd.enable(torch.compile(backend="aot_eager")):
   loss.backward()

圖中應該有一個call_hook節點,Dynamo稍後會將其內聯到以下內容

../_images/call_hook_node.png

編譯後的Autograd常見的重新編譯原因

  1. 由於loss值的autograd結構發生變化

torch._dynamo.config.compiled_autograd = True
x = torch.randn(10, requires_grad=True)
for op in [torch.add, torch.sub, torch.mul, torch.div]:
   loss = op(x, x).sum()
   torch.compile(lambda: loss.backward(), backend="eager")()

在上面的示例中,我們在每次迭代時呼叫不同的運算子,導致loss每次跟蹤不同的autograd歷史。你應該會看到一些重新編譯訊息:Cache miss due to new autograd node

../_images/recompile_due_to_node.png
  1. 由於張量形狀發生變化

torch._dynamo.config.compiled_autograd = True
for i in [10, 100, 10]:
   x = torch.randn(i, i, requires_grad=True)
   loss = x.sum()
   torch.compile(lambda: loss.backward(), backend="eager")()

在上面的示例中,x的形狀發生變化,編譯後的autograd會在第一次變化後將x標記為動態形狀張量。你應該會看到重新編譯訊息:Cache miss due to changed shapes

../_images/recompile_due_to_dynamic.png

結論

在本教程中,我們回顧了帶有編譯後的autograd的torch.compile的高層生態系統、編譯後的autograd的基礎知識以及一些常見的重新編譯原因。請關注dev-discuss上的深入探討。

文件

訪問PyTorch的全面開發者文件

檢視文件

教程

獲取面向初學者和高階開發者的深入教程

檢視教程

資源

查詢開發資源並獲取問題解答

檢視資源