捷徑

PyTorch 2.0 NNModule 支援

作者Will Constable

torch.compile 對 torch.nn.Module 物件有特殊的處理方式,會以不同於追蹤任意 Python 類別的方式追蹤它們,目的是透過對結構做出假設來產生更快的程式碼。

本文件說明由於這種特殊化而產生的一些權衡或邊緣情況。

NNModule 鉤子支援

以前,torch.compile 不支援 nn.Modules 上的鉤子,如果註冊了鉤子,則在編譯後的程式中只會忽略它們。的確,許多使用者根本不使用 nn.Module 鉤子,或者只將它們用於偵錯工作流程,但有一些有效的用例可以將 nn.Module 鉤子與 torch.compile 組合使用。

透過 nn.Module.__call__ 實作協調的鉤子包括 _forward_pre_hooksforward_hooks_backward_pre_hooks_backward_hooks,以下稱為「呼叫鉤子」。torch.compile 部分支援這些鉤子,但有以下限制。

另一類鉤子包括 _state_dict_hooks 及其 preload_ 變體,torch.compile 仍不支援這些鉤子。

nn.Module.__call__ 鉤子用法和限制

根據預設,torch.compile 會追蹤 nn.Module.__call__ 的內容,這表示它會遇到並執行 forward/pre-forward 鉤子。如果在呼叫 torch.compile 之前安裝鉤子,然後之後沒有移除或更改鉤子,則預設情況下應該支援您的用例。

一般也支援 Backward/Pre-backward 鉤子,但有一些注意事項:目前在存取 backward_hooks 字典時,dynamo 中會發生圖形中斷,這可能可以透過一些工作來避免。圖形中斷也會影響觸發 backward 鉤子的時機,因為圖形區段是作為 autograd 函數執行的,這些函數會同時產生所有梯度。假設 dynamo 可以不因 backward-hooks 的存在而中斷圖形,我們仍然希望一系列模組的 backward 鉤子在整個編譯圖形的 backward 執行後一起觸發。

「允許模組」上的鉤子 torch.compile 會特殊處理常見的模組,例如 torch.conv,以及難以追蹤的模組,方法是允許在 dynamo 圖形中不透明地呼叫它們,而不是由 dynamo 追蹤到它們。對於此類模組,鉤子目前會觸發圖形中斷,以便受影響的模組在 dynamo 之外執行。根據模型的不同,這可能會導致顯著的效能下降,需要額外的工作來改進此支援。

skip_nnmodule_hook_guards 根據預設,torch._dynamo.config.skip_nnmodule_hook_guards 設定為 True,這表示不會在每個 nn.Module 鉤子字典上安裝防護,透過減少防護執行時間來提高執行階段效能,但代價是在編譯後不會注意到任何鉤子字典是否已更改。

如果您希望能夠在編譯後移除或修改鉤子,並讓 torch.compile 做出適當的反應(透過重新編譯),則需要設定 skip_nnmodule_hook_guards=False,並預期會增加執行階段的防護成本。

TODO:確認 backward/pre_backward 鉤子是否正常運作,並據此記錄。

state_dict 鉤子

torch.compile 尚未支援 State dict 鉤子。

TODO:如果在鉤子上中斷圖形,則發出一次警告。如果存在鉤子,則發出一次警告以指向本文件。

文件

存取 PyTorch 的完整開發人員文件

檢視文件

教學課程

取得適用於初學者和進階開發人員的深入教學課程

檢視教學課程

資源

尋找開發資源並獲得問題解答

檢視資源