注意
點選此處下載完整示例程式碼
ONNX 簡介 || 將 PyTorch 模型匯出到 ONNX || 擴充套件 ONNX 匯出器的運算元支援 || `將帶有控制流的模型匯出到 ONNX
將帶有控制流的模型匯出到 ONNX¶
作者:Xavier Dupré
概述¶
本教程演示了在將 PyTorch 模型匯出到 ONNX 時如何處理控制流邏輯。它強調了直接匯出條件語句的挑戰,並提供了規避這些挑戰的解決方案。
條件邏輯不能匯出到 ONNX,除非將其重構為使用 torch.cond()。讓我們從一個實現測試的簡單模型開始。
你將學到什麼
如何重構模型,使其使用
torch.cond()以便匯出。如何將帶有控制流邏輯的模型匯出到 ONNX。
如何使用 ONNX 最佳化器最佳化匯出的模型。
定義模型¶
定義了兩個模型
ForwardWithControlFlowTest:一個包含 if-else 條件的 forward 方法的模型。
ModelWithControlFlowTest:一個將 ForwardWithControlFlowTest 作為簡單 MLP 一部分納入的模型。使用隨機輸入張量對模型進行測試,以確認其按預期執行。
class ForwardWithControlFlowTest(torch.nn.Module):
def forward(self, x):
if x.sum():
return x * 2
return -x
class ModelWithControlFlowTest(torch.nn.Module):
def __init__(self):
super().__init__()
self.mlp = torch.nn.Sequential(
torch.nn.Linear(3, 2),
torch.nn.Linear(2, 1),
ForwardWithControlFlowTest(),
)
def forward(self, x):
out = self.mlp(x)
return out
model = ModelWithControlFlowTest()
匯出模型:首次嘗試¶
使用 torch.export.export 匯出此模型會失敗,因為 forward 傳遞中的控制流邏輯會建立一個匯出器無法處理的圖斷裂。這種行為是預期的,因為未使用 torch.cond() 編寫的條件邏輯不受支援。
使用 try-except 塊來捕獲匯出過程中預期的失敗。如果匯出意外成功,則會引發 AssertionError。
x = torch.randn(3)
model(x)
try:
torch.export.export(model, (x,), strict=False)
raise AssertionError("This export should failed unless PyTorch now supports this model.")
except Exception as e:
print(e)
def forward(self, arg0_1: "f32[2, 3]", arg1_1: "f32[2]", arg2_1: "f32[1, 2]", arg3_1: "f32[1]", arg4_1: "f32[3]"):
# File: /var/lib/ci-user/.local/lib/python3.10/site-packages/torch/nn/modules/linear.py:125 in forward, code: return F.linear(input, self.weight, self.bias)
linear: "f32[2]" = torch.ops.aten.linear.default(arg4_1, arg0_1, arg1_1); arg4_1 = arg0_1 = arg1_1 = None
linear_1: "f32[1]" = torch.ops.aten.linear.default(linear, arg2_1, arg3_1); linear = arg2_1 = arg3_1 = None
# File: /var/lib/workspace/beginner_source/onnx/export_control_flow_model_to_onnx_tutorial.py:56 in forward, code: if x.sum():
sum_1: "f32[]" = torch.ops.aten.sum.default(linear_1); linear_1 = None
ne: "b8[]" = torch.ops.aten.ne.Scalar(sum_1, 0); sum_1 = None
item: "Sym(Eq(u0, 1))" = torch.ops.aten.item.default(ne); ne = item = None
Could not guard on data-dependent expression Eq(u0, 1) (unhinted: Eq(u0, 1)). (Size-like symbols: none)
Caused by: (_export/non_strict_utils.py:683 in __torch_function__)
For more information, run with TORCH_LOGS="dynamic"
For extended logs when we create symbols, also add TORCHDYNAMO_EXTENDED_DEBUG_CREATE_SYMBOL="u0"
If you suspect the guard was triggered from C++, add TORCHDYNAMO_EXTENDED_DEBUG_CPP=1
For more debugging help, see https://docs.google.com/document/d/1HSuTTVvYH1pTew89Rtpeu84Ht3nQEFTYhAX3Ypa_xJs/edit?usp=sharing
For C++ stack trace, run with TORCHDYNAMO_EXTENDED_DEBUG_CPP=1
The following call raised this error:
File "/var/lib/workspace/beginner_source/onnx/export_control_flow_model_to_onnx_tutorial.py", line 56, in forward
if x.sum():
使用 JIT Tracing 匯出 torch.onnx.export()¶
當使用帶有 dynamo=True 引數的 torch.onnx.export() 匯出模型時,匯出器預設使用 JIT Tracing。這種回退允許模型匯出,但由於 Tracing 的限制,生成的 ONNX 圖可能無法忠實地表示原始模型邏輯。
onnx_program = torch.onnx.export(model, (x,), dynamo=True)
print(onnx_program.model)
/usr/local/lib/python3.10/dist-packages/onnxscript/converter.py:823: FutureWarning:
'onnxscript.values.Op.param_schemas' is deprecated in version 0.1 and will be removed in the future. Please use '.op_signature' instead.
/usr/local/lib/python3.10/dist-packages/onnxscript/converter.py:823: FutureWarning:
'onnxscript.values.OnnxFunction.param_schemas' is deprecated in version 0.1 and will be removed in the future. Please use '.op_signature' instead.
[torch.onnx] Obtain model graph for `ModelWithControlFlowTest([...]` with `torch.export.export(..., strict=False)`...
def forward(self, arg0_1: "f32[2, 3]", arg1_1: "f32[2]", arg2_1: "f32[1, 2]", arg3_1: "f32[1]", arg4_1: "f32[3]"):
# File: /var/lib/ci-user/.local/lib/python3.10/site-packages/torch/nn/modules/linear.py:125 in forward, code: return F.linear(input, self.weight, self.bias)
linear: "f32[2]" = torch.ops.aten.linear.default(arg4_1, arg0_1, arg1_1); arg4_1 = arg0_1 = arg1_1 = None
linear_1: "f32[1]" = torch.ops.aten.linear.default(linear, arg2_1, arg3_1); linear = arg2_1 = arg3_1 = None
# File: /var/lib/workspace/beginner_source/onnx/export_control_flow_model_to_onnx_tutorial.py:56 in forward, code: if x.sum():
sum_1: "f32[]" = torch.ops.aten.sum.default(linear_1); linear_1 = None
ne: "b8[]" = torch.ops.aten.ne.Scalar(sum_1, 0); sum_1 = None
item: "Sym(Eq(u0, 1))" = torch.ops.aten.item.default(ne); ne = item = None
[torch.onnx] Obtain model graph for `ModelWithControlFlowTest([...]` with `torch.export.export(..., strict=False)`... ❌
[torch.onnx] Obtain model graph for `ModelWithControlFlowTest([...]` with `torch.export.export`...
class GraphModule(torch.nn.Module):
def forward(self, L_x_: "f32[3][1]cpu"):
l_x_ = L_x_
# File: /var/lib/workspace/beginner_source/onnx/export_control_flow_model_to_onnx_tutorial.py:71 in forward, code: out = self.mlp(x)
l__self___mlp_0: "f32[2][1]cpu" = self.L__self___mlp_0(l_x_); l_x_ = None
l__self___mlp_1: "f32[1][1]cpu" = self.L__self___mlp_1(l__self___mlp_0); l__self___mlp_0 = None
# File: /var/lib/workspace/beginner_source/onnx/export_control_flow_model_to_onnx_tutorial.py:56 in forward, code: if x.sum():
sum_1: "f32[][]cpu" = l__self___mlp_1.sum(); l__self___mlp_1 = sum_1 = None
class GraphModule(torch.nn.Module):
def forward(self, L_x_: "f32[3][1]cpu"):
l_x_ = L_x_
# File: /var/lib/workspace/beginner_source/onnx/export_control_flow_model_to_onnx_tutorial.py:71 in forward, code: out = self.mlp(x)
l__self___mlp_0: "f32[2][1]cpu" = self.L__self___mlp_0(l_x_); l_x_ = None
l__self___mlp_1: "f32[1][1]cpu" = self.L__self___mlp_1(l__self___mlp_0); l__self___mlp_0 = None
# File: /var/lib/workspace/beginner_source/onnx/export_control_flow_model_to_onnx_tutorial.py:56 in forward, code: if x.sum():
sum_1: "f32[][]cpu" = l__self___mlp_1.sum(); l__self___mlp_1 = sum_1 = None
class GraphModule(torch.nn.Module):
def forward(self, L_x_: "f32[3][1]cpu"):
l_x_ = L_x_
# File: /var/lib/workspace/beginner_source/onnx/export_control_flow_model_to_onnx_tutorial.py:71 in forward, code: out = self.mlp(x)
l__self___mlp_0: "f32[2][1]cpu" = self.L__self___mlp_0(l_x_); l_x_ = None
l__self___mlp_1: "f32[1][1]cpu" = self.L__self___mlp_1(l__self___mlp_0); l__self___mlp_0 = None
# File: /var/lib/workspace/beginner_source/onnx/export_control_flow_model_to_onnx_tutorial.py:56 in forward, code: if x.sum():
sum_1: "f32[][]cpu" = l__self___mlp_1.sum(); l__self___mlp_1 = sum_1 = None
[torch.onnx] Obtain model graph for `ModelWithControlFlowTest([...]` with `torch.export.export`... ❌
[torch.onnx] Obtain model graph for `ModelWithControlFlowTest([...]` with Torch Script...
/var/lib/workspace/beginner_source/onnx/export_control_flow_model_to_onnx_tutorial.py:56: TracerWarning:
Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
[torch.onnx] Obtain model graph for `ModelWithControlFlowTest([...]` with Torch Script... ✅
[torch.onnx] Run decomposition...
/var/lib/ci-user/.local/lib/python3.10/site-packages/torch/export/_unlift.py:81: UserWarning:
Attempted to insert a get_attr Node with no underlying reference in the owning GraphModule! Call GraphModule.add_submodule to add the necessary submodule, GraphModule.add_parameter to add the necessary Parameter, or nn.Module.register_buffer to add the necessary buffer
/var/lib/ci-user/.local/lib/python3.10/site-packages/torch/fx/graph.py:1772: UserWarning:
Node lifted_tensor_6 target lifted_tensor_6 lifted_tensor_6 of does not reference an nn.Module, nn.Parameter, or buffer, which is what 'get_attr' Nodes typically target
[torch.onnx] Run decomposition... ✅
[torch.onnx] Translate the graph into ONNX...
[torch.onnx] Translate the graph into ONNX... ✅
<
ir_version=10,
opset_imports={'pkg.onnxscript.torch_lib.common': 1, '': 18},
producer_name='pytorch',
producer_version='2.7.0+cu126',
domain=None,
model_version=None,
>
graph(
name=main_graph,
inputs=(
%"input_1"<FLOAT,[3]>
),
outputs=(
%"mul"<FLOAT,[1]>
),
initializers=(
%"model.mlp.0.bias"<FLOAT,[2]>,
%"model.mlp.1.bias"<FLOAT,[1]>
),
) {
0 | # node_Constant_8
%"val_0"<FLOAT,[3,2]> ⬅️ ::Constant() {value=Tensor<FLOAT,[3,2]>(array([[ 0.44140652, 0.53036046],
[ 0.47920528, -0.1264995 ],
[-0.13525727, 0.11650391]], dtype=float32), name='val_0')}
1 | # node_MatMul_1
%"val_1"<FLOAT,[2]> ⬅️ ::MatMul(%"input_1", %"val_0")
2 | # node_Add_2
%"linear"<FLOAT,[2]> ⬅️ ::Add(%"val_1", %"model.mlp.0.bias")
3 | # node_Constant_9
%"val_2"<FLOAT,[2,1]> ⬅️ ::Constant() {value=Tensor<FLOAT,[2,1]>(array([[ 0.62334496],
[-0.5187534 ]], dtype=float32), name='val_2')}
4 | # node_MatMul_4
%"val_3"<FLOAT,[1]> ⬅️ ::MatMul(%"linear", %"val_2")
5 | # node_Add_5
%"linear_1"<FLOAT,[1]> ⬅️ ::Add(%"val_3", %"model.mlp.1.bias")
6 | # node_Constant_10
%"convert_element_type_default"<FLOAT,[]> ⬅️ ::Constant() {value=Tensor<FLOAT,[]>(array(2., dtype=float32), name='convert_element_type_default')}
7 | # node_Mul_7
%"mul"<FLOAT,[1]> ⬅️ ::Mul(%"linear_1", %"convert_element_type_default")
return %"mul"<FLOAT,[1]>
}
建議的補丁:使用 torch.cond() 進行重構¶
為了使控制流可匯出,本教程演示瞭如何將 ForwardWithControlFlowTest 中的 forward 方法替換為使用 torch.cond`() 重構的版本。
重構細節
兩個輔助函式 (identity2 和 neg) 代表條件邏輯的分支:* 使用 torch.cond`() 指定條件和兩個分支以及輸入引數。* 然後將更新後的 forward 方法動態分配給模型中的 ForwardWithControlFlowTest 例項。列印子模組列表以確認替換。
def new_forward(x):
def identity2(x):
return x * 2
def neg(x):
return -x
return torch.cond(x.sum() > 0, identity2, neg, (x,))
print("the list of submodules")
for name, mod in model.named_modules():
print(name, type(mod))
if isinstance(mod, ForwardWithControlFlowTest):
mod.forward = new_forward
the list of submodules
<class '__main__.ModelWithControlFlowTest'>
mlp <class 'torch.nn.modules.container.Sequential'>
mlp.0 <class 'torch.nn.modules.linear.Linear'>
mlp.1 <class 'torch.nn.modules.linear.Linear'>
mlp.2 <class '__main__.ForwardWithControlFlowTest'>
讓我們看看 FX 圖是什麼樣的。
print(torch.export.export(model, (x,), strict=False))
ExportedProgram:
class GraphModule(torch.nn.Module):
def forward(self, p_mlp_0_weight: "f32[2, 3]", p_mlp_0_bias: "f32[2]", p_mlp_1_weight: "f32[1, 2]", p_mlp_1_bias: "f32[1]", x: "f32[3]"):
# File: /var/lib/ci-user/.local/lib/python3.10/site-packages/torch/nn/modules/linear.py:125 in forward, code: return F.linear(input, self.weight, self.bias)
linear: "f32[2]" = torch.ops.aten.linear.default(x, p_mlp_0_weight, p_mlp_0_bias); x = p_mlp_0_weight = p_mlp_0_bias = None
linear_1: "f32[1]" = torch.ops.aten.linear.default(linear, p_mlp_1_weight, p_mlp_1_bias); linear = p_mlp_1_weight = p_mlp_1_bias = None
# File: /var/lib/ci-user/.local/lib/python3.10/site-packages/torch/nn/modules/container.py:240 in forward, code: input = module(input)
sum_1: "f32[]" = torch.ops.aten.sum.default(linear_1)
gt: "b8[]" = torch.ops.aten.gt.Scalar(sum_1, 0); sum_1 = None
# File: <eval_with_key>.30:9 in forward, code: cond = torch.ops.higher_order.cond(l_args_0_, cond_true_0, cond_false_0, [l_args_3_0_]); l_args_0_ = cond_true_0 = cond_false_0 = l_args_3_0_ = None
true_graph_0 = self.true_graph_0
false_graph_0 = self.false_graph_0
cond = torch.ops.higher_order.cond(gt, true_graph_0, false_graph_0, [linear_1]); gt = true_graph_0 = false_graph_0 = linear_1 = None
getitem: "f32[1]" = cond[0]; cond = None
return (getitem,)
class true_graph_0(torch.nn.Module):
def forward(self, linear_1: "f32[1]"):
# File: <eval_with_key>.25:6 in forward, code: mul = l_args_3_0__1.mul(2); l_args_3_0__1 = None
mul: "f32[1]" = torch.ops.aten.mul.Tensor(linear_1, 2); linear_1 = None
return (mul,)
class false_graph_0(torch.nn.Module):
def forward(self, linear_1: "f32[1]"):
# File: <eval_with_key>.26:6 in forward, code: neg = l_args_3_0__1.neg(); l_args_3_0__1 = None
neg: "f32[1]" = torch.ops.aten.neg.default(linear_1); linear_1 = None
return (neg,)
Graph signature: ExportGraphSignature(input_specs=[InputSpec(kind=<InputKind.PARAMETER: 2>, arg=TensorArgument(name='p_mlp_0_weight'), target='mlp.0.weight', persistent=None), InputSpec(kind=<InputKind.PARAMETER: 2>, arg=TensorArgument(name='p_mlp_0_bias'), target='mlp.0.bias', persistent=None), InputSpec(kind=<InputKind.PARAMETER: 2>, arg=TensorArgument(name='p_mlp_1_weight'), target='mlp.1.weight', persistent=None), InputSpec(kind=<InputKind.PARAMETER: 2>, arg=TensorArgument(name='p_mlp_1_bias'), target='mlp.1.bias', persistent=None), InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='x'), target=None, persistent=None)], output_specs=[OutputSpec(kind=<OutputKind.USER_OUTPUT: 1>, arg=TensorArgument(name='getitem'), target=None)])
Range constraints: {}
讓我們再次匯出。
onnx_program = torch.onnx.export(model, (x,), dynamo=True)
print(onnx_program.model)
[torch.onnx] Obtain model graph for `ModelWithControlFlowTest([...]` with `torch.export.export(..., strict=False)`...
[torch.onnx] Obtain model graph for `ModelWithControlFlowTest([...]` with `torch.export.export(..., strict=False)`... ✅
[torch.onnx] Run decomposition...
[torch.onnx] Run decomposition... ✅
[torch.onnx] Translate the graph into ONNX...
[torch.onnx] Translate the graph into ONNX... ✅
<
ir_version=10,
opset_imports={'pkg.onnxscript.torch_lib.common': 1, '': 18, 'pkg.torch.__subgraph__': 1},
producer_name='pytorch',
producer_version='2.7.0+cu126',
domain=None,
model_version=None,
>
graph(
name=main_graph,
inputs=(
%"x"<FLOAT,[3]>
),
outputs=(
%"getitem"<FLOAT,[1]>
),
initializers=(
%"mlp.0.bias"<FLOAT,[2]>,
%"mlp.1.bias"<FLOAT,[1]>
),
) {
0 | # node_Constant_11
%"val_0"<FLOAT,[3,2]> ⬅️ ::Constant() {value=Tensor<FLOAT,[3,2]>(array([[ 0.44140652, 0.53036046],
[ 0.47920528, -0.1264995 ],
[-0.13525727, 0.11650391]], dtype=float32), name='val_0')}
1 | # node_MatMul_1
%"val_1"<FLOAT,[2]> ⬅️ ::MatMul(%"x", %"val_0")
2 | # node_Add_2
%"linear"<FLOAT,[2]> ⬅️ ::Add(%"val_1", %"mlp.0.bias")
3 | # node_Constant_12
%"val_2"<FLOAT,[2,1]> ⬅️ ::Constant() {value=Tensor<FLOAT,[2,1]>(array([[ 0.62334496],
[-0.5187534 ]], dtype=float32), name='val_2')}
4 | # node_MatMul_4
%"val_3"<FLOAT,[1]> ⬅️ ::MatMul(%"linear", %"val_2")
5 | # node_Add_5
%"linear_1"<FLOAT,[1]> ⬅️ ::Add(%"val_3", %"mlp.1.bias")
6 | # node_ReduceSum_6
%"sum_1"<FLOAT,[]> ⬅️ ::ReduceSum(%"linear_1") {keepdims=False, noop_with_empty_axes=0}
7 | # node_Constant_13
%"scalar_tensor_default"<FLOAT,[]> ⬅️ ::Constant() {value=Tensor<FLOAT,[]>(array(0., dtype=float32), name='scalar_tensor_default')}
8 | # node_Greater_9
%"gt"<BOOL,[]> ⬅️ ::Greater(%"sum_1", %"scalar_tensor_default")
9 | # node_If_10
%"getitem"<FLOAT,[1]> ⬅️ ::If(%"gt") {then_branch=
graph(
name=true_graph_0,
inputs=(
),
outputs=(
%"mul_true_graph_0"<FLOAT,[1]>
),
) {
0 | # node_Constant_1
%"scalar_tensor_default_2"<FLOAT,[]> ⬅️ ::Constant() {value=Tensor<FLOAT,[]>(array(2., dtype=float32), name='scalar_tensor_default_2')}
1 | # node_Mul_2
%"mul_true_graph_0"<FLOAT,[1]> ⬅️ ::Mul(%"linear_1", %"scalar_tensor_default_2")
return %"mul_true_graph_0"<FLOAT,[1]>
}, else_branch=
graph(
name=false_graph_0,
inputs=(
),
outputs=(
%"neg_false_graph_0"<FLOAT,[1]>
),
) {
0 | # node_Neg_0
%"neg_false_graph_0"<FLOAT,[1]> ⬅️ ::Neg(%"linear_1")
return %"neg_false_graph_0"<FLOAT,[1]>
}}
return %"getitem"<FLOAT,[1]>
}
我們可以最佳化模型,並去掉為捕獲控制流分支而建立的模型本地函式。
onnx_program.optimize()
print(onnx_program.model)
<
ir_version=10,
opset_imports={'pkg.onnxscript.torch_lib.common': 1, '': 18, 'pkg.torch.__subgraph__': 1},
producer_name='pytorch',
producer_version='2.7.0+cu126',
domain=None,
model_version=None,
>
graph(
name=main_graph,
inputs=(
%"x"<FLOAT,[3]>
),
outputs=(
%"getitem"<FLOAT,[1]>
),
initializers=(
%"mlp.0.bias"<FLOAT,[2]>,
%"mlp.1.bias"<FLOAT,[1]>
),
) {
0 | # node_Constant_11
%"val_0"<FLOAT,[3,2]> ⬅️ ::Constant() {value=Tensor<FLOAT,[3,2]>(array([[ 0.44140652, 0.53036046],
[ 0.47920528, -0.1264995 ],
[-0.13525727, 0.11650391]], dtype=float32), name='val_0')}
1 | # node_MatMul_1
%"val_1"<FLOAT,[2]> ⬅️ ::MatMul(%"x", %"val_0")
2 | # node_Add_2
%"linear"<FLOAT,[2]> ⬅️ ::Add(%"val_1", %"mlp.0.bias")
3 | # node_Constant_12
%"val_2"<FLOAT,[2,1]> ⬅️ ::Constant() {value=Tensor<FLOAT,[2,1]>(array([[ 0.62334496],
[-0.5187534 ]], dtype=float32), name='val_2')}
4 | # node_MatMul_4
%"val_3"<FLOAT,[1]> ⬅️ ::MatMul(%"linear", %"val_2")
5 | # node_Add_5
%"linear_1"<FLOAT,[1]> ⬅️ ::Add(%"val_3", %"mlp.1.bias")
6 | # node_ReduceSum_6
%"sum_1"<FLOAT,[]> ⬅️ ::ReduceSum(%"linear_1") {keepdims=False, noop_with_empty_axes=0}
7 | # node_Constant_13
%"scalar_tensor_default"<FLOAT,[]> ⬅️ ::Constant() {value=Tensor<FLOAT,[]>(array(0., dtype=float32), name='scalar_tensor_default')}
8 | # node_Greater_9
%"gt"<BOOL,[]> ⬅️ ::Greater(%"sum_1", %"scalar_tensor_default")
9 | # node_If_10
%"getitem"<FLOAT,[1]> ⬅️ ::If(%"gt") {then_branch=
graph(
name=true_graph_0,
inputs=(
),
outputs=(
%"mul_true_graph_0"<FLOAT,[1]>
),
) {
0 | # node_Constant_1
%"scalar_tensor_default_2"<FLOAT,[]> ⬅️ ::Constant() {value=Tensor<FLOAT,[]>(array(2., dtype=float32), name='scalar_tensor_default_2')}
1 | # node_Mul_2
%"mul_true_graph_0"<FLOAT,[1]> ⬅️ ::Mul(%"linear_1", %"scalar_tensor_default_2")
return %"mul_true_graph_0"<FLOAT,[1]>
}, else_branch=
graph(
name=false_graph_0,
inputs=(
),
outputs=(
%"neg_false_graph_0"<FLOAT,[1]>
),
) {
0 | # node_Neg_0
%"neg_false_graph_0"<FLOAT,[1]> ⬅️ ::Neg(%"linear_1")
return %"neg_false_graph_0"<FLOAT,[1]>
}}
return %"getitem"<FLOAT,[1]>
}
結論¶
本教程演示了將帶有條件邏輯的模型匯出到 ONNX 的挑戰,並提供了一個使用 torch.cond() 的實用解決方案。儘管預設匯出器可能會失敗或生成不完美的圖,但重構模型的邏輯可以確保相容性並生成忠實的 ONNX 表示。
透過理解這些技術,我們可以克服在 PyTorch 模型中處理控制流時常見的陷阱,並確保與 ONNX 工作流程的順利整合。
延伸閱讀¶
下面的列表引用了一些教程,它們涵蓋了從基本示例到高階場景,順序不一定按列表排列。您可以隨意直接跳到您感興趣的特定主題,或者耐心地通讀所有教程,瞭解 ONNX 匯出器的所有內容。
指令碼總執行時間: ( 0 分鐘 2.263 秒)