快捷方式

torch.export 程式設計模型

本文旨在解釋 torch.export.export() 的行為和功能。它旨在幫助你理解 torch.export.export() 如何處理程式碼。

跟蹤基礎

torch.export.export() 透過在“示例”輸入上跟蹤模型的執行,並記錄沿跟蹤路徑觀察到的 PyTorch 操作和條件,來捕獲表示模型的圖。只要後續輸入滿足相同的條件,此圖就可以在不同的輸入上執行。

torch.export.export() 的基本輸出是一個包含相關元資料的 PyTorch 操作的單一圖。此輸出的具體格式在 torch.export IR 規範 中介紹。

嚴格跟蹤與非嚴格跟蹤

torch.export.export() 提供了兩種跟蹤模式。

非嚴格模式下,我們使用標準的 Python 直譯器跟蹤程式。你的程式碼將完全按照 eager 模式執行;唯一的區別是所有 Tensor 都被替換為 偽 Tensor它們具有形狀和其他形式的元資料,但沒有資料,並封裝在 Proxy 物件 中,這些物件將所有操作記錄到一個圖中。我們還捕獲了 Tensor 形狀條件 這些條件用於保證生成程式碼的正確性

嚴格模式下,我們首先使用 TorchDynamo(一個 Python 位元組碼分析引擎)跟蹤程式。TorchDynamo 實際上並不執行你的 Python 程式碼。相反,它象徵性地分析程式碼並根據結果構建圖。一方面,這種分析允許 torch.export.export() 提供額外的 Python 級別安全性保證(除了像非嚴格模式那樣捕獲 Tensor 形狀條件外)。另一方面,並非所有 Python 特性都受到此分析的支援。

雖然目前預設的跟蹤模式是嚴格模式,但我們強烈建議使用非嚴格模式,它很快將成為預設模式。對於大多數模型,Tensor 形狀條件足以保證健全性,而額外的 Python 級別安全性保證沒有影響;同時,在 TorchDynamo 中遇到不支援的 Python 特性會帶來不必要的風險。

在本文件的其餘部分,我們假設在非嚴格模式下進行跟蹤;特別是,我們假設所有 Python 特性都受到支援

值:靜態值與動態值

理解 torch.export.export() 行為的關鍵概念是靜態值和動態值之間的區別。

靜態值

靜態值是在匯出時固定,且在匯出程式的每次執行之間不能更改的值。在跟蹤期間遇到該值時,我們將其視為常量並將其硬編碼到圖中。

當執行操作(例如 x + y)且所有輸入均為靜態時,操作的輸出將直接硬編碼到圖中,且該操作不會出現在圖中(即它被“常量摺疊”)。

當值被硬編碼到圖中時,我們稱該圖已針對該值進行了特化。例如

import torch

class MyMod(torch.nn.Module):
    def forward(self, x, y):
        z = y + 7
        return x + z

m = torch.export.export(MyMod(), (torch.randn(1), 3))
print(m.graph_module.code)

"""
def forward(self, arg0_1, arg1_1):
    add = torch.ops.aten.add.Tensor(arg0_1, 10);  arg0_1 = None
    return (add,)

"""

在這裡,我們將 3 作為 y 的跟蹤值;它被視為靜態值並新增到 7 中,在圖中固化了靜態值 10

動態值

動態值是在每次執行時可以更改的值。它的行為就像“正常”函式引數一樣:你可以傳遞不同的輸入並期望函式執行正確的操作。

哪些值是靜態的,哪些是動態的?

值是靜態還是動態取決於其型別

  • 對於 Tensor

    • Tensor 資料被視為動態。

    • Tensor 形狀可以被系統視為靜態或動態。

      • 預設情況下,所有輸入 Tensor 的形狀被視為靜態。使用者可以透過為任何輸入 Tensor 指定動態形狀來覆蓋此行為。

      • 作為模組狀態一部分的 Tensor,即引數和緩衝區,始終具有靜態形狀。

    • 其他形式的 Tensor 元資料(例如 device, dtype)是靜態的。

  • Python 基本型別int, float, bool, str, None)是靜態的。

    • 某些基本型別有動態變體(SymInt, SymFloat, SymBool)。通常使用者不需要處理它們。

  • 對於 Python 標準容器list, tuple, dict, namedtuple

    • 結構(即 listtuple 值的長度,以及 dictnamedtuple 值的鍵序列)是靜態的。

    • 包含的元素遞迴應用這些規則(基本上是 PyTree 方案),葉子節點是 Tensor 或基本型別。

  • 其他(包括資料類)可以透過 PyTree 註冊(見下文),並遵循與標準容器相同的規則。

輸入型別

輸入將被視為靜態或動態,具體取決於它們的型別(如上所述)。

  • 靜態輸入將被硬編碼到圖中,並在執行時傳遞不同的值將導致錯誤。請記住,這些值主要是基本型別的值。

  • 動態輸入行為類似於“正常”函式輸入。請記住,這些值主要是 Tensor 型別的值。

預設情況下,程式可用的輸入型別包括

  • Tensor

  • Python 基本型別(int, float, bool, str, None

  • Python 標準容器(list, tuple, dict, namedtuple

自定義輸入型別

此外,你還可以定義自己的(自定義)類並將其用作輸入型別,但這需要你將此類註冊為 PyTree。

這是一個使用實用程式註冊用作輸入型別的資料類的示例。

@dataclass
class Input:
    f: torch.Tensor
    p: torch.Tensor

torch.export.register_dataclass(Input)

class M(torch.nn.Module):
    def forward(self, x: Input):
        return x.f + 1

torch.export.export(M(), (Input(f=torch.ones(10, 4), p=torch.zeros(10, 4)),))

可選輸入型別

對於程式中未傳入的可選輸入,torch.export.export() 將特化為它們的預設值。因此,匯出的程式將要求使用者顯式傳入所有引數,並會丟失預設行為。例如

class M(torch.nn.Module):
    def forward(self, x, y=None):
        if y is not None:
            return y * x
        return x + x

# Optional input is passed in
ep = torch.export.export(M(), (torch.randn(3, 3), torch.randn(3, 3)))
print(ep)
"""
ExportedProgram:
    class GraphModule(torch.nn.Module):
        def forward(self, x: "f32[3, 3]", y: "f32[3, 3]"):
            # File: /data/users/angelayi/pytorch/moo.py:15 in forward, code: return y * x
            mul: "f32[3, 3]" = torch.ops.aten.mul.Tensor(y, x);  y = x = None
            return (mul,)
"""

# Optional input is not passed in
ep = torch.export.export(M(), (torch.randn(3, 3),))
print(ep)
"""
ExportedProgram:
    class GraphModule(torch.nn.Module):
        def forward(self, x: "f32[3, 3]", y):
            # File: /data/users/angelayi/pytorch/moo.py:16 in forward, code: return x + x
            add: "f32[3, 3]" = torch.ops.aten.add.Tensor(x, x);  x = None
            return (add,)
"""

控制流:靜態與動態

PyTorch torch.export.export() 支援控制流。控制流的行為取決於你分支的值是靜態還是動態。

靜態控制流

Python 對靜態值的控制流得到透明支援。(請記住,靜態值包括靜態形狀,因此對靜態形狀的控制流也屬於這種情況。)

如上所述,我們“固化”靜態值,因此匯出的圖永遠不會看到任何對靜態值的控制流。

對於 if 語句,我們將繼續跟蹤匯出時採用的分支。對於 forwhile 語句,我們將透過展開迴圈來繼續跟蹤。

動態控制流:依賴形狀與依賴資料

當控制流中涉及的值是動態的時,它可能依賴於動態形狀或動態資料。考慮到編譯器跟蹤時使用的是形狀資訊而不是資料,這些情況下對程式設計模型的影響是不同的。

動態形狀依賴控制流

當控制流中涉及的值是動態形狀時,在大多數情況下,我們在跟蹤期間也會知道動態形狀的具體值:有關編譯器如何跟蹤此資訊的更多詳細資訊,請參閱下一節。

在這些情況下,我們稱控制流是形狀依賴的。我們使用動態形狀的具體值來評估條件TrueFalse 並繼續跟蹤(如上所述),此外還會發出與剛剛評估的條件對應的守衛。

否則,控制流被視為資料依賴的。我們無法評估條件為 TrueFalse,因此無法繼續跟蹤,必須在匯出時引發錯誤。請參閱下一節。

動態資料依賴控制流

資料依賴的動態值控制流受到支援,但你必須使用 PyTorch 的顯式運算子之一來繼續跟蹤。使用 Python 控制流語句處理動態值是不允許的,因為編譯器無法評估繼續跟蹤所需的條件,因此必須在匯出時引發錯誤。

我們提供了運算子來表達動態值的通用條件和迴圈,例如 torch.condtorch.map。請注意,只有當你確實需要資料依賴的控制流時才需要使用這些運算子。

這是一個關於資料依賴條件 x.sum() > 0if 語句示例,其中 x 是一個輸入 Tensor,使用 torch.cond 重寫。現在,兩個分支都被跟蹤,而不是必須決定跟蹤哪個分支。

class M_old(torch.nn.Module):
    def forward(self, x):
        if x.sum() > 0:
            return x.sin()
        else:
            return x.cos()

class M_new(torch.nn.Module):
    def forward(self, x):
        return torch.cond(
            pred=x.sum() > 0,
            true_fn=lambda x: x.sin(),
            false_fn=lambda x: x.cos(),
            operands=(x,),
        )

資料依賴控制流的一個特殊情況是它涉及無支援的動態形狀:通常是某些中間 Tensor 的形狀,它依賴於輸入資料而不是輸入形狀(因此不依賴於形狀)。在這種情況下,你可以提供一個斷言來決定條件是 True 還是 False,而不是使用控制流運算子。給定這樣的斷言,我們可以繼續跟蹤,併發出一個守衛,如上所述。

我們提供運算子來表達對動態形狀的斷言,例如 torch._check。請注意,只有當對資料依賴的動態形狀存在控制流時才需要使用此運算子。

這是一個涉及資料依賴動態形狀的條件 nz.shape[0] > 0if 語句示例,其中 nz 是呼叫 torch.nonzero() 的結果,這是一個輸出形狀依賴於輸入資料(因此不依賴於形狀)的運算子。在這種情況下,你可以使用 torch._check 新增斷言來有效地決定跟蹤哪個分支,而不是重寫它。

class M_old(torch.nn.Module):
    def forward(self, x):
        nz = x.nonzero()
        if nz.shape[0] > 0:
            return x.sin()
        else:
            return x.cos()

class M_new(torch.nn.Module):
    def forward(self, x):
        nz = x.nonzero()
        torch._check(nz.shape[0] > 0)
        if nz.shape[0] > 0:
            return x.sin()
        else:
            return x.cos()

符號形狀基礎

在跟蹤期間,動態 Tensor 形狀及其條件被編碼為“符號表達式”。(相比之下,靜態 Tensor 形狀及其條件僅僅是 intbool 值。)

一個符號就像一個變數;它描述一個動態 Tensor 形狀。

隨著跟蹤的進行,中間 Tensor 的形狀可以用更通用的表示式來描述,通常涉及整數算術運算子。這是因為對於大多數 PyTorch 運算子,輸出 Tensor 的形狀可以描述為輸入 Tensor 形狀的函式。例如,torch.cat() 的輸出形狀是其輸入形狀的總和。

此外,當我們在程式中遇到控制流時,我們會建立布林表示式,通常涉及關係運算符,描述沿著跟蹤路徑的條件。這些表示式會被評估以決定跟蹤程式中的哪條路徑,並記錄在形狀環境中,以保證跟蹤路徑的正確性並評估隨後建立的表示式。

接下來我們將簡要介紹這些子系統。

PyTorch 運算子的偽實現

回想一下,在跟蹤期間,我們使用偽 Tensor 執行程式,這些 Tensor 沒有資料。通常我們無法使用偽 Tensor 呼叫 PyTorch 運算子的實際實現。因此,每個運算子都需要一個額外的偽(也稱為“元”)實現,該實現輸入和輸出偽 Tensor,並在形狀和偽 Tensor 攜帶的其他形式的元資料方面與實際實現的行為匹配。

例如,請注意 torch.index_select() 的偽實現如何使用輸入形狀計算輸出形狀(同時忽略輸入資料並返回空的輸出資料)。

def meta_index_select(self, dim, index):
    result_size = list(self.size())
    if self.dim() > 0:
        result_size[dim] = index.numel()
    return self.new_empty(result_size)

形狀傳播:有支援的動態形狀與無支援的動態形狀

形狀透過 PyTorch 運算子的偽實現進行傳播。

理解動態形狀傳播的關鍵概念是有支援的無支援的動態形狀之間的區別:我們知道前者的具體值,但不知道後者的具體值。

形狀的傳播,包括跟蹤有支援的和無支援的動態形狀,按以下方式進行

  • 表示輸入的 Tensor 的形狀可以是靜態或動態的。當為動態時,它們由符號描述;此外,由於我們知道使用者在匯出時提供的“真實”示例輸入所給出的具體值,因此此類符號是有支援的

  • 運算子的輸出形狀由其偽實現計算,可以是靜態或動態的。當為動態時,通常由符號表達式描述。此外

    • 如果輸出形狀僅依賴於輸入形狀,則當輸入形狀全部為靜態或有支援的動態時,輸出形狀也是靜態或有支援的動態。

    • 另一方面,如果輸出形狀依賴於輸入資料,則它必然是動態的,而且,因為我們無法知道它的具體值,它是無支援的

控制流:守衛和斷言

當遇到形狀條件時,它要麼僅涉及靜態形狀,在這種情況下它是一個 bool,要麼涉及動態形狀,在這種情況下它是一個符號布林表示式。對於後者

  • 當條件僅涉及有支援的動態形狀時,我們可以使用這些動態形狀的具體值來評估條件為 TrueFalse。然後我們可以在形狀環境中新增一個守衛,說明相應的符號布林表示式為 TrueFalse,並繼續跟蹤。

  • 否則,條件涉及無支援的動態形狀。通常,在沒有額外資訊的情況下,我們無法評估此類條件;因此我們無法繼續跟蹤,必須在匯出時引發錯誤。使用者需要使用顯式的 PyTorch 運算子才能繼續跟蹤。此資訊作為守衛新增到形狀環境中,並且可能有助於評估其他隨後遇到的條件為 TrueFalse

模型匯出後,對有支援動態形狀的任何守衛都可以理解為對輸入動態形狀的條件。這些條件會根據必須提供給匯出的動態形狀規範進行驗證,該規範描述了示例輸入以及所有未來輸入為使生成程式碼正確必須滿足的動態形狀條件。更精確地說,動態形狀規範在邏輯上必須蘊含生成的守衛,否則將在匯出時引發錯誤(並給出動態形狀規範的建議修復)。另一方面,當對有支援動態形狀沒有生成守衛時(特別是當所有形狀都是靜態時),無需為匯出提供動態形狀規範。通常,動態形狀規範會轉換為對生成程式碼輸入的執行時斷言。

最後,對無支援動態形狀的任何守衛都會轉換為“內聯”執行時斷言。這些斷言會新增到生成程式碼中建立這些無支援動態形狀的位置:通常是在資料依賴運算子呼叫之後。

允許的 PyTorch 運算子

允許使用所有 PyTorch 運算子。

自定義運算子

此外,你可以定義和使用自定義運算子。定義自定義運算子包括為其定義一個偽實現,就像任何其他 PyTorch 運算子一樣(請參閱上一節)。

這是一個自定義 sin 運算子的示例,它封裝了 NumPy,及其註冊的(簡單的)偽實現。

@torch.library.custom_op("mylib::sin", mutates_args=())
def sin(x: Tensor) -> Tensor:
    x_np = x.numpy()
    y_np = np.sin(x_np)
    return torch.from_numpy(y_np)

@torch.library.register_fake("mylib::sin")
def _(x: Tensor) -> Tensor:
    return torch.empty_like(x)

有時,你的自定義運算子的偽實現會涉及資料依賴的形狀。這是一個自定義 nonzero 運算子的偽實現可能看起來的樣子。

...

@torch.library.register_fake("mylib::custom_nonzero")
def _(x):
    nnz = torch.library.get_ctx().new_dynamic_size()
    shape = [nnz, x.dim()]
    return x.new_empty(shape, dtype=torch.int64)

模組狀態:讀取與更新

模組狀態包括引數、緩衝區和普通屬性。

  • 普通屬性可以是任何型別。

  • 另一方面,引數和緩衝區始終是 Tensor。

模組狀態可以是動態或靜態的,具體取決於如上所述的型別。例如,self.training 是一個 bool,這意味著它是靜態的;另一方面,任何引數或緩衝區都是動態的。

模組狀態中包含的任何 Tensor 的形狀不能是動態的,即這些形狀在匯出時固定,且在匯出程式的每次執行之間不能更改。

訪問規則

所有模組狀態都必須初始化。在匯出時訪問尚未初始化的模組狀態會導致錯誤發生。

總是允許讀取模組狀態.

更新模組狀態是可能的,但必須遵循以下規則

  • 靜態常規屬性(例如,原始型別)可以更新。讀取和更新可以自由交錯進行,正如預期一樣,任何讀取都將始終看到最新更新的值。由於這些屬性是靜態的,我們也會將值嵌入其中,因此生成的程式碼將不會包含實際“獲取”或“設定”這些屬性的指令。

  • 動態常規屬性(例如,Tensor 型別)無法更新。要更新它,必須在模組初始化期間將其註冊為 buffer。

  • Buffer 可以更新,更新可以是原地更新(例如,self.buffer[:] = ...)或者非原地更新(例如,self.buffer = ...)。

  • Parameter 無法更新。通常 parameter 只在訓練期間更新,而非推理期間。我們建議使用 torch.no_grad() 進行匯出,以避免在匯出時更新 parameter。

Functionalization 的影響

任何被讀取和/或更新的動態模組狀態會(相應地)作為生成程式碼的輸入和/或輸出被“提升”(lifted)。

匯出的程式會與生成程式碼一起儲存 parameter 和 buffer 的初始值以及其他 Tensor 屬性的常量值。

文件

訪問 PyTorch 的完整開發者文件

檢視文件

教程

獲取面向初學者和高階開發者的深入教程

檢視教程

資源

查詢開發資源並解答你的問題

檢視資源