動態形狀¶
另請參閱:動態形狀手冊
動機¶
深度學習編譯器通常只適用於靜態形狀,也就是說,它們產生的已編譯程式只適用於單一特定的輸入形狀配置,如果任何輸入形狀改變,就必須重新編譯。這種假設對於現今大多數常用的深度學習模型來說非常有效,但在某些情況下,它是不夠的
- 某些維度,例如批次大小或序列長度,可能會有所不同。例如,執行自適應批次的推論服務將根據其在批次視窗內收到的請求數量,以不同的批次大小執行推論請求。我們也可能希望考慮將可變大小的序列填充到批次內的最大序列長度,而最大序列長度可能會因批次而異。 
- 有些模型呈現出資料相關的輸出形狀,也就是說,它們的輸出和中間值的大小可能會取決於實際的輸入資料,而這些資料在不同的執行過程中可能會有所不同。例如,偵測模型可能會先產生可變數量的潛在邊界框,然後再執行更昂貴的影像辨識模型,以識別主體是否位於邊界框內。邊界框的數量取決於資料。 
- 在處理稀疏表示法(例如稀疏張量、不規則張量和圖形神經網路)時,會出現一個特別重要的資料相關形狀案例。在所有這些情況下,要處理的資料量取決於問題的稀疏結構,而這些結構通常會以資料相關的方式變化。 
在支援動態形狀時,我們選擇不支援動態秩程式,例如輸入張量在維度上發生變化的程式,因為這種模式在真實世界的深度學習程式中很少出現,而且它避免了對符號形狀清單進行歸納推理的需要。
簡要的公開 API¶
PyTorch 2.1 中的預設動態行為是
- PT2 預設假設所有東西都是靜態的 
- 如果我們因為大小改變而重新編譯,我們會嘗試將該大小重新編譯為動態的(已經改變的大小在未來可能會再次改變)。這種概括化可能會失敗(例如,因為使用者程式碼對相關大小進行了條件分支,或者 PT2 中缺少對動態形狀的支援)。如果您想瞭解為什麼 PT2 會過度特化某些程式碼,請使用 - TORCH_LOGS=dynamic執行,並尋找說明何時新增防護以及原因的「評估」條目。
- 如果您事先知道某些東西會是動態的,您可以使用 - torch._dynamo.mark_dynamic(tensor, dim)來跳過第一次重新編譯。如果您事先知道這個維度可以取值的- min和- max值,您可以指定- torch._dynamo.mark_dynamic(tensor, dim, min=min, max=max)
- 如果您說 - torch.compile(dynamic=False),我們將關閉重新編譯時的自動動態形狀,並始終針對每個不同的大小進行重新編譯。相反地,如果您說- torch.compile(dynamic=True),我們將盡可能地將所有東西都設為動態的。這對於小型運算子來說非常有用;如果您嘗試在大型模型上使用它,它可能會 (1) 導致 PT2 崩潰,以及 (2) 毫無理由地執行緩慢。
防護模型¶
在考慮如何為 TorchDynamo 和 TorchInductor 新增對動態形狀的支援時,我們做了一個重大的設計決策:為了重複使用以 Python/C++ 編寫並以 PyTorch API 為目標的分解和其他現有程式碼,我們必須能夠追蹤動態形狀。與可能擷取條件式兩個分支的完全符號系統不同,我們始終選擇一個分支,並在假設我們只在將來會為該分支做出相同選擇時才使用此追蹤的情況下,特化我們的追蹤。為此,我們為每個符號大小維護一個「提示」,說明其在編譯時的具體值(由於 TorchDynamo 是一個即時編譯器,因此它始終知道實際的輸入大小)。當我們對張量執行條件時,我們只需查閱提示即可找出要採取哪個分支。
這大大簡化了我們產生的符號形狀公式,但這表示我們有一個更複雜的系統來管理防護。例如,請考慮以下程式
def f(x, y):
    z = torch.cat([x, y])
    if z.size(0) > 2:
        return z.mul(2)
    else:
        return z.add(2)
我們將使用 TorchInductor 編譯的最終 IR 將是 torch.cat([x, y]).add(2) 或 torch.cat([x, y]).mul(2)(條件已扁平化),但要確定我們在哪個分支中,我們需要知道 z 的大小,這是一個中間值。由於 TorchDynamo 必須事先知道編譯後的追蹤是否有效(我們不支援像某些 JIT 編譯器那樣的 bailouts),我們必須能夠根據輸入 x.size(0) + y.size(0) 將 z.size(0) 簡化為一個表達式。這是通過為 PyTorch 中的所有運算子編寫元函數來完成的,這些元函數可以將大小資訊傳播到張量的輸出,而無需實際對節點執行計算。
整體架構¶
符號形狀工作流程
- 當我們開始在 Dynamo 中編譯框架時,我們會配置一個 ShapeEnv(附加到 FakeTensorMode),它會追蹤符號形狀狀態。 
- 我們為輸入的張量分配符號大小(靜態或動態是一個策略決定,有一些旋鈕可以調整)。 
- 我們通過運算子傳播符號大小,同時維護 (1) FX IR,以便我們可以忠實地匯出符號計算,以及 (2) 表示大小變數的 Sympy 表達式,以便我們可以對其進行推理。 
- 當我們根據符號大小設置條件時,無論是在 Dynamo 追蹤中還是在 Inductor 優化中,我們都會根據條件添加防護。這些可以從 Python 和 C++ 中誘導出來。 
- 這些防護可以誘導對符號變數的進一步簡化。例如,如果您斷言 - s0 == 4,我們現在可以用- 4替換所有出現的- s0。
- 當我們完成追蹤和優化後,我們會使用編譯後的代碼安裝所有這些防護;只有在所有防護都評估為真時,編譯後的代碼才能重複使用。 
重要檔案
- C++ SymInt API: - c10/core/SymInt.h、- SymFloat.h、- SymBool.h
- Python SymInt API: - torch/__init__.py(尋找- SymInt/SymFloat/SymBool)
- C++ 管道: - c10/core/SymNodeImpl.h、- torch/csrc/utils/python_symnode.h、- torch/csrc/jit/python/init.cpp
- Python 基礎架構: - torch/fx/experimental/symbolic_shapes.py
- 其他重要檔案: - torch/_subclasses/fake_tensor.py、- torch/_meta_registrations.py、decomps、PrimTorch 參考
簡要內部 API¶
瞭解 Python 類別層次結構
- SymInt/SymFloat/SymBool:這些是使用者可見的類別,模擬它們的 int/float/bool 對應物。如果您添加兩個 SymInt,我們會給您一個新的 SymInt,它會符號化地追蹤已發生的整數加法。 
- SymNode:這是內部結構(可通過例如 - symint.node訪問),它保存實際的符號追蹤資訊。SymNode 是類型擦除的;這使得表示混合類型操作更加方便。請注意,從技術上講,您不必從 SymInt 調用到 Python SymNode;例如,XLA 的 C++- SymNodeImpl將取代 SymNode。
- ShapeEnv:每個編譯上下文狀態,用於追蹤到目前為止我們累積的所有自由符號和防護。每個 SymNode 都會記錄其 ShapeEnv(但反之亦然;只有參與防護的 SymNodes 才會被使用)。 
C++ 非常相似
- c10::SymInt/SymFloat/SymBool:使用者可見的類別,模擬 int/float/bool。 
- c10::SymNode/SymNodeImpl:類似於 SymNode 
- C++ 中沒有 ShapeEnv;為了便於調試,整個符號推理裝置都在 Python 中。 
當您編寫可以使用 make_fx 追蹤的代碼時,它必須能夠處理流經它的 SymInt/SymFloat/SymBool。動態形狀手冊 提供了一些有關如何做到這一點的指導。
無後備 SymInts¶
為了解析控制流程,我們檢查符號整數的提示(即實際值)以確定要進入哪個分支。但是,在某些情況下,我們可能沒有提示:當大小變數來自數據依賴的操作(例如 .nonzero() 或 .item())時,就會出現所謂的無後備符號整數。對這些符號整數執行控制流程是非法的,因此我們必須在這些操作上進行圖形斷開。
如果天真地實現,這就太受限制了:如果您嘗試對無後備符號整數做任何事情,大多數 PyTorch 程序都會立即失敗。以下是最重要的增強功能,使這項功能真正發揮作用
- 在張量創建時,PyTorch 會預先計算有關張量的許多數據;例如,如果您使用 - empty_strided創建張量,我們會急切地對步幅進行排序,並確定張量是否是非重疊且密集的。排序會產生很多防護。但是,更常見的做法是直接使用更高級別的 API(例如- empty)創建張量,這保證會生成非重疊且密集的張量。我們修改了 PyTorch 以避免不必要地重新計算這些屬性。
- 即使需要進行非平凡的計算,有時也根本不會查詢屬性。使這些預先計算的屬性變得懶惰,可以讓我們避免除非真正需要,否則不必防範無後備符號整數。 
- 通常不知道整數張量中的數據是非負的。但是,我們提供了一個 API - constrain_range,使用者可以通過它指定大小在上下限之間有已知的限制。
在 PT2 的未來版本中(PT2.1 之後),我們將擴展我們的推理系統,以根據使用情況推斷無後備符號整數是類似於大小的。例如,如果您將 .item() 調用的結果傳遞給工廠函數(例如 torch.empty),我們將自動推斷結果是一個大小(因為如果不是,它將會失敗。)此假設將在運行時得到驗證,如果未滿足,則會引發錯誤。