torch.export¶
警告
此特性是一個正在積極開發中的原型,將來*會*有重大變更。
概述¶
torch.export.export() 接受一個 torch.nn.Module 並生成一個代表函式中張量計算的跟蹤圖(Ahead-of-Time,AOT 方式),該跟蹤圖隨後可以以不同的輸出執行或序列化。
import torch
from torch.export import export
class Mod(torch.nn.Module):
def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
a = torch.sin(x)
b = torch.cos(y)
return a + b
example_args = (torch.randn(10, 10), torch.randn(10, 10))
exported_program: torch.export.ExportedProgram = export(
Mod(), args=example_args
)
print(exported_program)
ExportedProgram:
class GraphModule(torch.nn.Module):
def forward(self, x: "f32[10, 10]", y: "f32[10, 10]"):
# code: a = torch.sin(x)
sin: "f32[10, 10]" = torch.ops.aten.sin.default(x)
# code: b = torch.cos(y)
cos: "f32[10, 10]" = torch.ops.aten.cos.default(y)
# code: return a + b
add: f32[10, 10] = torch.ops.aten.add.Tensor(sin, cos)
return (add,)
Graph signature:
ExportGraphSignature(
input_specs=[
InputSpec(
kind=<InputKind.USER_INPUT: 1>,
arg=TensorArgument(name='x'),
target=None,
persistent=None
),
InputSpec(
kind=<InputKind.USER_INPUT: 1>,
arg=TensorArgument(name='y'),
target=None,
persistent=None
)
],
output_specs=[
OutputSpec(
kind=<OutputKind.USER_OUTPUT: 1>,
arg=TensorArgument(name='add'),
target=None
)
]
)
Range constraints: {}
torch.export 生成一個具有以下不變數的清晰中間表示 (IR)。關於 IR 的更多規範可以在這裡找到。
穩健性:它保證是原始程式的穩健表示,並保持與原始程式相同的呼叫約定。
標準化:圖中沒有 Python 語義。原始程式中的子模組被內聯,形成一個完全扁平化的計算圖。
圖屬性:該圖是純函式式的,這意味著它不包含具有副作用的操作,例如變異或別名。它不會變異任何中間值、引數或緩衝區。
元資料:圖包含在跟蹤期間捕獲的元資料,例如來自使用者程式碼的堆疊跟蹤。
在底層,torch.export 利用了以下最新技術:
TorchDynamo (torch._dynamo) 是一個內部 API,它使用 CPython 的 Frame Evaluation API 來安全地跟蹤 PyTorch 圖。這大大改進了圖捕獲體驗,需要重寫的程式碼少得多才能完全跟蹤 PyTorch 程式碼。
AOT Autograd 提供一個函式化的 PyTorch 圖,並確保該圖被分解/降級到 ATen 運算子集。
Torch FX (torch.fx) 是圖的底層表示,允許靈活的基於 Python 的轉換。
現有框架¶
torch.compile() 也利用了與 torch.export 相同的 PT2 技術棧,但有一些不同:
JIT vs. AOT:
torch.compile()是一個 JIT(Just-In-Time,即時)編譯器,不旨在用於在部署之外生成編譯好的工件。部分 vs. 完全圖捕獲:當
torch.compile()遇到模型中無法跟蹤的部分時,它會“圖中斷”並回退到在 eager Python 執行時中執行程式。相比之下,torch.export的目標是獲得 PyTorch 模型的完整圖表示,因此當遇到無法跟蹤的內容時,它會出錯。由於torch.export生成的完整圖與任何 Python 特性或執行時分離,因此該圖可以被儲存、載入並在不同的環境和語言中執行。可用性權衡:由於
torch.compile()可以在遇到無法跟蹤的內容時回退到 Python 執行時,它更加靈活。torch.export則要求使用者提供更多資訊或重寫其程式碼以使其可跟蹤。
與 torch.fx.symbolic_trace() 相比,torch.export 使用 TorchDynamo 進行跟蹤,它在 Python 位元組碼層面操作,使其能夠跟蹤不受 Python 運算子過載支援限制的任意 Python 結構。此外,torch.export 精確跟蹤張量元資料,因此基於張量形狀等條件的控制流不會導致跟蹤失敗。通常情況下,torch.export 預計可以在更多使用者程式上工作,並生成更底層的圖(在 torch.ops.aten 運算子層面)。注意,使用者仍然可以將 torch.fx.symbolic_trace() 用作 torch.export 之前的預處理步驟。
與 torch.jit.script() 相比,torch.export 不捕獲 Python 控制流或資料結構,但它支援比 TorchScript 更多的 Python 語言特性(因為它更容易全面覆蓋 Python 位元組碼)。生成的圖更簡單,只有直線控制流(除了顯式控制流運算子)。
與 torch.jit.trace() 相比,torch.export 是穩健的:它能夠跟蹤對大小執行整數計算的程式碼,並記錄所有必要的副作用條件,以證明特定跟蹤對其他輸入有效。
匯出 PyTorch 模型¶
一個例子¶
主要入口點是 torch.export.export(),它接受一個可呼叫物件(torch.nn.Module、函式或方法)和示例輸入,並將計算圖捕獲到 torch.export.ExportedProgram 中。例如:
import torch
from torch.export import export
# Simple module for demonstration
class M(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.conv = torch.nn.Conv2d(
in_channels=3, out_channels=16, kernel_size=3, padding=1
)
self.relu = torch.nn.ReLU()
self.maxpool = torch.nn.MaxPool2d(kernel_size=3)
def forward(self, x: torch.Tensor, *, constant=None) -> torch.Tensor:
a = self.conv(x)
a.add_(constant)
return self.maxpool(self.relu(a))
example_args = (torch.randn(1, 3, 256, 256),)
example_kwargs = {"constant": torch.ones(1, 16, 256, 256)}
exported_program: torch.export.ExportedProgram = export(
M(), args=example_args, kwargs=example_kwargs
)
print(exported_program)
ExportedProgram:
class GraphModule(torch.nn.Module):
def forward(self, p_conv_weight: "f32[16, 3, 3, 3]", p_conv_bias: "f32[16]", x: "f32[1, 3, 256, 256]", constant: "f32[1, 16, 256, 256]"):
# code: a = self.conv(x)
conv2d: "f32[1, 16, 256, 256]" = torch.ops.aten.conv2d.default(x, p_conv_weight, p_conv_bias, [1, 1], [1, 1])
# code: a.add_(constant)
add_: "f32[1, 16, 256, 256]" = torch.ops.aten.add_.Tensor(conv2d, constant)
# code: return self.maxpool(self.relu(a))
relu: "f32[1, 16, 256, 256]" = torch.ops.aten.relu.default(add_)
max_pool2d: "f32[1, 16, 85, 85]" = torch.ops.aten.max_pool2d.default(relu, [3, 3], [3, 3])
return (max_pool2d,)
Graph signature:
ExportGraphSignature(
input_specs=[
InputSpec(
kind=<InputKind.PARAMETER: 2>,
arg=TensorArgument(name='p_conv_weight'),
target='conv.weight',
persistent=None
),
InputSpec(
kind=<InputKind.PARAMETER: 2>,
arg=TensorArgument(name='p_conv_bias'),
target='conv.bias',
persistent=None
),
InputSpec(
kind=<InputKind.USER_INPUT: 1>,
arg=TensorArgument(name='x'),
target=None,
persistent=None
),
InputSpec(
kind=<InputKind.USER_INPUT: 1>,
arg=TensorArgument(name='constant'),
target=None,
persistent=None
)
],
output_specs=[
OutputSpec(
kind=<OutputKind.USER_OUTPUT: 1>,
arg=TensorArgument(name='max_pool2d'),
target=None
)
]
)
Range constraints: {}
檢查 ExportedProgram,我們可以注意到以下幾點:
torch.fx.Graph包含原始程式的計算圖,以及原始程式碼的記錄,便於除錯。圖中僅包含在此處找到的
torch.ops.aten運算子和自定義運算子,並且完全函式式,不包含任何原地運算子,例如torch.add_。引數(conv 的權重和偏差)被提升為圖的輸入,導致圖中沒有
get_attr節點,這在torch.fx.symbolic_trace()的結果中曾存在。torch.export.ExportGraphSignature建模了輸入和輸出簽名,並指定哪些輸入是引數。圖中每個節點生成的張量的最終形狀和資料型別都已註明。例如,
convolution節點將生成一個數據型別為torch.float32,形狀為 (1, 16, 256, 256) 的張量。
非嚴格匯出¶
在 PyTorch 2.3 中,我們引入了一種新的跟蹤模式,稱為非嚴格模式 (non-strict mode)。它仍在進行加固,因此如果您遇到任何問題,請將其提交到 Github 並帶有“oncall: export”標籤。
在*非嚴格模式*下,我們使用 Python 直譯器跟蹤程式。您的程式碼將完全按照 eager 模式下的方式執行;唯一的區別是所有 Tensor 物件將被 ProxyTensor 替換,後者將記錄所有操作到圖中。
在*嚴格模式*下(目前是預設設定),我們首先使用 TorchDynamo(一個位元組碼分析引擎)跟蹤程式。TorchDynamo 實際上不執行您的 Python 程式碼。相反,它對程式碼進行符號分析並根據結果構建圖。這種分析使得 torch.export 能夠提供更強的安全性保證,但並非所有 Python 程式碼都受支援。
一個可能需要使用非嚴格模式的例子是,如果您遇到了不易解決的 TorchDynamo 不支援的特性,並且您知道該 Python 程式碼並非計算所必需。例如:
import contextlib
import torch
class ContextManager():
def __init__(self):
self.count = 0
def __enter__(self):
self.count += 1
def __exit__(self, exc_type, exc_value, traceback):
self.count -= 1
class M(torch.nn.Module):
def forward(self, x):
with ContextManager():
return x.sin() + x.cos()
export(M(), (torch.ones(3, 3),), strict=False) # Non-strict traces successfully
export(M(), (torch.ones(3, 3),)) # Strict mode fails with torch._dynamo.exc.Unsupported: ContextManager
在此示例中,使用非嚴格模式(透過 strict=False 標誌)的首次呼叫成功跟蹤,而使用嚴格模式(預設)的第二次呼叫則失敗,因為 TorchDynamo 不支援上下文管理器。一種選擇是重寫程式碼(參見torch.export 的限制),但考慮到上下文管理器不影響模型中的張量計算,我們可以選擇非嚴格模式的結果。
用於訓練和推理的匯出¶
在 PyTorch 2.5 中,我們引入了一個名為 export_for_training() 的新 API。它仍在進行加固,因此如果您遇到任何問題,請將其提交到 Github 並帶有“oncall: export”標籤。
在此 API 中,我們生成最通用的 IR,其中包含所有 ATen 運算子(包括函式式和非函式式),可用於在 eager PyTorch Autograd 中進行訓練。此 API 旨在用於 eager 訓練用例,例如 PT2 量化,並且很快將成為 torch.export.export 的預設 IR。要進一步瞭解此更改背後的動機,請參閱 https://dev-discuss.pytorch.org/t/why-pytorch-does-not-need-a-new-standardized-operator-set/2206
當此 API 與 run_decompositions() 結合使用時,您應該能夠獲得具有任何期望分解行為的推理 IR。
以下是一些示例:
class ConvBatchnorm(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.conv = torch.nn.Conv2d(1, 3, 1, 1)
self.bn = torch.nn.BatchNorm2d(3)
def forward(self, x):
x = self.conv(x)
x = self.bn(x)
return (x,)
mod = ConvBatchnorm()
inp = torch.randn(1, 1, 3, 3)
ep_for_training = torch.export.export_for_training(mod, (inp,))
print(ep_for_training)
ExportedProgram:
class GraphModule(torch.nn.Module):
def forward(self, p_conv_weight: "f32[3, 1, 1, 1]", p_conv_bias: "f32[3]", p_bn_weight: "f32[3]", p_bn_bias: "f32[3]", b_bn_running_mean: "f32[3]", b_bn_running_var: "f32[3]", b_bn_num_batches_tracked: "i64[]", x: "f32[1, 1, 3, 3]"):
conv2d: "f32[1, 3, 3, 3]" = torch.ops.aten.conv2d.default(x, p_conv_weight, p_conv_bias)
add_: "i64[]" = torch.ops.aten.add_.Tensor(b_bn_num_batches_tracked, 1)
batch_norm: "f32[1, 3, 3, 3]" = torch.ops.aten.batch_norm.default(conv2d, p_bn_weight, p_bn_bias, b_bn_running_mean, b_bn_running_var, True, 0.1, 1e-05, True)
return (batch_norm,)
從上面的輸出可以看出,除了圖中的運算子之外,export_for_training() 生成的 ExportedProgram 與 export() 生成的基本相同。您可以看到我們以最通用的形式捕獲了 batch_norm。此運算子是非函式式的,在執行推理時將被降級為不同的運算子。
您還可以透過 run_decompositions() 並進行任意自定義,從此 IR 轉換為推理 IR。
# Lower to core aten inference IR, but keep conv2d
decomp_table = torch.export.default_decompositions()
del decomp_table[torch.ops.aten.conv2d.default]
ep_for_inference = ep_for_training.run_decompositions(decomp_table)
print(ep_for_inference)
ExportedProgram:
class GraphModule(torch.nn.Module):
def forward(self, p_conv_weight: "f32[3, 1, 1, 1]", p_conv_bias: "f32[3]", p_bn_weight: "f32[3]", p_bn_bias: "f32[3]", b_bn_running_mean: "f32[3]", b_bn_running_var: "f32[3]", b_bn_num_batches_tracked: "i64[]", x: "f32[1, 1, 3, 3]"):
conv2d: "f32[1, 3, 3, 3]" = torch.ops.aten.conv2d.default(x, p_conv_weight, p_conv_bias)
add: "i64[]" = torch.ops.aten.add.Tensor(b_bn_num_batches_tracked, 1)
_native_batch_norm_legit_functional = torch.ops.aten._native_batch_norm_legit_functional.default(conv2d, p_bn_weight, p_bn_bias, b_bn_running_mean, b_bn_running_var, True, 0.1, 1e-05)
getitem: "f32[1, 3, 3, 3]" = _native_batch_norm_legit_functional[0]
getitem_3: "f32[3]" = _native_batch_norm_legit_functional[3]
getitem_4: "f32[3]" = _native_batch_norm_legit_functional[4]
return (getitem_3, getitem_4, add, getitem)
在這裡您可以看到,我們保留了 IR 中的 conv2d 運算子,同時分解了其餘部分。現在,該 IR 是一個函式式 IR,包含核心的 aten 運算子,除了 conv2d。
您還可以透過直接註冊您選擇的分解行為來實現更多自定義。
您還可以透過直接註冊自定義分解行為來實現更多自定義
# Lower to core aten inference IR, but customize conv2d
decomp_table = torch.export.default_decompositions()
def my_awesome_custom_conv2d_function(x, weight, bias, stride=[1, 1], padding=[0, 0], dilation=[1, 1], groups=1):
return 2 * torch.ops.aten.convolution(x, weight, bias, stride, padding, dilation, False, [0, 0], groups)
decomp_table[torch.ops.aten.conv2d.default] = my_awesome_conv2d_function
ep_for_inference = ep_for_training.run_decompositions(decomp_table)
print(ep_for_inference)
ExportedProgram:
class GraphModule(torch.nn.Module):
def forward(self, p_conv_weight: "f32[3, 1, 1, 1]", p_conv_bias: "f32[3]", p_bn_weight: "f32[3]", p_bn_bias: "f32[3]", b_bn_running_mean: "f32[3]", b_bn_running_var: "f32[3]", b_bn_num_batches_tracked: "i64[]", x: "f32[1, 1, 3, 3]"):
convolution: "f32[1, 3, 3, 3]" = torch.ops.aten.convolution.default(x, p_conv_weight, p_conv_bias, [1, 1], [0, 0], [1, 1], False, [0, 0], 1)
mul: "f32[1, 3, 3, 3]" = torch.ops.aten.mul.Tensor(convolution, 2)
add: "i64[]" = torch.ops.aten.add.Tensor(b_bn_num_batches_tracked, 1)
_native_batch_norm_legit_functional = torch.ops.aten._native_batch_norm_legit_functional.default(mul, p_bn_weight, p_bn_bias, b_bn_running_mean, b_bn_running_var, True, 0.1, 1e-05)
getitem: "f32[1, 3, 3, 3]" = _native_batch_norm_legit_functional[0]
getitem_3: "f32[3]" = _native_batch_norm_legit_functional[3]
getitem_4: "f32[3]" = _native_batch_norm_legit_functional[4];
return (getitem_3, getitem_4, add, getitem)
表達動態性¶
預設情況下,torch.export 將假定所有輸入形狀都是靜態的,並針對這些維度特化匯出的程式。然而,某些維度,例如批處理維度,可以是動態的並且在每次執行時不同。這些維度必須透過使用 torch.export.Dim() API 建立它們,並透過 dynamic_shapes 引數將其傳遞給 torch.export.export() 來指定。例如:
import torch
from torch.export import Dim, export
class M(torch.nn.Module):
def __init__(self):
super().__init__()
self.branch1 = torch.nn.Sequential(
torch.nn.Linear(64, 32), torch.nn.ReLU()
)
self.branch2 = torch.nn.Sequential(
torch.nn.Linear(128, 64), torch.nn.ReLU()
)
self.buffer = torch.ones(32)
def forward(self, x1, x2):
out1 = self.branch1(x1)
out2 = self.branch2(x2)
return (out1 + self.buffer, out2)
example_args = (torch.randn(32, 64), torch.randn(32, 128))
# Create a dynamic batch size
batch = Dim("batch")
# Specify that the first dimension of each input is that batch size
dynamic_shapes = {"x1": {0: batch}, "x2": {0: batch}}
exported_program: torch.export.ExportedProgram = export(
M(), args=example_args, dynamic_shapes=dynamic_shapes
)
print(exported_program)
ExportedProgram:
class GraphModule(torch.nn.Module):
def forward(self, p_branch1_0_weight: "f32[32, 64]", p_branch1_0_bias: "f32[32]", p_branch2_0_weight: "f32[64, 128]", p_branch2_0_bias: "f32[64]", c_buffer: "f32[32]", x1: "f32[s0, 64]", x2: "f32[s0, 128]"):
# code: out1 = self.branch1(x1)
linear: "f32[s0, 32]" = torch.ops.aten.linear.default(x1, p_branch1_0_weight, p_branch1_0_bias)
relu: "f32[s0, 32]" = torch.ops.aten.relu.default(linear)
# code: out2 = self.branch2(x2)
linear_1: "f32[s0, 64]" = torch.ops.aten.linear.default(x2, p_branch2_0_weight, p_branch2_0_bias)
relu_1: "f32[s0, 64]" = torch.ops.aten.relu.default(linear_1)
# code: return (out1 + self.buffer, out2)
add: "f32[s0, 32]" = torch.ops.aten.add.Tensor(relu, c_buffer)
return (add, relu_1)
Range constraints: {s0: VR[0, int_oo]}
一些額外需要注意的地方:
透過
torch.export.Dim()API 和dynamic_shapes引數,我們將每個輸入的第一個維度指定為動態的。檢視輸入x1和x2,它們的符號形狀是 (s0, 64) 和 (s0, 128),而不是我們作為示例輸入傳遞的 (32, 64) 和 (32, 128) 形狀張量。s0是一個符號,表示此維度可以是一系列值。exported_program.range_constraints描述了圖中出現的每個符號的範圍。在本例中,我們看到s0的範圍是 [0, int_oo]。出於難以在此解釋的技術原因,它們被假定不為 0 或 1。這並非一個錯誤,也不一定意味著匯出的程式對於維度 0 或 1 將無法工作。有關此主題的深入討論,請參閱 The 0/1 Specialization Problem。
我們還可以指定輸入形狀之間更具表達性的關係,例如一對形狀可能相差一,一個形狀可能是另一個的兩倍,或者一個形狀是偶數。例如:
class M(torch.nn.Module):
def forward(self, x, y):
return x + y[1:]
x, y = torch.randn(5), torch.randn(6)
dimx = torch.export.Dim("dimx", min=3, max=6)
dimy = dimx + 1
exported_program = torch.export.export(
M(), (x, y), dynamic_shapes=({0: dimx}, {0: dimy}),
)
print(exported_program)
ExportedProgram:
class GraphModule(torch.nn.Module):
def forward(self, x: "f32[s0]", y: "f32[s0 + 1]"):
# code: return x + y[1:]
slice_1: "f32[s0]" = torch.ops.aten.slice.Tensor(y, 0, 1, 9223372036854775807)
add: "f32[s0]" = torch.ops.aten.add.Tensor(x, slice_1)
return (add,)
Range constraints: {s0: VR[3, 6], s0 + 1: VR[4, 7]}
一些需要注意的地方:
透過為第一個輸入指定
{0: dimx},我們看到第一個輸入的最終形狀現在是動態的,為[s0]。現在透過為第二個輸入指定{0: dimy},我們看到第二個輸入的最終形狀也是動態的。然而,因為我們表達了dimy = dimx + 1,所以y的形狀沒有包含新的符號,而是與x中使用的符號s0表示。我們可以看到dimy = dimx + 1的關係透過s0 + 1顯示。檢視範圍約束,我們看到
s0的範圍是 [3, 6],這是最初指定的;我們可以看到s0 + 1的求解範圍是 [4, 7]。
序列化¶
要儲存 ExportedProgram,使用者可以使用 torch.export.save() 和 torch.export.load() API。一個約定是使用 .pt2 副檔名儲存 ExportedProgram。
一個例子:
import torch
import io
class MyModule(torch.nn.Module):
def forward(self, x):
return x + 10
exported_program = torch.export.export(MyModule(), torch.randn(5))
torch.export.save(exported_program, 'exported_program.pt2')
saved_exported_program = torch.export.load('exported_program.pt2')
特化¶
理解 torch.export 行為的一個關鍵概念是靜態值和動態值之間的區別。
一個動態值是每次執行時可以改變的值。它們的行為類似於 Python 函式的普通引數——您可以為引數傳遞不同的值,並期望您的函式做正確的事情。張量*資料*被視為動態的。
一個靜態值是在匯出時固定的值,並且不能在匯出程式的執行之間改變。當在跟蹤期間遇到該值時,匯出器會將其視為常量並將其硬編碼到圖中。
當執行操作(例如 x + y)且所有輸入都是靜態的時,操作的輸出將直接硬編碼到圖中,並且操作本身將不會出現在圖中(即它將被常量摺疊)。
當一個值被硬編碼到圖中時,我們稱該圖已針對該值進行了*特化*。
以下值為靜態:
輸入張量形狀¶
預設情況下,torch.export 將跟蹤程式時特化輸入張量的形狀,除非透過 torch.export 的 dynamic_shapes 引數將某個維度指定為動態。這意味著如果存在依賴於形狀的控制流,torch.export 將特化到使用給定示例輸入時所採取的分支。例如:
import torch
from torch.export import export
class Mod(torch.nn.Module):
def forward(self, x):
if x.shape[0] > 5:
return x + 1
else:
return x - 1
example_inputs = (torch.rand(10, 2),)
exported_program = export(Mod(), example_inputs)
print(exported_program)
ExportedProgram:
class GraphModule(torch.nn.Module):
def forward(self, x: "f32[10, 2]"):
# code: return x + 1
add: "f32[10, 2]" = torch.ops.aten.add.Tensor(x, 1)
return (add,)
條件表示式 (x.shape[0] > 5) 不會出現在 ExportedProgram 中,因為示例輸入的靜態形狀是 (10, 2)。由於 torch.export 會特化到輸入的靜態形狀,因此 else 分支 (x - 1) 將永遠不會被執行到。為了在跟蹤圖中保留基於張量形狀的動態分支行為,需要使用 torch.export.Dim() 指定輸入張量 (x.shape[0]) 的維度是動態的,並且需要重寫原始碼。
請注意,作為模組狀態一部分的張量(例如引數和緩衝區)總是具有靜態形狀。
Python 原語¶
torch.export 還會特化 Python 原語,例如 int, float, bool 和 str。然而,它們確實有動態變體,例如 SymInt, SymFloat 和 SymBool。
例如:
import torch
from torch.export import export
class Mod(torch.nn.Module):
def forward(self, x: torch.Tensor, const: int, times: int):
for i in range(times):
x = x + const
return x
example_inputs = (torch.rand(2, 2), 1, 3)
exported_program = export(Mod(), example_inputs)
print(exported_program)
ExportedProgram:
class GraphModule(torch.nn.Module):
def forward(self, x: "f32[2, 2]", const, times):
# code: x = x + const
add: "f32[2, 2]" = torch.ops.aten.add.Tensor(x, 1)
add_1: "f32[2, 2]" = torch.ops.aten.add.Tensor(add, 1)
add_2: "f32[2, 2]" = torch.ops.aten.add.Tensor(add_1, 1)
return (add_2,)
因為整數是特殊化的,torch.ops.aten.add.Tensor 操作都使用硬編碼常量 1 計算,而不是 const。如果在執行時,使用者為 const 傳入一個與匯出時使用的值 1 不同的值(例如 2),這將導致錯誤。此外,for 迴圈中使用的 times 迭代器也透過 3 次重複的 torch.ops.aten.add.Tensor 呼叫在圖中被“內聯”了,並且輸入 times 從未使用。
Python 容器¶
Python 容器(List、Dict、NamedTuple 等)被認為是具有靜態結構的。
`torch.export` 的限制¶
圖中斷¶
`torch.export` 是一種一次性捕獲 PyTorch 程式計算圖的過程,由於幾乎不可能支援跟蹤所有 PyTorch 和 Python 特性,它最終可能會遇到程式中無法跟蹤的部分。對於 torch.compile,不支援的操作會導致“圖中斷”,並且不支援的操作將使用預設的 Python 求值方式執行。相比之下,torch.export 將要求使用者提供額外資訊或重寫部分程式碼使其可跟蹤。由於跟蹤是基於在 Python 位元組碼級別進行求值的 TorchDynamo,與以前的跟蹤框架相比,所需的重寫將顯著減少。
遇到圖中斷時,ExportDB 是一個很好的資源,可以瞭解支援和不支援的程式型別,以及重寫程式使其可跟蹤的方法。
繞過這些圖中斷的一種選擇是使用 非嚴格匯出(non-strict export)
資料/形狀相關的控制流¶
當形狀未被特殊化時,在資料相關的控制流(如 if x.shape[0] > 2)上也可能遇到圖中斷,因為跟蹤編譯器不可能在不為組合爆炸式的路徑數量生成程式碼的情況下處理這種情況。在這種情況下,使用者需要使用特殊的控制流運算子重寫其程式碼。目前,我們支援 torch.cond 來表達類似 if-else 的控制流(更多內容即將推出!)。
運算子缺失 Fake/Meta/Abstract 核心¶
在跟蹤時,所有運算子都需要一個 FakeTensor 核心(也稱為 meta 核心、abstract impl)。這用於推斷此運算子的輸入/輸出形狀。
更多詳情請參見 torch.library.register_fake()。
不幸的是,如果您的模型使用了尚無 FakeTensor 核心實現 的 ATen 運算子,請提交一個 issue。
閱讀更多¶
PyTorch 開發者深入探討
API 參考¶
- torch.export.export(mod, args, kwargs=None, *, dynamic_shapes=None, strict=True, preserve_module_call_signature=())[source][source]¶
`export()` 接受任何 `nn.Module` 以及示例輸入,並以前向編譯 (AOT) 的方式生成一個表示函式中僅包含張量計算的跟蹤圖,該圖隨後可以用於不同的輸入進行執行或序列化。跟蹤圖 (1) 在功能性 ATen 運算子集合中生成規範化的運算子(以及任何使用者指定的自定義運算子),(2) 消除了所有 Python 控制流和資料結構(某些情況除外),並且 (3) 記錄了證明這種規範化和控制流消除對於未來的輸入是健全所需的一組形狀約束。
健全性保證
在跟蹤過程中,`export()` 會記錄使用者程式和底層 PyTorch 運算子核心做出的與形狀相關的假設。只有當這些假設成立時,輸出的
ExportedProgram才被認為是有效的。跟蹤對輸入張量的形狀(而非值)做出假設。為了使
export()成功,這些假設必須在圖捕獲時進行驗證。具體而言對輸入張量靜態形狀的假設無需額外工作即可自動驗證。
對輸入張量動態形狀的假設需要透過使用
Dim()API 構建動態維度,並透過dynamic_shapes引數將它們與示例輸入關聯起來進行顯式指定。
如果任何假設無法驗證,將引發致命錯誤。發生這種情況時,錯誤訊息將包含驗證假設所需的規範建議修復。例如,
export()可能會建議對動態維度dim0_x的定義進行以下修復,該維度可能出現在與輸入x關聯的形狀中,之前被定義為Dim("dim0_x")dim = Dim("dim0_x", max=5)
此示例意味著生成的程式碼要求輸入
x的維度 0 小於或等於 5 才能有效。您可以檢查動態維度定義的建議修復,然後將其逐字複製到您的程式碼中,而無需更改對export()的呼叫中的dynamic_shapes引數。- 引數
mod (Module) – 將跟蹤此模組的 forward 方法。
dynamic_shapes (Optional[Union[dict[str, Any], tuple[Any], list[Any]]]) –
一個可選引數,其型別應為:1) 一個字典,將
f的引數名對映到其動態形狀規範;2) 一個元組,按原始順序指定每個輸入的動態形狀規範。如果您正在為關鍵字引數指定動態性,則需要按照原始函式簽名中定義的順序傳遞它們。張量引數的動態形狀可以指定為 (1) 一個字典,將動態維度索引對映到
Dim()型別,此字典中無需包含靜態維度索引,但如果包含,則應對映到 None;或 (2) 一個Dim()型別或 None 的元組/列表,其中Dim()型別對應於動態維度,靜態維度由 None 表示。作為字典或張量元組/列表的引數,則透過使用包含規範的對映或序列進行遞迴指定。strict (bool) – 啟用時(預設),匯出函式將透過 TorchDynamo 跟蹤程式,這將確保生成圖的健全性。否則,匯出的程式將不會驗證圖中隱含的假設,可能導致原始模型與匯出模型之間行為不一致。當用戶需要解決跟蹤器中的錯誤,或者只是想逐步在其模型中啟用安全性時,此選項很有用。請注意,這不會影響生成的 IR 規範有所不同,無論此處傳遞什麼值,模型都將以相同的方式序列化。警告:此選項是實驗性的,請自行承擔風險使用。
- 返回值
包含跟蹤的可呼叫物件的
ExportedProgram。- 返回型別
可接受的輸入/輸出型別
可接受的輸入(對於
args和kwargs)和輸出型別包括基本型別,即
torch.Tensor、int、float、bool和str。資料類 (Dataclasses),但必須先呼叫
register_dataclass()進行註冊。包含上述所有型別的巢狀資料結構,包括
dict、list、tuple、namedtuple和OrderedDict。
- torch.export.save(ep, f, *, extra_files=None, opset_version=None, pickle_protocol=2)[source][source]¶
警告
正在積極開發中,儲存的檔案在 PyTorch 的新版本中可能無法使用。
將一個
ExportedProgram儲存到檔案類物件中。然後可以使用 Python APItorch.export.load載入它。- 引數
ep (ExportedProgram) – 要儲存的匯出程式。
f (str | os.PathLike[str] | IO[bytes]) – 實現 write 和 flush) 或包含檔名的字串。
extra_files (Optional[Dict[str, Any]]) – 檔名到內容的對映,這些內容將作為 f 的一部分儲存。
opset_version (Optional[Dict[str, int]]) – 將 opset 名稱對映到此 opset 版本的一個對映
pickle_protocol (int) – 可以指定以覆蓋預設協議
示例
import torch import io class MyModule(torch.nn.Module): def forward(self, x): return x + 10 ep = torch.export.export(MyModule(), (torch.randn(5),)) # Save to file torch.export.save(ep, 'exported_program.pt2') # Save to io.BytesIO buffer buffer = io.BytesIO() torch.export.save(ep, buffer) # Save with extra files extra_files = {'foo.txt': b'bar'.decode('utf-8')} torch.export.save(ep, 'exported_program.pt2', extra_files=extra_files)
- torch.export.load(f, *, extra_files=None, expected_opset_version=None)[source][source]¶
警告
正在積極開發中,儲存的檔案在 PyTorch 的新版本中可能無法使用。
載入之前使用
torch.export.save儲存的ExportedProgram。- 引數
- 返回值
一個
ExportedProgram物件- 返回型別
示例
import torch import io # Load ExportedProgram from file ep = torch.export.load('exported_program.pt2') # Load ExportedProgram from io.BytesIO object with open('exported_program.pt2', 'rb') as f: buffer = io.BytesIO(f.read()) buffer.seek(0) ep = torch.export.load(buffer) # Load with extra files. extra_files = {'foo.txt': ''} # values will be replaced with data ep = torch.export.load('exported_program.pt2', extra_files=extra_files) print(extra_files['foo.txt']) print(ep(torch.randn(5)))
- torch.export.register_dataclass(cls, *, serialized_type_name=None)[source][source]¶
將一個數據類註冊為
torch.export.export()的有效輸入/輸出型別。- 引數
示例
import torch from dataclasses import dataclass @dataclass class InputDataClass: feature: torch.Tensor bias: int @dataclass class OutputDataClass: res: torch.Tensor torch.export.register_dataclass(InputDataClass) torch.export.register_dataclass(OutputDataClass) class Mod(torch.nn.Module): def forward(self, x: InputDataClass) -> OutputDataClass: res = x.feature + x.bias return OutputDataClass(res=res) ep = torch.export.export(Mod(), (InputDataClass(torch.ones(2, 2), 1), )) print(ep)
- torch.export.dynamic_shapes.Dim(name, *, min=None, max=None)[source][source]¶
Dim()構建一個型別,類似於帶範圍的命名符號整數。它可以用來描述動態張量維度的多個可能值。請注意,同一張量的不同動態維度,或不同張量的動態維度,可以使用相同的型別來描述。
- torch.export.exported_program.default_decompositions()[source][source]¶
這是預設的分解表,其中包含將所有 ATEN 運算子分解為核心 aten opset 的規則。將此 API 與
run_decompositions()一起使用。- 返回型別
- torch.export.dims(*names, min=None, max=None)[source][source]¶
用於建立多個
Dim()型別的實用工具。- 返回值
一個包含
Dim()型別的元組。- 返回型別
tuple[torch.export.dynamic_shapes._Dim, …]
- class torch.export.dynamic_shapes.ShapesCollection[source][source]¶
`dynamic_shapes` 的構建器。用於將動態形狀規範分配給出現在輸入中的張量。
這在
args()是巢狀輸入結構時特別有用,此時索引輸入張量比在dynamic_shapes()規範中複製args()的結構更容易。示例
args = ({"x": tensor_x, "others": [tensor_y, tensor_z]}) dim = torch.export.Dim(...) dynamic_shapes = torch.export.ShapesCollection() dynamic_shapes[tensor_x] = (dim, dim + 1, 8) dynamic_shapes[tensor_y] = {0: dim * 2} # This is equivalent to the following (now auto-generated): # dynamic_shapes = {"x": (dim, dim + 1, 8), "others": [{0: dim * 2}, None]} torch.export(..., args, dynamic_shapes=dynamic_shapes)
- dynamic_shapes(m, args, kwargs=None)[source][source]¶
根據
args()和kwargs()生成dynamic_shapes()的 pytree 結構。
- torch.export.dynamic_shapes.refine_dynamic_shapes_from_suggested_fixes(msg, dynamic_shapes)[source][source]¶
使用
dynamic_shapes()匯出時,如果規範與從模型跟蹤中推斷出的約束不匹配,則匯出可能會因 ConstraintViolation 錯誤而失敗。錯誤訊息可能會提供建議修復——可以對dynamic_shapes()進行的更改,以便成功匯出。ConstraintViolation 錯誤訊息示例
Suggested fixes: dim = Dim('dim', min=3, max=6) # this just refines the dim's range dim = 4 # this specializes to a constant dy = dx + 1 # dy was specified as an independent dim, but is actually tied to dx with this relation
這是一個輔助函式,它接受 ConstraintViolation 錯誤訊息和原始的
dynamic_shapes()規範,並返回一個包含建議修復的新dynamic_shapes()規範。示例用法
try: ep = export(mod, args, dynamic_shapes=dynamic_shapes) except torch._dynamo.exc.UserError as exc: new_shapes = refine_dynamic_shapes_from_suggested_fixes( exc.msg, dynamic_shapes ) ep = export(mod, args, dynamic_shapes=new_shapes)
- class torch.export.ExportedProgram(root, graph, graph_signature, state_dict, range_constraints, module_call_graph, example_inputs=None, constants=None, *, verifiers=None)[source][source]¶
`export()` 生成的程式包。它包含一個表示張量計算的
torch.fx.Graph、一個包含所有 lifted 引數和 buffers 的張量值的 `state_dict`,以及各種元資料。您可以像呼叫
export()跟蹤的原始可呼叫物件一樣,以相同的呼叫約定呼叫ExportedProgram。要在圖上執行變換,請使用 `.module` 屬性訪問一個
torch.fx.GraphModule。然後可以使用 FX 變換 來重寫圖。之後,您可以簡單地再次使用export()構建一個正確的ExportedProgram。- named_buffers()[source][source]¶
返回一個用於遍歷原始模組緩衝區的迭代器,生成緩衝區的名稱及其緩衝區本身。
警告
此API是實驗性的,並且 *不* 向後相容。
- 返回型別
- named_parameters()[source][source]¶
返回一個用於遍歷原始模組引數的迭代器,生成引數的名稱及其引數本身。
警告
此API是實驗性的,並且 *不* 向後相容。
- 返回型別
- run_decompositions(decomp_table=None, decompose_custom_triton_ops=False)[source][source]¶
對匯出的程式執行一組分解操作,並返回一個新的匯出程式。預設情況下,我們將執行核心 ATen 分解以獲取 Core ATen Operator Set 中的運算子。
目前,我們不會分解聯合圖。
- 引數
decomp_table (Optional[dict[torch._ops.OperatorBase, Callable]]) – 一個可選引數,用於指定 Aten 操作的分解行為。(1) 如果為None,我們將分解為核心aten分解。(2) 如果為空,我們不會分解任何運算子
- 返回型別
一些示例
如果您不想分解任何東西
ep = torch.export.export(model, ...) ep = ep.run_decompositions(decomp_table={})
如果您想獲取核心aten運算子集,但排除某些運算子,您可以執行以下操作
ep = torch.export.export(model, ...) decomp_table = torch.export.default_decompositions() decomp_table[your_op] = your_custom_decomp ep = ep.run_decompositions(decomp_table=decomp_table)
- class torch.export.ExportBackwardSignature(gradients_to_parameters: dict[str, str], gradients_to_user_inputs: dict[str, str], loss_output: str)[source][source]¶
- class torch.export.ExportGraphSignature(input_specs, output_specs)[source][source]¶
ExportGraphSignature建模了 Export Graph 的輸入/輸出簽名,它是一個具有更強不變數保證的 fx.Graph。Export Graph 是函式式的,並且不透過
getattr節點訪問圖中的引數或緩衝區等“狀態”。相反,export()保證引數、緩衝區和常量張量作為輸入從圖中提取出來。類似地,對緩衝區的任何變動也不包含在圖中,相反,變動後的緩衝區的更新值被建模為 Export Graph 的附加輸出。所有輸入和輸出的順序是
Inputs = [*parameters_buffers_constant_tensors, *flattened_user_inputs] Outputs = [*mutated_inputs, *flattened_user_outputs]
例如,如果匯出以下模組
class CustomModule(nn.Module): def __init__(self) -> None: super(CustomModule, self).__init__() # Define a parameter self.my_parameter = nn.Parameter(torch.tensor(2.0)) # Define two buffers self.register_buffer('my_buffer1', torch.tensor(3.0)) self.register_buffer('my_buffer2', torch.tensor(4.0)) def forward(self, x1, x2): # Use the parameter, buffers, and both inputs in the forward method output = (x1 + self.my_parameter) * self.my_buffer1 + x2 * self.my_buffer2 # Mutate one of the buffers (e.g., increment it by 1) self.my_buffer2.add_(1.0) # In-place addition return output
結果圖為
graph(): %arg0_1 := placeholder[target=arg0_1] %arg1_1 := placeholder[target=arg1_1] %arg2_1 := placeholder[target=arg2_1] %arg3_1 := placeholder[target=arg3_1] %arg4_1 := placeholder[target=arg4_1] %add_tensor := call_function[target=torch.ops.aten.add.Tensor](args = (%arg3_1, %arg0_1), kwargs = {}) %mul_tensor := call_function[target=torch.ops.aten.mul.Tensor](args = (%add_tensor, %arg1_1), kwargs = {}) %mul_tensor_1 := call_function[target=torch.ops.aten.mul.Tensor](args = (%arg4_1, %arg2_1), kwargs = {}) %add_tensor_1 := call_function[target=torch.ops.aten.add.Tensor](args = (%mul_tensor, %mul_tensor_1), kwargs = {}) %add_tensor_2 := call_function[target=torch.ops.aten.add.Tensor](args = (%arg2_1, 1.0), kwargs = {}) return (add_tensor_2, add_tensor_1)
結果 ExportGraphSignature 為
ExportGraphSignature( input_specs=[ InputSpec(kind=<InputKind.PARAMETER: 2>, arg=TensorArgument(name='arg0_1'), target='my_parameter'), InputSpec(kind=<InputKind.BUFFER: 3>, arg=TensorArgument(name='arg1_1'), target='my_buffer1'), InputSpec(kind=<InputKind.BUFFER: 3>, arg=TensorArgument(name='arg2_1'), target='my_buffer2'), InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='arg3_1'), target=None), InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='arg4_1'), target=None) ], output_specs=[ OutputSpec(kind=<OutputKind.BUFFER_MUTATION: 3>, arg=TensorArgument(name='add_2'), target='my_buffer2'), OutputSpec(kind=<OutputKind.USER_OUTPUT: 1>, arg=TensorArgument(name='add_1'), target=None) ] )
- class torch.export.ModuleCallSignature(inputs: list[Union[torch.export.graph_signature.TensorArgument, torch.export.graph_signature.SymIntArgument, torch.export.graph_signature.SymFloatArgument, torch.export.graph_signature.SymBoolArgument, torch.export.graph_signature.ConstantArgument, torch.export.graph_signature.CustomObjArgument, torch.export.graph_signature.TokenArgument]], outputs: list[Union[torch.export.graph_signature.TensorArgument, torch.export.graph_signature.SymIntArgument, torch.export.graph_signature.SymFloatArgument, torch.export.graph_signature.SymBoolArgument, torch.export.graph_signature.ConstantArgument, torch.export.graph_signature.CustomObjArgument, torch.export.graph_signature.TokenArgument]], in_spec: torch.utils._pytree.TreeSpec, out_spec: torch.utils._pytree.TreeSpec, forward_arg_names: Optional[list[str]] = None)[source][source]¶
- class torch.export.ModuleCallEntry(fqn: str, signature: Optional[torch.export.exported_program.ModuleCallSignature] = None)[source][source]¶
- class torch.export.decomp_utils.CustomDecompTable[source][source]¶
這是一個自定義字典,專門用於在匯出中處理 decomp_table。我們需要它的原因是,在新的體系中,您只能透過 *刪除* decomp table 中的運算子來保留它。這對於自定義運算子來說是個問題,因為我們不知道自定義運算子何時實際載入到排程器(dispatcher)中。因此,我們需要記錄自定義運算子的操作,直到我們真正需要具象化它(也就是執行分解過程時)。
- 我們維護的不變數是
所有 aten 分解在初始化時載入
當用戶從表中讀取時,我們會具象化所有運算子,以使排程器更有可能載入自定義運算子。
如果是寫入操作,我們不一定具象化
我們在匯出期間最後一次載入,就在呼叫 run_decompositions() 之前。
- class torch.export.graph_signature.InputSpec(kind: torch.export.graph_signature.InputKind, arg: Union[torch.export.graph_signature.TensorArgument, torch.export.graph_signature.SymIntArgument, torch.export.graph_signature.SymFloatArgument, torch.export.graph_signature.SymBoolArgument, torch.export.graph_signature.ConstantArgument, torch.export.graph_signature.CustomObjArgument, torch.export.graph_signature.TokenArgument], target: Optional[str], persistent: Optional[bool] = None)[source][source]¶
- class torch.export.graph_signature.OutputSpec(kind: torch.export.graph_signature.OutputKind, arg: Union[torch.export.graph_signature.TensorArgument, torch.export.graph_signature.SymIntArgument, torch.export.graph_signature.SymFloatArgument, torch.export.graph_signature.SymBoolArgument, torch.export.graph_signature.ConstantArgument, torch.export.graph_signature.CustomObjArgument, torch.export.graph_signature.TokenArgument], target: Optional[str])[source][source]¶
- class torch.export.graph_signature.ExportGraphSignature(input_specs, output_specs)[source][source]¶
ExportGraphSignature建模了 Export Graph 的輸入/輸出簽名,它是一個具有更強不變數保證的 fx.Graph。Export Graph 是函式式的,並且不透過
getattr節點訪問圖中的引數或緩衝區等“狀態”。相反,export()保證引數、緩衝區和常量張量作為輸入從圖中提取出來。類似地,對緩衝區的任何變動也不包含在圖中,相反,變動後的緩衝區的更新值被建模為 Export Graph 的附加輸出。所有輸入和輸出的順序是
Inputs = [*parameters_buffers_constant_tensors, *flattened_user_inputs] Outputs = [*mutated_inputs, *flattened_user_outputs]
例如,如果匯出以下模組
class CustomModule(nn.Module): def __init__(self) -> None: super(CustomModule, self).__init__() # Define a parameter self.my_parameter = nn.Parameter(torch.tensor(2.0)) # Define two buffers self.register_buffer('my_buffer1', torch.tensor(3.0)) self.register_buffer('my_buffer2', torch.tensor(4.0)) def forward(self, x1, x2): # Use the parameter, buffers, and both inputs in the forward method output = (x1 + self.my_parameter) * self.my_buffer1 + x2 * self.my_buffer2 # Mutate one of the buffers (e.g., increment it by 1) self.my_buffer2.add_(1.0) # In-place addition return output
結果圖為
graph(): %arg0_1 := placeholder[target=arg0_1] %arg1_1 := placeholder[target=arg1_1] %arg2_1 := placeholder[target=arg2_1] %arg3_1 := placeholder[target=arg3_1] %arg4_1 := placeholder[target=arg4_1] %add_tensor := call_function[target=torch.ops.aten.add.Tensor](args = (%arg3_1, %arg0_1), kwargs = {}) %mul_tensor := call_function[target=torch.ops.aten.mul.Tensor](args = (%add_tensor, %arg1_1), kwargs = {}) %mul_tensor_1 := call_function[target=torch.ops.aten.mul.Tensor](args = (%arg4_1, %arg2_1), kwargs = {}) %add_tensor_1 := call_function[target=torch.ops.aten.add.Tensor](args = (%mul_tensor, %mul_tensor_1), kwargs = {}) %add_tensor_2 := call_function[target=torch.ops.aten.add.Tensor](args = (%arg2_1, 1.0), kwargs = {}) return (add_tensor_2, add_tensor_1)
結果 ExportGraphSignature 為
ExportGraphSignature( input_specs=[ InputSpec(kind=<InputKind.PARAMETER: 2>, arg=TensorArgument(name='arg0_1'), target='my_parameter'), InputSpec(kind=<InputKind.BUFFER: 3>, arg=TensorArgument(name='arg1_1'), target='my_buffer1'), InputSpec(kind=<InputKind.BUFFER: 3>, arg=TensorArgument(name='arg2_1'), target='my_buffer2'), InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='arg3_1'), target=None), InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='arg4_1'), target=None) ], output_specs=[ OutputSpec(kind=<OutputKind.BUFFER_MUTATION: 3>, arg=TensorArgument(name='add_2'), target='my_buffer2'), OutputSpec(kind=<OutputKind.USER_OUTPUT: 1>, arg=TensorArgument(name='add_1'), target=None) ] )
- class torch.export.graph_signature.CustomObjArgument(name: str, class_fqn: str, fake_val: Optional[torch._library.fake_class_registry.FakeScriptObject] = None)[source][source]¶
- class torch.export.unflatten.InterpreterModule(graph, ty=None)[source][source]¶
一個使用
torch.fx.Interpreter來執行的模組,而不是使用 GraphModule 通常的程式碼生成方式。這提供了更好的堆疊跟蹤資訊,並使得除錯執行更容易。
- class torch.export.unflatten.InterpreterModuleDispatcher(attrs, call_modules)[source][source]¶
一個包含一系列 InterpreterModules 的模組,這些 InterpreterModules 對應於該模組的一系列呼叫。對模組的每次呼叫都會分派給下一個 InterpreterModule,並在最後一個之後迴圈回第一個。
- torch.export.unflatten.unflatten(module, flat_args_adapter=None)[source][source]¶
展平 ExportedProgram,產生一個具有與原始 eager 模組相同模組層級的模組。當您嘗試將
torch.export用於期望模組層級而不是torch.export通常產生的平坦圖的另一個系統時,這可能很有用。注意
展平後的模組的 args/kwargs 不一定與 eager 模組匹配,因此進行模組交換(例如
self.submod = new_mod)不一定奏效。如果您需要替換一個模組,您需要設定torch.export.export()的preserve_module_call_signature引數。- 引數
module (ExportedProgram) – 要展平的 ExportedProgram。
flat_args_adapter (Optional[FlatArgsAdapter]) – 如果輸入 TreeSpec 與匯出模組的不匹配,則適配平坦引數。
- 返回值
UnflattenedModule的例項,它具有與原始 eager 模組在匯出前相同的模組層級。- 返回型別
UnflattenedModule
- torch.export.passes.move_to_device_pass(ep, location)[source][source]¶
將匯出的程式移動到給定的裝置。
- 引數
ep (ExportedProgram) – 要移動的匯出程式。
location (Union[torch.device, str, Dict[str, str]]) – 要將匯出程式移動到的裝置。如果是字串,則解釋為裝置名稱。如果是字典,則解釋為從現有裝置到目標裝置的對映。
- 返回值
已移動的匯出程式。
- 返回型別