torch.onnx¶
概覽¶
Open Neural Network eXchange (ONNX) 是一種用於表示機器學習模型的開放標準格式。 torch.onnx 模組從原生的 PyTorch torch.nn.Module 模型中捕獲計算圖,並將其轉換為 ONNX 圖。
匯出的模型可以被許多支援 ONNX 的執行時所消費,包括 Microsoft 的 ONNX Runtime。
您可以使用如下所列的兩種 ONNX 匯出器 API。兩者都可以透過函式 torch.onnx.export() 呼叫。下一個示例展示瞭如何匯出簡單模型。
import torch
class MyModel(torch.nn.Module):
def __init__(self):
super(MyModel, self).__init__()
self.conv1 = torch.nn.Conv2d(1, 128, 5)
def forward(self, x):
return torch.relu(self.conv1(x))
input_tensor = torch.rand((1, 1, 128, 128), dtype=torch.float32)
model = MyModel()
torch.onnx.export(
model, # model to export
(input_tensor,), # inputs of the model,
"my_model.onnx", # filename of the ONNX model
input_names=["input"], # Rename inputs for the ONNX model
dynamo=True # True or False to select the exporter to use
)
接下來的章節介紹了匯出器的兩個版本。
基於 TorchDynamo 的 ONNX 匯出器¶
基於 TorchDynamo 的 ONNX 匯出器是適用於 PyTorch 2.1 及更新版本的最新的(Beta)匯出器
利用 TorchDynamo 引擎掛接到 Python 的幀評估 API,並將其位元組碼動態重寫為 FX 圖。生成的 FX 圖經過精修後,最終被轉換為 ONNX 圖。
這種方法的主要優勢在於,FX 圖是使用位元組碼分析捕獲的,這保留了模型的動態特性,而不是使用傳統的靜態跟蹤技術。
基於 TorchScript 的 ONNX 匯出器¶
基於 TorchScript 的 ONNX 匯出器自 PyTorch 1.2.0 版本起可用
利用 TorchScript 跟蹤模型(透過 torch.jit.trace())並捕獲靜態計算圖。
因此,生成的圖存在一些限制
它不記錄任何控制流,例如 if 語句或迴圈;
不處理訓練模式和評估模式之間的細微差別;
不能真正處理動態輸入
為了嘗試支援靜態跟蹤的限制,匯出器還支援 TorchScript 指令碼化(透過 torch.jit.script()),例如,這增加了對資料依賴的控制流的支援。然而,TorchScript 本身是 Python 語言的一個子集,因此並非所有 Python 特性都受支援,例如原地操作。