捷徑

torch.export

警告

此功能是一個正在積極開發中的原型,未來將會有重大變更。

概觀

torch.export.export() 接受一個任意的 Python 可呼叫物件(一個 torch.nn.Module、一個函數或一個方法),並以預先編譯 (AOT) 的方式產生一個僅代表函數的張量計算的追蹤圖,該圖隨後可以使用不同的輸出執行或序列化。

import torch
from torch.export import export

class Mod(torch.nn.Module):
    def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
        a = torch.sin(x)
        b = torch.cos(y)
        return a + b

example_args = (torch.randn(10, 10), torch.randn(10, 10))

exported_program: torch.export.ExportedProgram = export(
    Mod(), args=example_args
)
print(exported_program)
ExportedProgram:
    class GraphModule(torch.nn.Module):
        def forward(self, arg0_1: f32[10, 10], arg1_1: f32[10, 10]):
            # code: a = torch.sin(x)
            sin: f32[10, 10] = torch.ops.aten.sin.default(arg0_1);

            # code: b = torch.cos(y)
            cos: f32[10, 10] = torch.ops.aten.cos.default(arg1_1);

            # code: return a + b
            add: f32[10, 10] = torch.ops.aten.add.Tensor(sin, cos);
            return (add,)

    Graph signature: ExportGraphSignature(
        parameters=[],
        buffers=[],
        user_inputs=['arg0_1', 'arg1_1'],
        user_outputs=['add'],
        inputs_to_parameters={},
        inputs_to_buffers={},
        buffers_to_mutate={},
        backward_signature=None,
        assertion_dep_token=None,
    )
    Range constraints: {}

torch.export 產生一個具有以下不變量的乾淨中間表示 (IR)。有關 IR 的更多規格可以在 這裡 找到。

  • 健全性:它保證是原始程式的一個健全表示,並保持與原始程式相同的呼叫慣例。

  • 標準化:圖中沒有 Python 語義。來自原始程式的子模組被內聯以形成一個完全扁平化的計算圖。

  • 圖屬性:圖是純函數式的,這意味著它不包含具有副作用的操作,例如突變或別名。它不會改變任何中間值、參數或緩衝區。

  • 中繼資料:圖包含在追蹤期間捕獲的中繼資料,例如來自使用者程式碼的堆疊追蹤。

在底層,torch.export 利用了以下最新技術

  • TorchDynamo (torch._dynamo) 是一個內部 API,它使用一個稱為框架評估 API 的 CPython 功能來安全地追蹤 PyTorch 圖。這提供了一個顯著改進的圖捕獲體驗,需要更少的重寫即可完全追蹤 PyTorch 程式碼。

  • AOT Autograd 提供了一個函數化的 PyTorch 圖,並確保圖被分解/降低到 ATen 運算子集。

  • Torch FX (torch.fx) 是圖的底層表示,允許靈活的基於 Python 的轉換。

現有框架

torch.compile() 也使用與 torch.export 相同的 PT2 堆疊,但略有不同

  • JIT 與 AOTtorch.compile() 是一個 JIT 編譯器,而它並非用於在部署之外產生已編譯的構件。

  • 部分與完整圖捕獲:當 torch.compile() 遇到模型中無法追蹤的部分時,它將會「圖中斷」,並退回到在渴望的 Python 執行期中執行程式。相比之下,torch.export 旨在獲得 PyTorch 模型的完整圖表示,因此當遇到無法追蹤的內容時,它將會出錯。由於 torch.export 產生了一個與任何 Python 功能或執行期分離的完整圖,因此該圖可以在不同的環境和語言中儲存、載入和執行。

  • 可用性取捨:由於 torch.compile() 可以在遇到任何無法追蹤的內容時退回到 Python 執行期,因此它更加靈活。torch.export 則需要使用者提供更多資訊或重寫他們的程式碼以使其可追蹤。

torch.fx.symbolic_trace() 相比,torch.export 使用在 Python 位元組碼層級運作的 TorchDynamo 進行追蹤,使其能夠追蹤不受 Python 運算子多載支援限制的任意 Python 構造。此外,torch.export 會精確追蹤張量中繼資料,因此基於張量形狀等條件的條件語句不會追蹤失敗。一般來說,torch.export 預計可以在更多使用者程式上運作,並產生更低層級的圖(在 torch.ops.aten 運算子層級)。請注意,使用者仍然可以在 torch.export 之前使用 torch.fx.symbolic_trace() 作為預處理步驟。

torch.jit.script() 相比,torch.export 不會捕捉 Python 控制流程或資料結構,但它支援比 TorchScript 更多的 Python 語言特性(因為它更容易全面涵蓋 Python 位元組碼)。 生成的圖形更簡單,並且只有直線控制流程(除了顯式的控制流程運算子)。

torch.jit.trace() 相比,torch.export 是可靠的:它能夠追蹤對大小執行整數計算的程式碼,並記錄顯示特定追蹤對其他輸入有效的必要邊界條件。

匯出 PyTorch 模型

範例

主要的進入點是透過 torch.export.export(),它接受一個可呼叫物件(torch.nn.Module、函數或方法)和範例輸入,並將計算圖捕捉到 torch.export.ExportedProgram 中。 範例

import torch
from torch.export import export

# Simple module for demonstration
class M(torch.nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.conv = torch.nn.Conv2d(
            in_channels=3, out_channels=16, kernel_size=3, padding=1
        )
        self.relu = torch.nn.ReLU()
        self.maxpool = torch.nn.MaxPool2d(kernel_size=3)

    def forward(self, x: torch.Tensor, *, constant=None) -> torch.Tensor:
        a = self.conv(x)
        a.add_(constant)
        return self.maxpool(self.relu(a))

example_args = (torch.randn(1, 3, 256, 256),)
example_kwargs = {"constant": torch.ones(1, 16, 256, 256)}

exported_program: torch.export.ExportedProgram = export(
    M(), args=example_args, kwargs=example_kwargs
)
print(exported_program)
ExportedProgram:
    class GraphModule(torch.nn.Module):
        def forward(self, arg0_1: f32[16, 3, 3, 3], arg1_1: f32[16], arg2_1: f32[1, 3, 256, 256], arg3_1: f32[1, 16, 256, 256]):

            # code: a = self.conv(x)
            convolution: f32[1, 16, 256, 256] = torch.ops.aten.convolution.default(
                arg2_1, arg0_1, arg1_1, [1, 1], [1, 1], [1, 1], False, [0, 0], 1
            );

            # code: a.add_(constant)
            add: f32[1, 16, 256, 256] = torch.ops.aten.add.Tensor(convolution, arg3_1);

            # code: return self.maxpool(self.relu(a))
            relu: f32[1, 16, 256, 256] = torch.ops.aten.relu.default(add);
            max_pool2d_with_indices = torch.ops.aten.max_pool2d_with_indices.default(
                relu, [3, 3], [3, 3]
            );
            getitem: f32[1, 16, 85, 85] = max_pool2d_with_indices[0];
            return (getitem,)

    Graph signature: ExportGraphSignature(
        parameters=['L__self___conv.weight', 'L__self___conv.bias'],
        buffers=[],
        user_inputs=['arg2_1', 'arg3_1'],
        user_outputs=['getitem'],
        inputs_to_parameters={
            'arg0_1': 'L__self___conv.weight',
            'arg1_1': 'L__self___conv.bias',
        },
        inputs_to_buffers={},
        buffers_to_mutate={},
        backward_signature=None,
        assertion_dep_token=None,
    )
    Range constraints: {}

檢查 ExportedProgram,我們可以注意到以下幾點

  • torch.fx.Graph 包含原始程式的計算圖,以及原始程式碼的記錄,以便於除錯。

  • 該圖僅包含 這裡 找到的 torch.ops.aten 運算子和自訂運算子,並且功能齊全,沒有任何原地運算子,例如 torch.add_

  • 參數(卷積的權重和偏差)被提升為圖的輸入,導致圖中沒有 get_attr 節點,這些節點先前存在於 torch.fx.symbolic_trace() 的結果中。

  • torch.export.ExportGraphSignature 為輸入和輸出簽章建模,並指定哪些輸入是參數。

  • 圖中每個節點產生的張量的結果形狀和 dtype 都會被記錄下來。 例如,convolution 節點將產生一個 dtype 為 torch.float32 且形狀為 (1, 16, 256, 256) 的張量。

非嚴格匯出

在 PyTorch 2.3 中,我們引入了一種稱為**非嚴格模式**的新追蹤模式。 它仍在強化中,因此如果您遇到任何問題,請使用「oncall: export」標籤將其提交到 Github。

在*非嚴格模式*下,我們使用 Python 解譯器追蹤程式。 您的程式碼將完全按照在 Eager 模式下的執行方式執行;唯一的區別是所有 Tensor 物件都將被 ProxyTensors 取代,ProxyTensors 會將其所有操作記錄到圖形中。

在*嚴格*模式(目前為預設模式)下,我們首先使用位元組碼分析引擎 TorchDynamo 追蹤程式。 TorchDynamo 實際上並未執行您的 Python 程式碼。 相反,它會對其進行符號分析,並根據結果構建圖形。 此分析允許 torch.export 對安全性提供更強的保證,但並非所有 Python 程式碼都受支援。

例如,如果您遇到可能不容易解決的未支援 TorchDynamo 功能,並且您知道計算並非完全需要 Python 程式碼,則可能需要使用非嚴格模式。 例如

import contextlib
import torch

class ContextManager():
    def __init__(self):
        self.count = 0
    def __enter__(self):
        self.count += 1
    def __exit__(self, exc_type, exc_value, traceback):
        self.count -= 1

class M(torch.nn.Module):
    def forward(self, x):
        with ContextManager():
            return x.sin() + x.cos()

export(M(), (torch.ones(3, 3),), strict=False)  # Non-strict traces successfully
export(M(), (torch.ones(3, 3),))  # Strict mode fails with torch._dynamo.exc.Unsupported: ContextManager

在此範例中,第一次使用非嚴格模式的呼叫(透過 strict=False 旗標)成功追蹤,而第二次使用嚴格模式的呼叫(預設)導致失敗,其中 TorchDynamo 無法支援上下文管理器。 一種選擇是重寫程式碼(請參閱 torch.export 的限制),但由於上下文管理器不會影響模型中的張量計算,因此我們可以使用非嚴格模式的結果。

表達動態性

根據預設,torch.export 將追蹤程式,假設所有輸入形狀都是**靜態**的,並將匯出的程式專用於這些維度。 然而,某些維度(例如批次維度)可以是動態的,並且在每次執行時都會發生變化。 必須使用 torch.export.Dim() API 建立這些維度,並透過 dynamic_shapes 參數將其傳遞到 torch.export.export() 中來指定這些維度。 範例

import torch
from torch.export import Dim, export

class M(torch.nn.Module):
    def __init__(self):
        super().__init__()

        self.branch1 = torch.nn.Sequential(
            torch.nn.Linear(64, 32), torch.nn.ReLU()
        )
        self.branch2 = torch.nn.Sequential(
            torch.nn.Linear(128, 64), torch.nn.ReLU()
        )
        self.buffer = torch.ones(32)

    def forward(self, x1, x2):
        out1 = self.branch1(x1)
        out2 = self.branch2(x2)
        return (out1 + self.buffer, out2)

example_args = (torch.randn(32, 64), torch.randn(32, 128))

# Create a dynamic batch size
batch = Dim("batch")
# Specify that the first dimension of each input is that batch size
dynamic_shapes = {"x1": {0: batch}, "x2": {0: batch}}

exported_program: torch.export.ExportedProgram = export(
    M(), args=example_args, dynamic_shapes=dynamic_shapes
)
print(exported_program)
ExportedProgram:
    class GraphModule(torch.nn.Module):
        def forward(self, arg0_1: f32[32, 64], arg1_1: f32[32], arg2_1: f32[64, 128], arg3_1: f32[64], arg4_1: f32[32], arg5_1: f32[s0, 64], arg6_1: f32[s0, 128]):

            # code: out1 = self.branch1(x1)
            permute: f32[64, 32] = torch.ops.aten.permute.default(arg0_1, [1, 0]);
            addmm: f32[s0, 32] = torch.ops.aten.addmm.default(arg1_1, arg5_1, permute);
            relu: f32[s0, 32] = torch.ops.aten.relu.default(addmm);

            # code: out2 = self.branch2(x2)
            permute_1: f32[128, 64] = torch.ops.aten.permute.default(arg2_1, [1, 0]);
            addmm_1: f32[s0, 64] = torch.ops.aten.addmm.default(arg3_1, arg6_1, permute_1);
            relu_1: f32[s0, 64] = torch.ops.aten.relu.default(addmm_1);  addmm_1 = None

            # code: return (out1 + self.buffer, out2)
            add: f32[s0, 32] = torch.ops.aten.add.Tensor(relu, arg4_1);
            return (add, relu_1)

    Graph signature: ExportGraphSignature(
        parameters=[
            'branch1.0.weight',
            'branch1.0.bias',
            'branch2.0.weight',
            'branch2.0.bias',
        ],
        buffers=['L__self___buffer'],
        user_inputs=['arg5_1', 'arg6_1'],
        user_outputs=['add', 'relu_1'],
        inputs_to_parameters={
            'arg0_1': 'branch1.0.weight',
            'arg1_1': 'branch1.0.bias',
            'arg2_1': 'branch2.0.weight',
            'arg3_1': 'branch2.0.bias',
        },
        inputs_to_buffers={'arg4_1': 'L__self___buffer'},
        buffers_to_mutate={},
        backward_signature=None,
        assertion_dep_token=None,
    )
    Range constraints: {s0: RangeConstraint(min_val=2, max_val=9223372036854775806)}

還有一些需要注意的事項

  • 透過 torch.export.Dim() API 和 dynamic_shapes 參數,我們將每個輸入的第一個維度指定為動態的。 查看輸入 arg5_1arg6_1,它們的符號形狀為 (s0, 64) 和 (s0, 128),而不是我們作為範例輸入傳遞的 (32, 64) 和 (32, 128) 形狀的張量。 s0 是一個符號,表示此維度可以是一系列值。

  • exported_program.range_constraints 描述圖中出現的每個符號的範圍。 在這種情況下,我們看到 s0 的範圍為 [2, inf]。 由於技術原因,這裡難以解釋,它們被假定為不是 0 或 1。 這不是錯誤,也不一定意味著匯出的程式不適用於維度 0 或 1。 有關此主題的深入討論,請參閱 0/1 特殊化問題

我們還可以指定輸入形狀之間更具表現力的關係,例如一對形狀可能相差 1、一個形狀可能是另一個形狀的兩倍,或者一個形狀是偶數。 範例

class M(torch.nn.Module):
    def forward(self, x, y):
        return x + y[1:]

x, y = torch.randn(5), torch.randn(6)
dimx = torch.export.Dim("dimx", min=3, max=6)
dimy = dimx + 1

exported_program = torch.export.export(
    M(), (x, y), dynamic_shapes=({0: dimx}, {0: dimy}),
)
print(exported_program)
ExportedProgram:
class GraphModule(torch.nn.Module):
    def forward(self, arg0_1: "f32[s0]", arg1_1: "f32[s0 + 1]"):
        # code: return x + y[1:]
        slice_1: "f32[s0]" = torch.ops.aten.slice.Tensor(arg1_1, 0, 1, 9223372036854775807);  arg1_1 = None
        add: "f32[s0]" = torch.ops.aten.add.Tensor(arg0_1, slice_1);  arg0_1 = slice_1 = None
        return (add,)

Graph signature: ExportGraphSignature(
    input_specs=[
        InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='arg0_1'), target=None, persistent=None),
        InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='arg1_1'), target=None, persistent=None)
    ],
    output_specs=[
        OutputSpec(kind=<OutputKind.USER_OUTPUT: 1>, arg=TensorArgument(name='add'), target=None)]
)
Range constraints: {s0: ValueRanges(lower=3, upper=6, is_bool=False), s0 + 1: ValueRanges(lower=4, upper=7, is_bool=False)}

需要注意的一些事項

  • 透過為第一個輸入指定 {0: dimx},我們可以看到第一個輸入的結果形狀現在是動態的,即 [s0]。 現在,透過為第二個輸入指定 {0: dimy},我們可以看到第二個輸入的結果形狀也是動態的。 然而,由於我們表達了 dimy = dimx + 1,而不是 arg1_1 的形狀包含一個新符號,我們看到它現在使用與 arg0_1 中使用的相同符號 s0 來表示。 我們可以看到 dimy = dimx + 1 的關係是透過 s0 + 1 來顯示的。

  • 查看範圍約束,我們看到 s0 的範圍為 [3, 6],這是最初指定的,我們可以看到 s0 + 1 的求解範圍為 [4, 7]。

序列化

要儲存 ExportedProgram,使用者可以使用 torch.export.save()torch.export.load() API。 一種慣例是使用 .pt2 檔案副檔名儲存 ExportedProgram

範例

import torch
import io

class MyModule(torch.nn.Module):
    def forward(self, x):
        return x + 10

exported_program = torch.export.export(MyModule(), torch.randn(5))

torch.export.save(exported_program, 'exported_program.pt2')
saved_exported_program = torch.export.load('exported_program.pt2')

特殊化

理解 torch.export 行為的一個關鍵概念是*靜態*值和*動態*值之間的區別。

*動態*值是指在每次執行時都可能發生變化的值。 這些值的行為類似於 Python 函數的普通參數 — 您可以為一個參數傳遞不同的值,並期望您的函數能夠正常工作。 張量*資料*被視為動態的。

*靜態*值是指在匯出時固定且在匯出的程式執行之間無法更改的值。 當在追蹤期間遇到該值時,匯出器會將其視為常數並將其硬編碼到圖形中。

當執行運算時(例如 x + y)並且所有輸入都是靜態的時,則運算的輸出將直接硬編碼到圖形中,並且運算將不會顯示(即,它將被常數摺疊)。

當一個值被硬編碼到圖形中時,我們說該圖形已*專用*於該值。

以下值是靜態的

輸入張量形狀

預設情況下,torch.export 會追蹤程式,並專注於輸入張量的形狀,除非透過 torch.exportdynamic_shapes 參數將維度指定為動態。這表示如果存在與形狀相關的控制流程,torch.export 將專注於使用給定範例輸入所採用的分支。例如

import torch
from torch.export import export

class Mod(torch.nn.Module):
    def forward(self, x):
        if x.shape[0] > 5:
            return x + 1
        else:
            return x - 1

example_inputs = (torch.rand(10, 2),)
exported_program = export(Mod(), example_inputs)
print(exported_program)
ExportedProgram:
    class GraphModule(torch.nn.Module):
        def forward(self, arg0_1: f32[10, 2]):
            add: f32[10, 2] = torch.ops.aten.add.Tensor(arg0_1, 1);
            return (add,)

(x.shape[0] > 5) 的條件不會出現在 ExportedProgram 中,因為範例輸入具有 (10, 2) 的靜態形狀。由於 torch.export 專注於輸入的靜態形狀,因此永遠不會到達 else 分支 (x - 1)。若要根據追蹤圖表中張量的形狀保留動態分支行為,需要使用 torch.export.dynamic_dim() 指定輸入張量 (x.shape[0]) 的維度為動態,並且需要 重寫 原始碼。

請注意,屬於模組狀態的張量(例如參數和緩衝區)始終具有靜態形狀。

Python 基元

torch.export 也專注於 Python 基元,例如 intfloatboolstr。然而,它們確實具有動態變體,例如 SymIntSymFloatSymBool

例如

import torch
from torch.export import export

class Mod(torch.nn.Module):
    def forward(self, x: torch.Tensor, const: int, times: int):
        for i in range(times):
            x = x + const
        return x

example_inputs = (torch.rand(2, 2), 1, 3)
exported_program = export(Mod(), example_inputs)
print(exported_program)
ExportedProgram:
    class GraphModule(torch.nn.Module):
        def forward(self, arg0_1: f32[2, 2], arg1_1, arg2_1):
            add: f32[2, 2] = torch.ops.aten.add.Tensor(arg0_1, 1);
            add_1: f32[2, 2] = torch.ops.aten.add.Tensor(add, 1);
            add_2: f32[2, 2] = torch.ops.aten.add.Tensor(add_1, 1);
            return (add_2,)

由於整數是專門化的,因此 torch.ops.aten.add.Tensor 運算都是使用硬編碼常數 1 計算的,而不是 arg1_1。如果使用者在執行時傳遞與匯出時使用的值不同的 arg1_1 值(例如 2,而不是 1),則會導致錯誤。此外,for 迴圈中使用的 times 迭代器也會透過 3 個重複的 torch.ops.aten.add.Tensor 呼叫「內嵌」在圖表中,並且永遠不會使用輸入 arg2_1

Python 容器

Python 容器(ListDictNamedTuple 等)被認為具有靜態結構。

torch.export 的限制

圖表斷裂

由於 torch.export 是從 PyTorch 程式擷取計算圖表的單次過程,它最終可能會遇到程式中無法追蹤的部分,因為幾乎不可能支援追蹤所有 PyTorch 和 Python 功能。在 torch.compile 的情況下,不支援的運算將導致「圖表斷裂」,並且不支援的運算將使用預設的 Python 評估來執行。相比之下,torch.export 將要求使用者提供額外資訊或重寫部分程式碼,使其可追蹤。由於追蹤基於在 Python 位元組碼級別進行評估的 TorchDynamo,因此與以前的追蹤框架相比,所需的重寫將顯著減少。

當遇到圖表斷裂時,ExportDB 是一個很好的資源,可用於了解支援和不支援的程式類型,以及重寫程式使其可追蹤的方法。

克服處理這些圖表斷裂的一個選擇是使用 非嚴格匯出

資料/形狀相關的控制流程

當未專門化形狀時,也可能在資料相關的控制流程(if x.shape[0] > 2)上遇到圖表斷裂,因為追蹤編譯器不可能在不為組合爆炸數量的路徑產生程式碼的情況下處理。在這種情況下,使用者需要使用特殊的控制流程運算子來重寫他們的程式碼。目前,我們支援 torch.cond 來表達類似 if-else 的控制流程(即將推出更多!)。

運算子缺少 Fake/Meta/Abstract 核心

追蹤時,所有運算子都需要 FakeTensor 核心(又稱元核心、抽象實作)。這用於推斷此運算子的輸入/輸出形狀。

如需更多詳細資訊,請參閱 torch.library.register_fake()

如果您不幸遇到模型使用尚未實作 FakeTensor 核心的 ATen 運算子的情況,請提交問題。

API 參考

torch.export.export(mod, args, kwargs=None, *, dynamic_shapes=None, strict=True, preserve_module_call_signature=())[source]

export() 採用任意 Python 可呼叫物件(nn.Module、函數或方法)以及範例輸入,並以預先 (AOT) 方式產生僅表示函數的張量計算的追蹤圖表,該圖表隨後可以使用不同的輸入執行或序列化。追蹤圖表 (1) 在函數式 ATen 運算子集(以及任何使用者指定的自訂運算子)中產生標準化運算子,(2) 已消除所有 Python 控制流程和資料結構(某些例外情況除外),以及 (3) 記錄顯示此標準化和控制流程消除對於未來輸入有效的形狀約束集。

健全性保證

追蹤時,export() 會注意到使用者程式和底層 PyTorch 運算子核心做出的與形狀相關的假設。僅當這些假設成立時,輸出 ExportedProgram 才被視為有效。

追蹤對輸入張量的形狀(而非值)做出假設。為了使 export() 成功,必須在圖表擷取時驗證這些假設。具體來說

  • 自動驗證對輸入張量的靜態形狀的假設,而無需額外工作。

  • 對輸入張量的動態形狀的假設需要透過使用 Dim() API 建構動態維度並透過 dynamic_shapes 參數將它們與範例輸入相關聯來明確指定。

如果無法驗證任何假設,則會引發致命錯誤。發生這種情況時,錯誤訊息將包含對驗證假設所需的規範的建議修復。例如,export() 可能會建議對動態維度 dim0_x 的定義進行以下修復,例如出現在與輸入 x 相關聯的形狀中,該定義先前定義為 Dim("dim0_x")

dim = Dim("dim0_x", max=5)

此範例表示產生的程式碼要求輸入 x 的維度 0 小於或等於 5 才能有效。您可以檢查對動態維度定義的建議修復,然後將它們逐字複製到您的程式碼中,而無需更改 export() 呼叫的 dynamic_shapes 參數。

參數
  • mod (Module) – 我們將追蹤此模組的 forward 方法。

  • args (Tuple[Any, ...]) – 範例位置輸入。

  • kwargs (Optional[Dict[str, Any]]) – 可選的範例關鍵字輸入。

  • dynamic_shapes (Optional[Union[Dict[str, Any], Tuple[Any], List[Any]]]) –

    一個可選參數,其類型應為以下三種之一:1) 從 f 的參數名稱到其動態形狀規範的字典,2) 一個元組,按原始順序指定每個輸入的動態形狀規範。如果您要指定關鍵字參數的動態性,則需要按照原始函數簽章中定義的順序傳遞它們。

    張量參數的動態形狀可以指定為以下兩種方式之一:(1) 從動態維度索引到 Dim() 類型的字典,其中不需要在此字典中包含靜態維度索引,但當它們存在時,它們應該映射到 None;或者 (2) Dim() 類型或 None 的元組/列表,其中 Dim() 類型對應於動態維度,靜態維度由 None 表示。作為張量字典或元組/列表的參數,可以使用包含規範的映射或序列遞迴指定。

  • strict (bool) – 啟用時(預設值),匯出函數將通過 TorchDynamo 追蹤程式,這將確保結果圖的健全性。否則,匯出的程式將不會驗證圖中隱含的假設,並且可能會導致原始模型與匯出模型之間的行為差異。當用戶需要解決追蹤器中的錯誤,或者只是想逐步啟用模型中的安全性時,這非常有用。請注意,這不會影響結果 IR 規範的不同,並且無論在此處傳遞什麼值,模型都將以相同的方式序列化。警告:此選項為實驗性選項,使用風險自負。

回傳值

一個包含已追蹤可呼叫項目的 ExportedProgram

回傳類型

ExportedProgram

可接受的輸入/輸出類型

可接受的輸入(對於 argskwargs)和輸出類型包括

  • 基本類型,例如 torch.Tensorintfloatboolstr

  • 數據類別,但必須先呼叫 register_dataclass() 進行註冊。

  • 包含上述所有類型的 dictlisttuplenamedtupleOrderedDict 的(嵌套)數據結構。

torch.export.dynamic_shapes.dynamic_dim(t, index, debug_name=None)[原始碼]

警告

(此功能已棄用。請改用 Dim()。)

dynamic_dim() 構造一個 _Constraint 物件,用於描述張量 t 的維度 index 的動態性。_Constraint 物件應傳遞給 constraints 參數 export()

參數
  • t (torch.Tensor) – 具有動態維度大小的範例輸入張量

  • index (int) – 動態維度的索引

回傳值

一個描述形狀動態性的 _Constraint 物件。它可以傳遞給 export(),以便 export() 不會假設指定張量的大小是靜態的,即將其保持為符號大小的動態性,而不是根據範例追蹤輸入的大小進行特化。

具體來說,dynamic_dim() 可以用於表達以下類型的動態性。

  • 維度的大小是動態且無界的

    t0 = torch.rand(2, 3)
    t1 = torch.rand(3, 4)
    
    # First dimension of t0 can be dynamic size rather than always being static size 2
    constraints = [dynamic_dim(t0, 0)]
    ep = export(fn, (t0, t1), constraints=constraints)
    
  • 維度的大小是動態的,但有一個下限

    t0 = torch.rand(10, 3)
    t1 = torch.rand(3, 4)
    
    # First dimension of t0 can be dynamic size with a lower bound of 5 (inclusive)
    # Second dimension of t1 can be dynamic size with a lower bound of 2 (exclusive)
    constraints = [
        dynamic_dim(t0, 0) >= 5,
        dynamic_dim(t1, 1) > 2,
    ]
    ep = export(fn, (t0, t1), constraints=constraints)
    
  • 維度的大小是動態的,但有一個上限

    t0 = torch.rand(10, 3)
    t1 = torch.rand(3, 4)
    
    # First dimension of t0 can be dynamic size with a upper bound of 16 (inclusive)
    # Second dimension of t1 can be dynamic size with a upper bound of 8 (exclusive)
    constraints = [
        dynamic_dim(t0, 0) <= 16,
        dynamic_dim(t1, 1) < 8,
    ]
    ep = export(fn, (t0, t1), constraints=constraints)
    
  • 維度的大小是動態的,並且始終等於另一個動態維度的大小

    t0 = torch.rand(10, 3)
    t1 = torch.rand(3, 4)
    
    # Sizes of second dimension of t0 and first dimension are always equal
    constraints = [
        dynamic_dim(t0, 1) == dynamic_dim(t1, 0),
    ]
    ep = export(fn, (t0, t1), constraints=constraints)
    
  • 混合搭配以上所有類型,只要它們不表示衝突的需求

torch.export.save(ep, f, *, extra_files=None, opset_version=None)[原始碼]

警告

正在積極開發中,儲存的檔案可能無法在較新版本的 PyTorch 中使用。

ExportedProgram 儲存到檔案物件。然後可以使用 Python API torch.export.load 載入它。

參數
  • ep (ExportedProgram) – 要儲存的匯出程式。

  • f (Union[str, os.PathLike, io.BytesIO) – 一個檔案物件(必須實現寫入和清空)或一個包含檔案名的字串。

  • extra_files (Optional[Dict[str, Any]]) – 從檔案名到內容的映射,這些內容將作為 f 的一部分儲存。

  • opset_version (Optional[Dict[str, int]]) – 操作集名稱到此操作集版本的映射

範例

import torch
import io

class MyModule(torch.nn.Module):
    def forward(self, x):
        return x + 10

ep = torch.export.export(MyModule(), (torch.randn(5),))

# Save to file
torch.export.save(ep, 'exported_program.pt2')

# Save to io.BytesIO buffer
buffer = io.BytesIO()
torch.export.save(ep, buffer)

# Save with extra files
extra_files = {'foo.txt': b'bar'.decode('utf-8')}
torch.export.save(ep, 'exported_program.pt2', extra_files=extra_files)
torch.export.load(f, *, extra_files=None, expected_opset_version=None)[原始碼]

警告

正在積極開發中,儲存的檔案可能無法在較新版本的 PyTorch 中使用。

載入先前使用 torch.export.save 儲存的 ExportedProgram

參數
  • ep (ExportedProgram) – 要儲存的匯出程式。

  • f (Union[str, os.PathLike, io.BytesIO) – 一個檔案物件(必須實現寫入和清空)或一個包含檔案名的字串。

  • extra_files (Optional[Dict[str, Any]]) – 此映射中提供的額外檔案名將被載入,並且它們的內容將儲存在提供的映射中。

  • expected_opset_version (Optional[Dict[str, int]]) – 操作集名稱到預期操作集版本的映射

回傳值

一個 ExportedProgram 物件

回傳類型

ExportedProgram

範例

import torch
import io

# Load ExportedProgram from file
ep = torch.export.load('exported_program.pt2')

# Load ExportedProgram from io.BytesIO object
with open('exported_program.pt2', 'rb') as f:
    buffer = io.BytesIO(f.read())
buffer.seek(0)
ep = torch.export.load(buffer)

# Load with extra files.
extra_files = {'foo.txt': ''}  # values will be replaced with data
ep = torch.export.load('exported_program.pt2', extra_files=extra_files)
print(extra_files['foo.txt'])
print(ep(torch.randn(5)))
torch.export.register_dataclass(cls, *, serialized_type_name=None)[原始碼]

註冊一個數據類別作為 torch.export.export() 的有效輸入/輸出類型。

參數
  • cls (Type[Any]) – 要註冊的數據類別類型

  • serialized_type_name (Optional[str]) – 數據類別的序列化名稱。這是

  • 這個 (如果要序列化包含以下內容的 pytree TreeSpec,則為必需項) –

  • dataclass。

範例

@dataclass
class InputDataClass:
    feature: torch.Tensor
    bias: int

class OutputDataClass:
    res: torch.Tensor

torch.export.register_dataclass(InputDataClass)
torch.export.register_dataclass(OutputDataClass)

def fn(o: InputDataClass) -> torch.Tensor:
    res = res=o.feature + o.bias
    return OutputDataClass(res=res)

ep = torch.export.export(fn, (InputDataClass(torch.ones(2, 2), 1), ))
print(ep)
torch.export.dynamic_shapes.Dim(name, *, min=None, max=None)[原始碼]

Dim() 建構一個類似於具有範圍的具名符號整數的類型。它可以用來描述動態張量維度的多個可能值。請注意,相同張量或不同張量的不同動態維度可以使用相同的類型來描述。

參數
  • name (str) – 用於除錯的可讀名稱。

  • min (Optional[int]) – 給定符號的最小可能值(含)。

  • max (Optional[int]) – 給定符號的最大可能值(含)。

回傳值

一種可用於張量的動態形狀規範的類型。

torch.export.dims(*names, min=None, max=None)[原始碼]

用於建立多個 Dim() 類型的工具。

class torch.export.dynamic_shapes.ShapesCollection[原始碼]

動態形狀的建構器。用於將動態形狀規範分配給出現在輸入中的張量。

範例:

args = ({“x”: tensor_x, “others”: [tensor_y, tensor_z]})

dim = torch.export.Dim(…) dynamic_shapes = torch.export.ShapesCollection() dynamic_shapes[tensor_x] = (dim, dim + 1, 8) dynamic_shapes[tensor_y] = {0: dim * 2} # 這等同於以下內容(現在自動產生): # dynamic_shapes = {“x”: (dim, dim + 1, 8), “others”: [{0: dim * 2}, None]}

torch.export(…, args, dynamic_shapes=dynamic_shapes)

dynamic_shapes(m, args, kwargs=None)[原始碼]

產生動態形狀。

torch.export.dynamic_shapes.refine_dynamic_shapes_from_suggested_fixes(msg, dynamic_shapes)[原始碼]

用於處理導出的動態形狀建議修復,和/或自動動態形狀。根據 ConstraintViolation 錯誤訊息和原始動態形狀,改進給定的動態形狀規範。

對於大多數情況,行為都很簡單,例如針對專門化或改進 Dim 範圍的建議修復,或針對建議衍生關係的修復,新的動態形狀規範將會照此更新。

例如,建議修復

dim = Dim(‘dim’, min=3, max=6) -> 這只會改進 dim 的範圍 dim = 4 -> 這會專門化為常數 dy = dx + 1 -> dy 被指定為一個獨立的 dim,但實際上透過此關係與 dx 綁定

然而,與衍生 dim 相關的建議修復可能會更複雜。例如,如果為根 dim 提供了建議修復,則新的衍生 dim 值將根據根進行評估。

例如,dx = Dim(‘dx’) dy = dx + 2 dynamic_shapes = {“x”: (dx,), “y”: (dy,)}

建議修復

dx = 4 # 專門化將導致 dy 也專門化為 6 dx = Dim(‘dx’, max=6) # dy 現在的最大值為 8

衍生 dim 建議修復也可以用於表示可除性約束。這涉及建立未與特定輸入形狀綁定的新根 dim。在這種情況下,根 dim 不會直接出現在新的規範中,而是作為其中一個 dim 的根。

例如,建議修復

_dx = Dim(‘_dx’, max=1024) # 這不會出現在回傳結果中,但 dx 會出現 dx = 4*_dx # dx 現在可被 4 整除,最大值為 4096

回傳類型

Union[Dict[str, Any], Tuple[Any], List[Any]]

torch.export.Constraint

Union[_Constraint, _DerivedConstraint] 的別名

class torch.export.ExportedProgram(root, graph, graph_signature, state_dict, range_constraints, module_call_graph, example_inputs=None, verifier=None, tensor_constants=None, constants=None)[原始碼]

來自 export() 的程式套件。它包含一個表示張量計算的 torch.fx.Graph,一個包含所有提升參數和緩衝區的張量值的 state_dict,以及各種中繼資料。

您可以像使用相同的呼叫約定追蹤 export() 的原始可呼叫物件一樣呼叫 ExportedProgram。

若要對圖形執行轉換,請使用 .module 屬性來存取 torch.fx.GraphModule。然後,您可以使用 FX 轉換 來重寫圖形。之後,您可以簡單地再次使用 export() 來建構正確的 ExportedProgram。

module()[原始碼]

回傳一個包含所有內嵌參數/緩衝區的自包含 GraphModule。

回傳類型

模組

buffers()[原始碼]

回傳一個迭代器,用於迭代原始模組緩衝區。

警告

此 API 為實驗性質,不保證向後相容。

回傳類型

Iterator[Tensor]

named_buffers()[原始碼]

回傳一個迭代器,用於迭代原始模組緩衝區,並產生緩衝區的名稱和緩衝區本身。

警告

此 API 為實驗性質,不保證向後相容。

回傳類型

Iterator[Tuple[str, Tensor]]

parameters()[原始碼]

回傳一個迭代器,用於迭代原始模組的參數。

警告

此 API 為實驗性質,不保證向後相容。

回傳類型

Iterator[Parameter]

named_parameters()[原始碼]

傳回一個迭代器,用於迭代原始模組參數,同時產生參數的名稱和參數本身。

警告

此 API 為實驗性質,不保證向後相容。

回傳類型

迭代器[元組[字串, 參數]]

run_decompositions(decomp_table=None)[原始碼]

對匯出的程式執行一組分解,並傳回一個新的匯出程式。預設情況下,我們將執行核心 ATen 分解,以在 核心 ATen 運算子集 中取得運算子。

目前,我們不分解聯合圖。

回傳類型

ExportedProgram

類別 torch.export.ExportBackwardSignature(gradients_to_parameters: Dict[字串, 字串], gradients_to_user_inputs: Dict[字串, 字串], loss_output: 字串)[原始碼]
類別 torch.export.ExportGraphSignature(input_specs, output_specs)[原始碼]

ExportGraphSignature 模擬匯出圖的輸入/輸出簽章,這是一個具有更強不變量保證的 fx.Graph。

匯出圖是函數式的,不會透過 getattr 節點存取圖中的「狀態」,例如參數或緩衝區。相反地,export() 保證參數、緩衝區和常數張量會作為輸入從圖中提取出來。同樣地,對緩衝區的任何變更也不會包含在圖中,而是將變更後的緩衝區的更新值模擬為匯出圖的額外輸出。

所有輸入和輸出的順序為

Inputs = [*parameters_buffers_constant_tensors, *flattened_user_inputs]
Outputs = [*mutated_inputs, *flattened_user_outputs]

例如,如果匯出以下模組

class CustomModule(nn.Module):
    def __init__(self):
        super(CustomModule, self).__init__()

        # Define a parameter
        self.my_parameter = nn.Parameter(torch.tensor(2.0))

        # Define two buffers
        self.register_buffer('my_buffer1', torch.tensor(3.0))
        self.register_buffer('my_buffer2', torch.tensor(4.0))

    def forward(self, x1, x2):
        # Use the parameter, buffers, and both inputs in the forward method
        output = (x1 + self.my_parameter) * self.my_buffer1 + x2 * self.my_buffer2

        # Mutate one of the buffers (e.g., increment it by 1)
        self.my_buffer2.add_(1.0) # In-place addition

        return output

產生的圖將是

graph():
    %arg0_1 := placeholder[target=arg0_1]
    %arg1_1 := placeholder[target=arg1_1]
    %arg2_1 := placeholder[target=arg2_1]
    %arg3_1 := placeholder[target=arg3_1]
    %arg4_1 := placeholder[target=arg4_1]
    %add_tensor := call_function[target=torch.ops.aten.add.Tensor](args = (%arg3_1, %arg0_1), kwargs = {})
    %mul_tensor := call_function[target=torch.ops.aten.mul.Tensor](args = (%add_tensor, %arg1_1), kwargs = {})
    %mul_tensor_1 := call_function[target=torch.ops.aten.mul.Tensor](args = (%arg4_1, %arg2_1), kwargs = {})
    %add_tensor_1 := call_function[target=torch.ops.aten.add.Tensor](args = (%mul_tensor, %mul_tensor_1), kwargs = {})
    %add_tensor_2 := call_function[target=torch.ops.aten.add.Tensor](args = (%arg2_1, 1.0), kwargs = {})
    return (add_tensor_2, add_tensor_1)

產生的 ExportGraphSignature 將是

ExportGraphSignature(
    input_specs=[
        InputSpec(kind=<InputKind.PARAMETER: 2>, arg=TensorArgument(name='arg0_1'), target='my_parameter'),
        InputSpec(kind=<InputKind.BUFFER: 3>, arg=TensorArgument(name='arg1_1'), target='my_buffer1'),
        InputSpec(kind=<InputKind.BUFFER: 3>, arg=TensorArgument(name='arg2_1'), target='my_buffer2'),
        InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='arg3_1'), target=None),
        InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='arg4_1'), target=None)
    ],
    output_specs=[
        OutputSpec(kind=<OutputKind.BUFFER_MUTATION: 3>, arg=TensorArgument(name='add_2'), target='my_buffer2'),
        OutputSpec(kind=<OutputKind.USER_OUTPUT: 1>, arg=TensorArgument(name='add_1'), target=None)
    ]
)
類別 torch.export.ModuleCallSignature(inputs: List[Union[torch.export.graph_signature.TensorArgument, torch.export.graph_signature.SymIntArgument, torch.export.graph_signature.ConstantArgument, torch.export.graph_signature.CustomObjArgument, torch.export.graph_signature.TokenArgument]], outputs: List[Union[torch.export.graph_signature.TensorArgument, torch.export.graph_signature.SymIntArgument, torch.export.graph_signature.ConstantArgument, torch.export.graph_signature.CustomObjArgument, torch.export.graph_signature.TokenArgument]], in_spec: torch.utils._pytree.TreeSpec, out_spec: torch.utils._pytree.TreeSpec)[原始碼]
類別 torch.export.ModuleCallEntry(fqn: 字串, signature: Union[torch.export.exported_program.ModuleCallSignature, NoneType] = None)[原始碼]
類別 torch.export.graph_signature.InputKind(value)[原始碼]

列舉。

類別 torch.export.graph_signature.InputSpec(kind: torch.export.graph_signature.InputKind, arg: Union[torch.export.graph_signature.TensorArgument, torch.export.graph_signature.SymIntArgument, torch.export.graph_signature.ConstantArgument, torch.export.graph_signature.CustomObjArgument, torch.export.graph_signature.TokenArgument], target: Union[字串, NoneType], persistent: Union[布林值, NoneType] = None)[原始碼]
類別 torch.export.graph_signature.OutputKind(value)[原始碼]

列舉。

類別 torch.export.graph_signature.OutputSpec(kind: torch.export.graph_signature.OutputKind, arg: Union[torch.export.graph_signature.TensorArgument, torch.export.graph_signature.SymIntArgument, torch.export.graph_signature.ConstantArgument, torch.export.graph_signature.CustomObjArgument, torch.export.graph_signature.TokenArgument], target: Union[字串, NoneType])[原始碼]
class torch.export.graph_signature.ExportGraphSignature(input_specs, output_specs)[source]

ExportGraphSignature 為匯出圖(Export Graph)的輸入/輸出簽章建模,匯出圖是一個具有更強不變量保證的 fx.Graph。

匯出圖是函數式的,不會透過 getattr 節點存取圖形中的「狀態」,例如參數或緩衝區。相反地,export() 保證參數、緩衝區和常數張量會從圖形中提取出來作為輸入。同樣地,對緩衝區的任何變更也不會包含在圖形中,而是將變更後緩衝區的更新值建模為匯出圖的額外輸出。

所有輸入和輸出的順序為

Inputs = [*parameters_buffers_constant_tensors, *flattened_user_inputs]
Outputs = [*mutated_inputs, *flattened_user_outputs]

例如,如果匯出以下模組

class CustomModule(nn.Module):
    def __init__(self):
        super(CustomModule, self).__init__()

        # Define a parameter
        self.my_parameter = nn.Parameter(torch.tensor(2.0))

        # Define two buffers
        self.register_buffer('my_buffer1', torch.tensor(3.0))
        self.register_buffer('my_buffer2', torch.tensor(4.0))

    def forward(self, x1, x2):
        # Use the parameter, buffers, and both inputs in the forward method
        output = (x1 + self.my_parameter) * self.my_buffer1 + x2 * self.my_buffer2

        # Mutate one of the buffers (e.g., increment it by 1)
        self.my_buffer2.add_(1.0) # In-place addition

        return output

產生的圖將是

graph():
    %arg0_1 := placeholder[target=arg0_1]
    %arg1_1 := placeholder[target=arg1_1]
    %arg2_1 := placeholder[target=arg2_1]
    %arg3_1 := placeholder[target=arg3_1]
    %arg4_1 := placeholder[target=arg4_1]
    %add_tensor := call_function[target=torch.ops.aten.add.Tensor](args = (%arg3_1, %arg0_1), kwargs = {})
    %mul_tensor := call_function[target=torch.ops.aten.mul.Tensor](args = (%add_tensor, %arg1_1), kwargs = {})
    %mul_tensor_1 := call_function[target=torch.ops.aten.mul.Tensor](args = (%arg4_1, %arg2_1), kwargs = {})
    %add_tensor_1 := call_function[target=torch.ops.aten.add.Tensor](args = (%mul_tensor, %mul_tensor_1), kwargs = {})
    %add_tensor_2 := call_function[target=torch.ops.aten.add.Tensor](args = (%arg2_1, 1.0), kwargs = {})
    return (add_tensor_2, add_tensor_1)

產生的 ExportGraphSignature 將是

ExportGraphSignature(
    input_specs=[
        InputSpec(kind=<InputKind.PARAMETER: 2>, arg=TensorArgument(name='arg0_1'), target='my_parameter'),
        InputSpec(kind=<InputKind.BUFFER: 3>, arg=TensorArgument(name='arg1_1'), target='my_buffer1'),
        InputSpec(kind=<InputKind.BUFFER: 3>, arg=TensorArgument(name='arg2_1'), target='my_buffer2'),
        InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='arg3_1'), target=None),
        InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='arg4_1'), target=None)
    ],
    output_specs=[
        OutputSpec(kind=<OutputKind.BUFFER_MUTATION: 3>, arg=TensorArgument(name='add_2'), target='my_buffer2'),
        OutputSpec(kind=<OutputKind.USER_OUTPUT: 1>, arg=TensorArgument(name='add_1'), target=None)
    ]
)
replace_all_uses(old, new)[source]

將簽章中所有使用舊名稱的地方替換為新名稱。

get_replace_hook()[source]
class torch.export.graph_signature.CustomObjArgument(name: str, class_fqn: str)[source]
class torch.export.unflatten.FlatArgsAdapter[source]

使用 input_spec 調整輸入參數,使其與 target_spec 對齊。

abstract adapt(target_spec, input_spec, input_args)[source]

注意:此適配器可能會變更給定的 input_args_with_path

回傳類型

List[Any]

class torch.export.unflatten.InterpreterModule(graph)[source]

一個使用 torch.fx.Interpreter 執行,而不是使用 GraphModule 通常使用的程式碼生成的模組。這提供了更好的堆疊追蹤資訊,並使除錯執行更容易。

torch.export.unflatten.unflatten(module, flat_args_adapter=None)[source]

展開 ExportedProgram,產生一個與原始 eager 模組具有相同模組層次結構的模組。如果您嘗試將 torch.export 與另一個需要模組層次結構而不是 torch.export 通常產生的平面圖的系統一起使用時,這會很有用。

注意

展開的模組的 args/kwargs 不一定與 eager 模組相符,因此進行模組交換(例如,self.submod = new_mod)不一定會成功。如果您需要交換模組,則需要設定 torch.export.export()preserve_module_call_signature 參數。

參數
  • **module** (ExportedProgram) – 要展開的 ExportedProgram。

  • **flat_args_adapter** (Optional[FlatArgsAdapter]) – 如果輸入 TreeSpec 與匯出的模組不符,則調整平面參數。

回傳值

一個 UnflattenedModule 的實例,它與匯出前的原始 eager 模組具有相同的模組層次結構。

回傳類型

UnflattenedModule

文件

取得 PyTorch 的完整開發人員文件

檢視文件

教學課程

取得適用於初學者和進階開發人員的深入教學課程

檢視教學課程

資源

尋找開發資源並獲得問題解答

檢視資源