快捷方式

PyTorch 2.0 NNModule 支援

作者: Will Constable

torch.compile 對 torch.nn.Module 物件有特殊處理,對其追蹤方式與追蹤任意 Python 類不同,目的是透過對結構進行假設來生成更快的程式碼。

本文件描述了這種特殊處理所帶來的一些權衡和邊緣情況。

NNModule 鉤子支援

以前,torch.compile 不支援 nn.Module 上的鉤子,如果註冊了鉤子,它們在編譯後的程式中會被直接忽略。事實上,許多使用者根本不使用 nn.Module 鉤子,或僅將其用於除錯工作流程,但將 nn.Module 鉤子與 torch.compile 結合使用存在有效的使用場景。

透過 nn.Module.__call__ 實現編排的鉤子包括 _forward_pre_hooksforward_hooks_backward_pre_hooks_backward_hooks,這些鉤子將被稱為“呼叫鉤子”。torch.compile 部分支援這些鉤子,但存在下述限制。

另一類鉤子包括 _state_dict_hooks 及其 preload_ 變體,這些鉤子仍不受 torch.compile 支援。

nn.Module.__call__ 鉤子使用及限制

預設情況下,torch.compile 會追蹤 nn.Module.__call__ 的內容,這意味著它會遇到並執行前向/前向預鉤子。如果您在呼叫 torch.compile 之前安裝鉤子,並且之後不移除或更改這些鉤子,則您的使用場景預設應受支援。

反向/反向預鉤子通常也受支援,但也存在類似注意事項:當前在 Dynamo 中訪問 backward_hooks 字典時會發生圖中斷 (graph-breaks),這可能透過一些工作可以避免。圖中斷還會影響反向鉤子的觸發時機,因為圖段會作為自動微分函式 (autograd-functions) 執行,它們同時產生所有梯度 (grads)。假設 Dynamo 可以在存在反向鉤子的情況下不發生圖中斷,我們仍然期望一系列模組的反向鉤子在整個編譯圖的反向傳播執行後一起觸發。

“允許模組”(allowed modules) 上的鉤子 torch.compile 會特殊處理常見模組(例如 torch.conv)以及難以追蹤的模組,方法是允許它們在 Dynamo 圖中以不透明方式呼叫,而不是被 Dynamo 追蹤進入。對於此類模組,鉤子目前會觸發圖中斷 (graph-break),導致受影響的模組在 Dynamo 外部執行。根據模型不同,這可能會導致顯著的效能下降 (performance regression),需要額外工作來改進此支援。

skip_nnmodule_hook_guards 預設情況下,torch._dynamo.config.skip_nnmodule_hook_guards 設定為 True,這意味著不會在每個 nn.Module 鉤子字典上安裝守衛 (guards),從而透過減少守衛執行時間來提高執行時效能,代價是編譯後如果任何鉤子字典發生更改,將不會被注意到。

如果您希望在編譯後能夠移除或修改鉤子,並讓 torch.compile 做出適當反應(透過重新編譯),則需要將 skip_nnmodule_hook_guards=False,並預計因新增守衛而產生的執行時開銷 (runtime penalty)。

TODO: 確認反向/反向預鉤子是否工作,並相應地更新文件

state_dict 鉤子

torch.compile 尚不支援 state dict 鉤子。

TODO: 如果在鉤子上發生圖中斷,則警告一次。如果存在鉤子,則警告一次並指向本文件。

文件

查閱 PyTorch 的全面開發者文件

檢視文件

教程

獲取面向初學者和高階開發者的深入教程

檢視教程

資源

查詢開發資源並獲得問題解答

檢視資源