快捷方式

動態形狀

程式碼: symbolic_shapes.py

另請參閱: 動態形狀手冊

動機

深度學習編譯器通常只適用於靜態形狀,也就是說,它們生成的編譯程式只針對輸入形狀的一種特定配置有效,如果任何輸入形狀發生變化,則必須重新編譯。這個假設對於當今大多數常用的深度學習模型來說效果很好,但在某些情況下它是不夠的

  • 某些維度,如批大小或序列長度,可能會變化。例如,執行自適應批處理的推理服務將根據其批處理視窗中接收到的請求數量,以變化的批大小執行推理請求。我們可能還會考慮將變長序列填充到批次內的最大序列長度,該長度可能因批次而異。

  • 一些模型表現出資料依賴的輸出形狀,也就是說,它們的輸出和中間結果的大小可能取決於實際輸入資料,而輸入資料在不同執行中可能會變化。例如,檢測模型可能首先生成可變數量的潛在邊界框,然後執行更昂貴的影像識別模型來識別主體是否在邊界框內。邊界框的數量取決於資料。

  • 處理稀疏表示(如稀疏張量、不規則張量和圖神經網路)時,會遇到一種特別重要的資料依賴形狀情況。在所有這些情況下,需要處理的資料量取決於問題的稀疏結構,而稀疏結構通常會以資料依賴的方式變化。

在支援動態形狀時,我們選擇不支援動態秩(rank)的程式,例如輸入張量維度變化的程式,因為這種模式在實際深度學習程式中很少出現,並且可以避免對符號形狀列表進行歸納推理的需要。

公共 API 摘要

PyTorch 2.1 中的預設動態行為是

  • PT2 預設假定一切都是靜態的

  • 如果我們因為大小改變而重新編譯,我們將嘗試將該大小重新編譯為動態的(改變的大小在將來很可能會再次改變)。這種泛化可能會失敗(例如,因為使用者程式碼對相關大小進行了條件分支,或者 PT2 中缺少動態形狀支援)。如果你試圖理解為什麼 PT2 對某些程式碼進行了過度特殊化,請使用 TORCH_LOGS=dynamic 執行,並查詢指示何時新增守衛(guards)以及原因的“eval”條目。

  • 如果你提前知道某些東西將是動態的,你可以使用 torch._dynamo.mark_dynamic(tensor, dim) 跳過首次重新編譯。如果你提前知道該維度可以接受的最小和最大值,你可以指定 torch._dynamo.mark_dynamic(tensor, dim, min=min, max=max)

  • 如果你指定 torch.compile(dynamic=False),我們將關閉重新編譯時的自動動態形狀,並始終為每個不同的尺寸重新編譯。反之,如果你指定 torch.compile(dynamic=True),我們將嘗試使一切儘可能地動態化。這對於小型運算元(operators)非常有用;如果你在一個大型模型上嘗試這樣做,它很可能會(1)使 PT2 崩潰,並且(2)執行緩慢而沒有任何好處。

守衛模型

在考慮如何向 TorchDynamo 和 TorchInductor 新增動態形狀支援時,我們做了一個重大的設計決策:為了重用用 Python/C++ 編寫並針對 PyTorch API 的分解(decompositions)和其他現有程式碼,我們必須能夠追蹤(trace)動態形狀。與完全符號化系統(可能會捕獲條件分支的兩條路徑)不同,我們總是選擇一個分支並專門化我們的追蹤,假設只有當未來做出相同分支選擇時,我們才會使用這個追蹤。為此,我們為每個符號大小維護一個“提示”,說明其在編譯時(由於 TorchDynamo 是即時編譯器,它總是知道實際輸入大小)的具體值。當我們在張量上執行條件判斷時,我們只需查詢該提示即可知道應選擇哪個分支。

這極大地簡化了我們生成的符號形狀公式,但意味著我們有一個更復雜的守衛管理系統。例如,考慮以下程式

def f(x, y):
    z = torch.cat([x, y])
    if z.size(0) > 2:
        return z.mul(2)
    else:
        return z.add(2)

我們將使用 TorchInductor 編譯的最終 IR 將是 torch.cat([x, y]).add(2)torch.cat([x, y]).mul(2)(條件被消除),但要確定我們處於哪個分支,我們需要知道中間結果 z 的大小。因為 TorchDynamo 必須提前知道編譯後的追蹤是否有效(我們不支援像某些 JIT 編譯器那樣的退出),我們必須能夠將 z.size(0) 作為一個表示式,用輸入 x.size(0) + y.size(0) 來表示。這是透過為 PyTorch 中的所有運算元編寫元函式(meta functions)來實現的,這些元函式可以在不實際對節點進行計算的情況下將大小資訊傳播到張量的輸出。

整體架構

符號形狀工作流程

  1. 當我們在 Dynamo 中開始編譯一個幀時,我們分配一個 ShapeEnv(附屬於 FakeTensorMode),用於跟蹤符號形狀狀態。

  2. 我們在入口處為張量分配符號大小(靜態或動態是一個策略決定,有一些控制選項)。

  3. 我們透過運算元傳播符號大小,同時維護 (1) FX IR,以便忠實地匯出符號計算,以及 (2) 代表尺寸變數的 Sympy 表示式,以便我們能夠對其進行推理。

  4. 當我們在 Dynamo 追蹤或 Inductor 最佳化中對符號大小設定條件時,我們會根據條件新增守衛。這些守衛可以由 Python 和 C++ 程式碼引起。

  5. 這些守衛可以進一步簡化符號變數。例如,如果你斷言 s0 == 4,我們現在可以將所有出現的 s0 替換為 4

  6. 完成追蹤和最佳化後,我們將所有這些守衛與編譯後的程式碼一起安裝;只有當所有守衛都評估為真時,編譯後的程式碼才能被重用。

重要檔案

  • C++ SymInt API: c10/core/SymInt.h, SymFloat.h, SymBool.h

  • Python SymInt API: torch/__init__.py (查詢 SymInt/SymFloat/SymBool)

  • C++ 底層實現: c10/core/SymNodeImpl.h, torch/csrc/utils/python_symnode.h, torch/csrc/jit/python/init.cpp

  • Python 基礎設施: torch/fx/experimental/symbolic_shapes.py

  • 其他重要檔案: torch/_subclasses/fake_tensor.py, torch/_meta_registrations.py, decomps, PrimTorch refs

內部 API 摘要

理解 Python 類層次結構

  • SymInt/SymFloat/SymBool: 這些是使用者可見的類,模擬其對應的 int/float/bool 型別。如果你將兩個 SymInt 相加,我們將給你一個新的 SymInt,以符號方式跟蹤發生了整數加法。

  • SymNode: 這是內部結構(例如透過 symint.node 訪問),用於儲存實際的符號跟蹤資訊。SymNode 是型別擦除的;這使得表示混合型別操作更加方便。請注意,從技術上講,你不必從 SymInt 呼叫 Python SymNode;例如,XLA 的 C++ SymNodeImpl 將取代 SymNode。

  • ShapeEnv: 每個編譯上下文的狀態,用於跟蹤我們迄今為止積累的所有自由符號和守衛。每個 SymNode 都記錄其 ShapeEnv(反之則不然;SymNode 只有在參與守衛時才會被使用)。

C++ 的情況也類似

  • c10::SymInt/SymFloat/SymBool: 使用者可見的類,模擬 int/float/bool 型別。

  • c10::SymNode/SymNodeImpl: 類似於 SymNode

  • C++ 中沒有 ShapeEnv;為了便於除錯,整個符號推理機制都在 Python 中。

當你編寫可以使用 make_fx 進行追蹤的程式碼時,它必須能夠處理流經其中的 SymInt/SymFloat/SymBool。動態形狀手冊提供了一些指導,說明如何做到這一點。

DimDynamic 策略

符號推理

  • 取值範圍

  • Sympy 使用注意事項

  • 約束

  • DimDynamic/Constraint

無支援的 SymInt

為了解析控制流,我們檢查符號整數的提示(即實際值)來確定選擇哪個分支。然而,在某些情況下,我們可能沒有提示:所謂的無支援符號整數出現在大小變數由資料依賴的操作(如 .nonzero().item())產生時。對這些符號整數執行控制流是非法的,因此我們必須在這些操作上進行圖中斷(graph break)。

如果天真地實現,這將過於嚴格:如果你嘗試對無支援的符號整數進行任何操作,大多數 PyTorch 程式會立即失敗。以下是使這實際工作起來的最重要的增強功能

  • 在建立張量時,PyTorch 會預先計算關於張量的很多資料;例如,如果你使用 empty_strided 建立張量,我們會急切地對步長(strides)進行排序,並確定張量是否是非重疊且稠密的。排序會產生很多守衛。然而,更常見的是直接使用像 empty 這樣的高階 API 建立張量,該 API 保證生成非重疊且稠密的張量。我們修改了 PyTorch,以避免不必要地重新計算這些屬性。

  • 即使需要進行非平凡計算,有時某個屬性也從未被實際查詢過。將這些預計算屬性設為惰性(lazy)可以讓我們避免對無支援的符號整數新增守衛,除非確實需要。

  • 整數張量中的資料通常不保證是非負的。然而,我們提供了一個 API constrain_range,使用者可以透過它指定某個大小的上下界由已知限制確定。

在 PT2 的未來版本(PT2.1 之後),我們將擴充套件我們的推理系統,根據使用情況推斷無支援符號整數是類似尺寸的。例如,如果你將 .item() 呼叫的結果傳遞給像 torch.empty 這樣的工廠函式,我們將自動推斷結果是一個尺寸(因為如果不是,它就會失敗)。這個假設將在執行時得到驗證,如果不滿足則會引發錯誤。

文件

查閱 PyTorch 的全面開發者文件

檢視文件

教程

獲取針對初學者和高階開發者的深度教程

檢視教程

資源

查詢開發資源並解答疑問

檢視資源