快捷方式

基於 TorchScript 的 ONNX 匯出器

注意

要使用 TorchDynamo 而非 TorchScript 匯出 ONNX 模型,請參閱 瞭解更多關於基於 TorchDynamo 的 ONNX 匯出器

示例:從 PyTorch 到 ONNX 的 AlexNet

這是一個簡單的指令碼,它將預訓練的 AlexNet 匯出到名為 alexnet.onnx 的 ONNX 檔案。呼叫 torch.onnx.export 會執行模型一次以追蹤其執行,然後將追蹤到的模型匯出到指定檔案。

import torch
import torchvision

dummy_input = torch.randn(10, 3, 224, 224, device="cuda")
model = torchvision.models.alexnet(pretrained=True).cuda()

# Providing input and output names sets the display names for values
# within the model's graph. Setting these does not change the semantics
# of the graph; it is only for readability.
#
# The inputs to the network consist of the flat list of inputs (i.e.
# the values you would pass to the forward() method) followed by the
# flat list of parameters. You can partially specify names, i.e. provide
# a list here shorter than the number of inputs to the model, and we will
# only set that subset of names, starting from the beginning.
input_names = [ "actual_input_1" ] + [ "learned_%d" % i for i in range(16) ]
output_names = [ "output1" ]

torch.onnx.export(model, dummy_input, "alexnet.onnx", verbose=True, input_names=input_names, output_names=output_names)

生成的 alexnet.onnx 檔案包含一個二進位制 protocol buffer,其中包含您匯出的模型(在本例中為 AlexNet)的網路結構和引數。引數 verbose=True 會使匯出器打印出模型的可讀表示。

# These are the inputs and parameters to the network, which have taken on
# the names we specified earlier.
graph(%actual_input_1 : Float(10, 3, 224, 224)
      %learned_0 : Float(64, 3, 11, 11)
      %learned_1 : Float(64)
      %learned_2 : Float(192, 64, 5, 5)
      %learned_3 : Float(192)
      # ---- omitted for brevity ----
      %learned_14 : Float(1000, 4096)
      %learned_15 : Float(1000)) {
  # Every statement consists of some output tensors (and their types),
  # the operator to be run (with its attributes, e.g., kernels, strides,
  # etc.), its input tensors (%actual_input_1, %learned_0, %learned_1)
  %17 : Float(10, 64, 55, 55) = onnx::Conv[dilations=[1, 1], group=1, kernel_shape=[11, 11], pads=[2, 2, 2, 2], strides=[4, 4]](%actual_input_1, %learned_0, %learned_1), scope: AlexNet/Sequential[features]/Conv2d[0]
  %18 : Float(10, 64, 55, 55) = onnx::Relu(%17), scope: AlexNet/Sequential[features]/ReLU[1]
  %19 : Float(10, 64, 27, 27) = onnx::MaxPool[kernel_shape=[3, 3], pads=[0, 0, 0, 0], strides=[2, 2]](%18), scope: AlexNet/Sequential[features]/MaxPool2d[2]
  # ---- omitted for brevity ----
  %29 : Float(10, 256, 6, 6) = onnx::MaxPool[kernel_shape=[3, 3], pads=[0, 0, 0, 0], strides=[2, 2]](%28), scope: AlexNet/Sequential[features]/MaxPool2d[12]
  # Dynamic means that the shape is not known. This may be because of a
  # limitation of our implementation (which we would like to fix in a
  # future release) or shapes which are truly dynamic.
  %30 : Dynamic = onnx::Shape(%29), scope: AlexNet
  %31 : Dynamic = onnx::Slice[axes=[0], ends=[1], starts=[0]](%30), scope: AlexNet
  %32 : Long() = onnx::Squeeze[axes=[0]](%31), scope: AlexNet
  %33 : Long() = onnx::Constant[value={9216}](), scope: AlexNet
  # ---- omitted for brevity ----
  %output1 : Float(10, 1000) = onnx::Gemm[alpha=1, beta=1, broadcast=1, transB=1](%45, %learned_14, %learned_15), scope: AlexNet/Sequential[classifier]/Linear[6]
  return (%output1);
}

您還可以使用 ONNX 庫來驗證輸出,您可以使用 pip 進行安裝。

pip install onnx

然後,您可以執行

import onnx

# Load the ONNX model
model = onnx.load("alexnet.onnx")

# Check that the model is well formed
onnx.checker.check_model(model)

# Print a human readable representation of the graph
print(onnx.helper.printable_graph(model.graph))

您還可以使用許多支援 ONNX 的 執行時 之一來執行匯出的模型。例如,安裝 ONNX Runtime 後,您可以載入並執行模型。

import onnxruntime as ort
import numpy as np

ort_session = ort.InferenceSession("alexnet.onnx")

outputs = ort_session.run(
    None,
    {"actual_input_1": np.random.randn(10, 3, 224, 224).astype(np.float32)},
)
print(outputs[0])

這是一個更詳細的 教程,講解如何匯出模型並使用 ONNX Runtime 執行它

追蹤 (Tracing) 與指令碼化 (Scripting)

在內部,torch.onnx.export() 需要一個 torch.jit.ScriptModule,而不是 torch.nn.Module。如果傳入的模型還不是 ScriptModuleexport() 將使用 追蹤 將其轉換為一個。

  • 追蹤:如果使用一個還不是 ScriptModule 的 Module 呼叫 torch.onnx.export(),它會首先執行等同於 torch.jit.trace() 的操作,該操作會使用給定的 args 並記錄該執行期間發生的所有操作。這意味著如果您的模型是動態的,例如,行為取決於輸入資料,匯出的模型將 捕捉這種動態行為。我們建議檢查匯出的模型並確保運算元看起來合理。追蹤將展開迴圈和條件語句,匯出一個與追蹤執行完全相同的靜態圖。如果您想匯出帶有動態控制流的模型,則需要使用 指令碼化

  • 指令碼化:透過指令碼化編譯模型保留了動態控制流,並且對不同大小的輸入有效。要使用指令碼化,請執行以下操作:

    • 使用 torch.jit.script() 來生成一個 ScriptModule

    • 使用 ScriptModule 作為模型來呼叫 torch.onnx.export()args 仍然是必需的,但它們只會在內部用於生成示例輸出,以便捕獲輸出的型別和形狀。不會執行追蹤。

請參閱 TorchScript 入門TorchScript 以獲取更多詳細資訊,包括如何組合追蹤和指令碼化以適應不同模型的特定需求。

避免陷阱

避免使用 NumPy 和內建 Python 型別

PyTorch 模型可以使用 NumPy 或 Python 型別和函式編寫,但在 追蹤 期間,任何 NumPy 或 Python 型別的變數(而不是 torch.Tensor)都會被轉換為常量,如果這些值應根據輸入而改變,這將產生錯誤的結果。

例如,與其在 numpy.ndarrays 上使用 numpy 函式

# Bad! Will be replaced with constants during tracing.
x, y = np.random.rand(1, 2), np.random.rand(1, 2)
np.concatenate((x, y), axis=1)

不如在 torch.Tensors 上使用 torch 運算元

# Good! Tensor operations will be captured during tracing.
x, y = torch.randn(1, 2), torch.randn(1, 2)
torch.cat((x, y), dim=1)

與其使用 torch.Tensor.item()(它將 Tensor 轉換為 Python 內建數值)

# Bad! y.item() will be replaced with a constant during tracing.
def forward(self, x, y):
    return x.reshape(y.item(), -1)

不如使用 torch 對單元素張量隱式型別轉換的支援

# Good! y will be preserved as a variable during tracing.
def forward(self, x, y):
    return x.reshape(y, -1)

避免使用 Tensor.data

使用 Tensor.data 欄位可能會產生不正確的追蹤,從而生成不正確的 ONNX 圖。請改用 torch.Tensor.detach()。(徹底移除 Tensor.data 的工作正在進行中)。

在追蹤模式下使用 tensor.shape 時避免原地操作

在追蹤模式下,從 tensor.shape 獲取的形狀被追蹤為張量,並共享相同的記憶體。這可能會導致最終輸出值不匹配。作為一種解決方法,請避免在這些場景中使用原地操作。例如,在模型中

class Model(torch.nn.Module):
  def forward(self, states):
      batch_size, seq_length = states.shape[:2]
      real_seq_length = seq_length
      real_seq_length += 2
      return real_seq_length + seq_length

real_seq_lengthseq_length 在追蹤模式下共享相同的記憶體。這可以透過重寫原地操作來避免

real_seq_length = real_seq_length + 2

限制

型別

  • 只有 torch.Tensors、可以輕易轉換為 torch.Tensors 的數值型別(例如 float, int)以及這些型別的元組和列表被支援作為模型輸入或輸出。在 追蹤 模式下接受 Dict 和 str 輸入和輸出,但

    • 任何依賴於 dict 或 str 輸入值的計算 將被替換為 單次追蹤執行期間看到 的常量值

    • 任何作為 dict 的輸出都將被靜默替換為其值的 扁平序列(鍵將被移除)。例如,{"foo": 1, "bar": 2} 變為 (1, 2)

    • 任何作為 str 的輸出都將被靜默移除。

  • 由於 ONNX 對巢狀序列的支援有限,涉及元組和列表的某些操作在 指令碼化 模式下不受支援。特別是,將元組附加到列表不受支援。在追蹤模式下,巢狀序列將在追蹤期間自動展平。

運算元實現差異

由於運算元實現上的差異,在不同執行時上執行匯出的模型可能會產生彼此不同或與 PyTorch 不同的結果。通常這些差異在數值上很小,因此只有當您的應用程式對這些微小差異敏感時才需要關注。

不支援的張量索引模式

無法匯出的張量索引模式列在下方。如果您在匯出模型時遇到問題,而模型不包含以下任何不支援的模式,請仔細檢查您是否使用最新的 opset_version 進行匯出。

讀取 / 獲取

在對張量進行索引以進行讀取時,不支援以下模式

# Tensor indices that includes negative values.
data[torch.tensor([[1, 2], [2, -3]]), torch.tensor([-2, 3])]
# Workarounds: use positive index values.

寫入 / 設定

在對張量進行索引以進行寫入時,不支援以下模式

# Multiple tensor indices if any has rank >= 2
data[torch.tensor([[1, 2], [2, 3]]), torch.tensor([2, 3])] = new_data
# Workarounds: use single tensor index with rank >= 2,
#              or multiple consecutive tensor indices with rank == 1.

# Multiple tensor indices that are not consecutive
data[torch.tensor([2, 3]), :, torch.tensor([1, 2])] = new_data
# Workarounds: transpose `data` such that tensor indices are consecutive.

# Tensor indices that includes negative values.
data[torch.tensor([1, -2]), torch.tensor([-2, 3])] = new_data
# Workarounds: use positive index values.

# Implicit broadcasting required for new_data.
data[torch.tensor([[0, 2], [1, 1]]), 1:3] = new_data
# Workarounds: expand new_data explicitly.
# Example:
#   data shape: [3, 4, 5]
#   new_data shape: [5]
#   expected new_data shape after broadcasting: [2, 2, 2, 5]

新增運算元支援

當匯出包含不支援運算元的模型時,您將看到類似以下內容的錯誤訊息

RuntimeError: ONNX export failed: Couldn't export operator foo

發生這種情況時,您可以採取以下幾種措施

  1. 更改模型以不使用該運算元。

  2. 建立一個符號函式來轉換該運算元,並將其註冊為自定義符號函式。

  3. 貢獻給 PyTorch,將相同的符號函式新增到 torch.onnx 本身。

如果您決定實現一個符號函式(我們希望您能將其貢獻回 PyTorch!),以下是入門方法

ONNX 匯出器內部機制

“符號函式”是一個函式,它將一個 PyTorch 運算元分解為一系列 ONNX 運算元的組合。

在匯出過程中,匯出器會按照拓撲順序訪問 TorchScript 圖中的每個節點(其中包含一個 PyTorch 運算元)。訪問節點時,匯出器會查詢為該運算元註冊的符號函式。符號函式是用 Python 實現的。一個名為 foo 的運算元的符號函式看起來像這樣

def foo(
  g,
  input_0: torch._C.Value,
  input_1: torch._C.Value) -> Union[None, torch._C.Value, List[torch._C.Value]]:
  """
  Adds the ONNX operations representing this PyTorch function by updating the
  graph g with `g.op()` calls.

  Args:
    g (Graph): graph to write the ONNX representation into.
    input_0 (Value): value representing the variables which contain
        the first input for this operator.
    input_1 (Value): value representing the variables which contain
        the second input for this operator.

  Returns:
    A Value or List of Values specifying the ONNX nodes that compute something
    equivalent to the original PyTorch operator with the given inputs.

    None if it cannot be converted to ONNX.
  """
  ...

torch._C 型別是 C++ 中 ir.h 檔案中定義的型別的 Python 包裝器。

新增符號函式的過程取決於運算元的型別。

ATen 運算元

ATen 是 PyTorch 內建的張量庫。如果運算元是 ATen 運算元(在 TorchScript 圖中以字首 aten:: 顯示),請確保它尚未被支援。

支援的運算元列表

訪問自動生成的 支援的 TorchScript 運算元列表,瞭解每個 opset_version 中支援哪些運算元。

新增對 aten 或量化運算元的支援

如果運算元不在上面的列表中

  • torch/onnx/symbolic_opset<version>.py 中定義符號函式,例如 torch/onnx/symbolic_opset9.py。確保函式名稱與 ATen 函式的名稱相同,ATen 函式可能在 torch/_C/_VariableFunctions.pyitorch/nn/functional.pyi 中宣告(這些檔案在構建時生成,因此在構建 PyTorch 之前不會出現在您的程式碼庫中)。

  • 預設情況下,第一個引數是 ONNX 圖。其他引數名稱必須 完全匹配 .pyi 檔案中的名稱,因為分派是使用關鍵字引數完成的。

  • 在符號函式中,如果運算元位於 ONNX 標準運算元集 中,我們只需要在圖中建立一個節點來表示該 ONNX 運算元即可。如果不在,我們可以組合幾個具有與 ATen 運算元等效語義的標準運算元。

這是一個處理缺少 ELU 運算元符號函式的示例。

如果我們執行以下程式碼

print(
    torch.jit.trace(
        torch.nn.ELU(), # module
        torch.ones(1)   # example input
    ).graph
)

我們將看到類似以下內容

graph(%self : __torch__.torch.nn.modules.activation.___torch_mangle_0.ELU,
      %input : Float(1, strides=[1], requires_grad=0, device=cpu)):
  %4 : float = prim::Constant[value=1.]()
  %5 : int = prim::Constant[value=1]()
  %6 : int = prim::Constant[value=1]()
  %7 : Float(1, strides=[1], requires_grad=0, device=cpu) = aten::elu(%input, %4, %5, %6)
  return (%7)

由於我們在圖中看到了 aten::elu,我們知道這是一個 ATen 運算元。

我們檢視 ONNX 運算元列表,並確認 Elu 已在 ONNX 中標準化。

我們在 torch/nn/functional.pyi 中找到了 elu 的簽名

def elu(input: Tensor, alpha: float = ..., inplace: bool = ...) -> Tensor: ...

我們將以下行新增到 symbolic_opset9.py

def elu(g, input: torch.Value, alpha: torch.Value, inplace: bool = False):
    return g.op("Elu", input, alpha_f=alpha)

現在 PyTorch 能夠匯出包含 aten::elu 運算元的模型了!

請參閱 torch/onnx/symbolic_opset*.py 檔案以獲取更多示例。

torch.autograd.Functions

如果運算元是 torch.autograd.Function 的子類,則有三種方法可以匯出它。

靜態符號方法

您可以向您的函式類新增一個名為 symbolic 的靜態方法。它應該返回代表該函式在 ONNX 中行為的 ONNX 運算元。例如

class MyRelu(torch.autograd.Function):
    @staticmethod
    def forward(ctx, input: torch.Tensor) -> torch.Tensor:
        ctx.save_for_backward(input)
        return input.clamp(min=0)

    @staticmethod
    def symbolic(g: torch.Graph, input: torch.Value) -> torch.Value:
        return g.op("Clip", input, g.op("Constant", value_t=torch.tensor(0, dtype=torch.float)))

內聯自動微分函式

在其後續 torch.autograd.Function 沒有提供靜態符號方法,或沒有提供將 prim::PythonOp 註冊為自定義符號函式的功能時,torch.onnx.export() 嘗試內聯與該 torch.autograd.Function 對應的圖,從而將該函式分解為其內部使用的各個運算元。只要這些單個運算元得到支援,匯出就應該成功。例如

class MyLogExp(torch.autograd.Function):
    @staticmethod
    def forward(ctx, input: torch.Tensor) -> torch.Tensor:
        ctx.save_for_backward(input)
        h = input.exp()
        return h.log().log()

該模型沒有提供靜態符號方法,但它匯出如下

graph(%input : Float(1, strides=[1], requires_grad=0, device=cpu)):
    %1 : float = onnx::Exp[](%input)
    %2 : float = onnx::Log[](%1)
    %3 : float = onnx::Log[](%2)
    return (%3)

如果您需要避免內聯 torch.autograd.Function,則應將 operator_export_type 設定為 ONNX_FALLTHROUGHONNX_ATEN_FALLBACK 來匯出模型。

自定義運算元

您可以使用自定義運算元匯出模型,這些運算元可以包含許多標準 ONNX 運算元的組合,或者由自定義的 C++ 後端驅動。

ONNX-script 函式

如果一個運算元不是標準的 ONNX 運算元,但可以由多個現有 ONNX 運算元組合而成,則可以利用 ONNX-script 來建立外部 ONNX 函式以支援該運算元。您可以按照此示例匯出它

import onnxscript
# There are three opset version needed to be aligned
# This is (1) the opset version in ONNX function
from onnxscript.onnx_opset import opset15 as op
opset_version = 15

x = torch.randn(1, 2, 3, 4, requires_grad=True)
model = torch.nn.SELU()

custom_opset = onnxscript.values.Opset(domain="onnx-script", version=1)

@onnxscript.script(custom_opset)
def Selu(X):
    alpha = 1.67326  # auto wrapped as Constants
    gamma = 1.0507
    alphaX = op.CastLike(alpha, X)
    gammaX = op.CastLike(gamma, X)
    neg = gammaX * (alphaX * op.Exp(X) - alphaX)
    pos = gammaX * X
    zero = op.CastLike(0, X)
    return op.Where(X <= zero, neg, pos)

# setType API provides shape/type to ONNX shape/type inference
def custom_selu(g: jit_utils.GraphContext, X):
    return g.onnxscript_op(Selu, X).setType(X.type())

# Register custom symbolic function
# There are three opset version needed to be aligned
# This is (2) the opset version in registry
torch.onnx.register_custom_op_symbolic(
    symbolic_name="aten::selu",
    symbolic_fn=custom_selu,
    opset_version=opset_version,
)

# There are three opset version needed to be aligned
# This is (2) the opset version in exporter
torch.onnx.export(
    model,
    x,
    "model.onnx",
    opset_version=opset_version,
    # only needed if you want to specify an opset version > 1.
    custom_opsets={"onnx-script": 2}
)

上面的示例將其作為“onnx-script”運算元集中的自定義運算元匯出。匯出自定義運算元時,可以使用匯出時的 custom_opsets 字典指定自定義域版本。如果未指定,自定義運算元集版本預設為 1。

注意:請務必對齊上述示例中提到的 opset 版本,並確保它們在匯出步驟中被使用。關於如何編寫 onnx-script 函式的示例用法是 onnx-script 活躍開發中的一個 beta 版本。請遵循最新的 ONNX-script

C++ 運算元

如果模型使用了在 使用自定義 C++ 運算元擴充套件 TorchScript 中描述的自定義 C++ 運算元,您可以按照此示例匯出它

from torch.onnx import symbolic_helper


# Define custom symbolic function
@symbolic_helper.parse_args("v", "v", "f", "i")
def symbolic_foo_forward(g, input1, input2, attr1, attr2):
    return g.op("custom_domain::Foo", input1, input2, attr1_f=attr1, attr2_i=attr2)


# Register custom symbolic function
torch.onnx.register_custom_op_symbolic("custom_ops::foo_forward", symbolic_foo_forward, 9)


class FooModel(torch.nn.Module):
    def __init__(self, attr1, attr2):
        super().__init__()
        self.attr1 = attr1
        self.attr2 = attr2

    def forward(self, input1, input2):
        # Calling custom op
        return torch.ops.custom_ops.foo_forward(input1, input2, self.attr1, self.attr2)


model = FooModel(attr1, attr2)
torch.onnx.export(
    model,
    (example_input1, example_input1),
    "model.onnx",
    # only needed if you want to specify an opset version > 1.
    custom_opsets={"custom_domain": 2}
)

上面的示例將其作為“custom_domain”運算元集中的自定義運算元匯出。匯出自定義運算元時,可以使用匯出時的 custom_opsets 字典指定自定義域版本。如果未指定,自定義運算元集版本預設為 1。

使用該模型的執行時需要支援自定義運算元。請參閱 Caffe2 自定義運算元ONNX Runtime 自定義運算元,或您選擇的執行時的文件。

一次性發現所有不可轉換的 ATen 運算元

當匯出因不可轉換的 ATen 運算元而失敗時,實際上可能不止一個此類運算元,但錯誤訊息只提到了第一個。要一次性發現所有不可轉換的運算元,您可以

# prepare model, args, opset_version
...

torch_script_graph, unconvertible_ops = torch.onnx.utils.unconvertible_ops(
    model, args, opset_version=opset_version
)

print(set(unconvertible_ops))

該集合是近似的,因為某些運算元可能在轉換過程中被移除,無需轉換。其他一些運算元可能只有部分支援,在特定輸入下會轉換失敗,但這應該能讓您大致瞭解哪些運算元不受支援。如需支援運算元,請隨時在 GitHub 上提交 Issue。

常見問題

問:我匯出了我的 LSTM 模型,但其輸入大小似乎是固定的?

追蹤器記錄了示例輸入的形狀。如果模型應該接受動態形狀的輸入,請在呼叫 torch.onnx.export() 時設定 dynamic_axes

問:如何匯出包含迴圈的模型?

問:如何匯出包含原始型別輸入(例如 int, float)的模型?

PyTorch 1.9 中添加了對原始數值型別輸入的支援。但是,匯出器不支援包含 str 輸入的模型。

問:ONNX 是否支援隱式標量資料型別轉換?

ONNX 標準本身不支援,但匯出器會嘗試處理這部分。標量將作為常量張量匯出。匯出器會為標量確定正確的資料型別。在少數無法確定資料型別的情況下,您需要手動指定,例如使用 dtype=torch.float32。如果看到任何錯誤,請建立 GitHub Issue

問:Tensor 列表可以匯出為 ONNX 嗎?

是的,對於 opset_version >= 11,因為 ONNX 在 opset 11 中引入了 Sequence 型別。

Python API

函式

torch.onnx.export(model, args=(), f=None, *, kwargs=None, export_params=True, verbose=None, input_names=None, output_names=None, opset_version=None, dynamic_axes=None, keep_initializers_as_inputs=False, dynamo=False, external_data=True, dynamic_shapes=None, custom_translation_table=None, report=False, optimize=True, verify=False, profile=False, dump_exported_program=False, artifacts_dir='.', fallback=False, training=<TrainingMode.EVAL: 0>, operator_export_type=<OperatorExportTypes.ONNX: 0>, do_constant_folding=True, custom_opsets=None, export_modules_as_functions=False, autograd_inlining=True)[source][source]

將模型匯出為 ONNX 格式。

引數
  • model (torch.nn.Module | torch.export.ExportedProgram | torch.jit.ScriptModule | torch.jit.ScriptFunction) – 要匯出的模型。

  • args (tuple[Any, ...]) – 示例位置輸入。任何非 Tensor 引數都將硬編碼到匯出的模型中;任何 Tensor 引數都將成為匯出模型的輸入,按照它們在元組中出現的順序排列。

  • f (str | os.PathLike | None) – 輸出 ONNX 模型檔案的路徑。例如,“model.onnx”。

  • kwargs (dict[str, Any] | None) – 可選的示例關鍵字輸入。

  • export_params (bool) – 如果為 false,引數(權重)將不會被匯出。

  • verbose (bool | None) – 是否啟用詳細日誌記錄。

  • input_names (Sequence[str] | None) – 要按順序分配給圖中輸入節點的名稱。

  • output_names (Sequence[str] | None) – 要按順序分配給圖中輸出節點的名稱。

  • opset_version (int | None) – 要面向的預設 (ai.onnx) opset 版本。必須 >= 7。

  • dynamic_axes (Mapping[str, Mapping[int, str]] | Mapping[str, Sequence[int]] | None) –

    預設情況下,匯出模型的輸入和輸出張量的形狀將與 args 中給出的形狀完全匹配。要將張量的軸指定為動態的(即僅在執行時已知),請將 dynamic_axes 設定為具有以下模式的字典:

    • 鍵 (str):輸入或輸出名稱。每個名稱也必須在 input_names

      output_names 中提供。.

    • 值 (dict 或 list):如果是字典,鍵是軸索引,值是軸名稱。如果是

      列表,每個元素都是一個軸索引。

    例如

    class SumModule(torch.nn.Module):
        def forward(self, x):
            return torch.sum(x, dim=1)
    
    
    torch.onnx.export(
        SumModule(),
        (torch.ones(2, 2),),
        "onnx.pb",
        input_names=["x"],
        output_names=["sum"],
    )
    

    生成

    input {
      name: "x"
      ...
          shape {
            dim {
              dim_value: 2  # axis 0
            }
            dim {
              dim_value: 2  # axis 1
    ...
    output {
      name: "sum"
      ...
          shape {
            dim {
              dim_value: 2  # axis 0
    ...
    

    同時

    torch.onnx.export(
        SumModule(),
        (torch.ones(2, 2),),
        "onnx.pb",
        input_names=["x"],
        output_names=["sum"],
        dynamic_axes={
            # dict value: manually named axes
            "x": {0: "my_custom_axis_name"},
            # list value: automatic names
            "sum": [0],
        },
    )
    

    生成

    input {
      name: "x"
      ...
          shape {
            dim {
              dim_param: "my_custom_axis_name"  # axis 0
            }
            dim {
              dim_value: 2  # axis 1
    ...
    output {
      name: "sum"
      ...
          shape {
            dim {
              dim_param: "sum_dynamic_axes_1"  # axis 0
    ...
    

  • keep_initializers_as_inputs (bool) –

    如果為 True,匯出圖中的所有初始化器(通常對應於模型權重)也將作為輸入新增到圖中。如果為 False,則初始化器不會作為輸入新增到圖中,只有使用者輸入會作為輸入新增。

    如果您打算在執行時提供模型權重,請將其設定為 True。如果權重是靜態的,請將其設定為 False,以便後端/執行時可以進行更好的最佳化(例如常量摺疊)。

  • dynamo (bool) – 是否使用 torch.export ExportedProgram 匯出模型,而不是使用 TorchScript。

  • external_data (bool) – 是否將模型權重儲存為外部資料檔案。對於權重過大超出 ONNX 檔案大小限制(2GB)的模型,這是必需的。當為 False 時,權重將與模型架構一起儲存在 ONNX 檔案中。

  • dynamic_shapes (dict[str, Any] | tuple[Any, ...] | list[Any] | None) – 模型輸入的動態形狀字典或元組。有關更多詳細資訊,請參閱 torch.export.export()。僅當 dynamo 為 True 時使用(並且優先)。請注意,dynamic_shapes 設計用於在 dynamo=True 時匯出模型,而 dynamic_axes 用於在 dynamo=False 時。

  • custom_translation_table (dict[Callable, Callable | Sequence[Callable]] | None) – 模型中運算元的自定義分解字典。字典應將 fx 節點中的可呼叫目標作為鍵(例如 torch.ops.aten.stft.default),值應是一個使用 ONNX Script 構建該圖的函式。此選項僅當 dynamo 為 True 時有效。

  • report (bool) – 是否為匯出過程生成 Markdown 報告。此選項僅當 dynamo 為 True 時有效。

  • optimize (bool) – 是否最佳化匯出的模型。此選項僅當 dynamo 為 True 時有效。預設為 True。

  • verify (bool) – 是否使用 ONNX Runtime 驗證匯出的模型。此選項僅當 dynamo 為 True 時有效。

  • profile (bool) – 是否對匯出過程進行效能分析。此選項僅當 dynamo 為 True 時有效。

  • dump_exported_program (bool) – 是否將 torch.export.ExportedProgram 匯出到檔案。這對於除錯匯出器很有用。此選項僅當 dynamo 為 True 時有效。

  • artifacts_dir (str | os.PathLike) – 儲存除錯工件(如報告和序列化匯出的程式)的目錄。此選項僅當 dynamo 為 True 時有效。

  • fallback (bool) – 如果 dynamo 匯出器失敗,是否回退到 TorchScript 匯出器。此選項僅當 dynamo 為 True 時有效。啟用回退時,即使提供了 dynamic_shapes,也建議設定 dynamic_axes。

  • training (_C_onnx.TrainingMode) – 已棄用選項。請改為在匯出模型之前設定模型的訓練模式。

  • operator_export_type (_C_onnx.OperatorExportTypes) – 已棄用選項。僅支援 ONNX。

  • do_constant_folding (bool) – 已棄用選項。

  • custom_opsets (Mapping[str, int] | None) –

    已棄用。一個字典

    • 鍵 (str):opset 域名稱

    • 值 (int):opset 版本

    如果 model 引用了自定義 opset 但未在此字典中提及,則 opset 版本將設定為 1。僅應透過此引數指定自定義 opset 域名稱和版本。

  • export_modules_as_functions (bool | Collection[type[torch.nn.Module]]) –

    已棄用選項。

    標誌,用於啟用將所有 nn.Module 的 forward 呼叫匯出為 ONNX 中的區域性函式。或者是一個集合,用於指示要匯出為 ONNX 中的區域性函式的特定模組型別。此功能需要 opset_version >= 15,否則匯出將失敗。這是因為 opset_version < 15 表示 IR 版本 < 8,這意味著不支援區域性函式。模組變數將作為函式屬性匯出。函式屬性有兩種類別。

    1. 註釋屬性:透過PEP 526 風格進行型別註釋的類變數將作為屬性匯出。註釋屬性不用於 ONNX 區域性函式的子圖中,因為它們不是由 PyTorch JIT tracing 建立的,但消費者可以使用它們來確定是否用特定的融合核替換該函式。

    2. 推斷屬性:模組內運算元使用的變數。屬性名稱將帶有字首 “inferred::”。這與從 python 模組註釋中檢索的預定義屬性區分開來。推斷屬性用於 ONNX 區域性函式的子圖中。

    • False (預設):將 nn.Module 的 forward 呼叫匯出為細粒度節點。

    • True:將所有 nn.Module 的 forward 呼叫匯出為區域性函式節點。

    • nn.Module 型別集合:將 nn.Module 的 forward 呼叫匯出為區域性函式節點,

      僅當 nn.Module 的型別在此集合中時。

  • autograd_inlining (bool) – 已棄用。標誌,用於控制是否內聯 autograd 函式。有關更多詳細資訊,請參閱 https://github.com/pytorch/pytorch/pull/74765

返回值

如果 dynamo 為 True,則返回 torch.onnx.ONNXProgram,否則返回 None。

返回型別

ONNXProgram | None

版本 2.6 更改:training 已棄用。請改為在匯出模型之前設定模型的訓練模式。operator_export_type 已棄用。僅支援 ONNX。do_constant_folding 已棄用。它始終啟用。export_modules_as_functions 已棄用。autograd_inlining 已棄用。

版本 2.7 更改:optimize 現在預設為 True。

torch.onnx.register_custom_op_symbolic(symbolic_name, symbolic_fn, opset_version)[source][source]

為自定義運算元註冊一個符號函式。

當用戶為 custom/contrib 運算元註冊符號時,強烈建議透過 setType API 為該運算元新增形狀推斷,否則在某些極端情況下匯出的圖可能具有不正確的形狀推斷。setType 的一個示例是 test_operators.py 中的 test_aten_embedding_2

有關用法示例,請參閱模組文件中的“自定義運算元”。

引數
  • symbolic_name (str) – 自定義運算元的名稱,格式為“<域>::<運算元>”。

  • symbolic_fn (Callable) – 一個函式,接受 ONNX 圖和當前運算元的輸入引數,並返回要新增到圖中的新運算元節點。

  • opset_version (int) – 要註冊的 ONNX opset 版本。

torch.onnx.unregister_custom_op_symbolic(symbolic_name, opset_version)[source][source]

登出 symbolic_name

有關用法示例,請參閱模組文件中的“自定義運算元”。

引數
  • symbolic_name (str) – 自定義運算元的名稱,格式為“<域>::<運算元>”。

  • opset_version (int) – 要登出的 ONNX opset 版本。

torch.onnx.select_model_mode_for_export(model, mode)[source][source]

一個上下文管理器,用於暫時將 model 的訓練模式設定為 mode,並在退出 with 程式碼塊時將其重置。

自版本 2.7 起棄用:請在匯出模型之前設定訓練模式。

引數
  • model – 型別和含義與 export()model 引數相同。

  • mode (TrainingMode) – 型別和含義與 export()training 引數相同。

torch.onnx.is_in_onnx_export()[source][source]

返回當前是否正在進行 ONNX 匯出。

返回型別

bool

JitScalarType

在 torch 中定義的標量型別。

文件

檢視 PyTorch 的完整開發者文件

檢視文件

教程

獲取面向初學者和高階開發者的深度教程

檢視教程

資源

查詢開發資源並解答您的問題

檢視資源