快捷方式

torch.onnx.verification

ONNX 驗證模組提供了一套工具,用於驗證 ONNX 模型的正確性。

torch.onnx.verification.verify_onnx_program(onnx_program, args=None, kwargs=None, compare_intermediates=False)[原始碼]

透過將 ONNX 模型的值與來自 ExportedProgram 的預期值進行比較來驗證 ONNX 模型。

引數
  • onnx_program (_onnx_program.ONNXProgram) – 要驗證的 ONNX program。

  • args (tuple[Any, ...] | None) – 模型的輸入引數。

  • kwargs (dict[str, Any] | None) – 模型的關鍵字引數。

  • compare_intermediates (bool) – 是否驗證中間值。這將花費更長時間,因此預設停用。

返回

包含每個值的驗證資訊的 VerificationInfo 物件。

返回型別

list[VerificationInfo]

class torch.onnx.verification.VerificationInfo(name, max_abs_diff, max_rel_diff, abs_diff_hist, rel_diff_hist, expected_dtype, actual_dtype)

ONNX program 中某個值的驗證資訊。

此類包含最大絕對差值、最大相對差值以及預期值與實際值之間絕對差值和相對差值的直方圖。它還包括預期和實際資料型別。

直方圖表示為張量的元組,其中第一個張量是直方圖計數,第二個張量是 bin 邊界。

變數
  • name (str) – 值的名稱(輸出或中間值)。

  • max_abs_diff (float) – 預期值與實際值之間的最大絕對差值。

  • max_rel_diff (float) – 預期值與實際值之間的最大相對差值。

  • abs_diff_hist (tuple[torch.Tensor, torch.Tensor]) – 表示絕對差值直方圖的張量元組。第一個張量是直方圖計數,第二個張量是 bin 邊界。

  • rel_diff_hist (tuple[torch.Tensor, torch.Tensor]) – 表示相對差值直方圖的張量元組。第一個張量是直方圖計數,第二個張量是 bin 邊界。

  • expected_dtype (torch.dtype) – 預期值的資料型別。

  • actual_dtype (torch.dtype) – 實際值的資料型別。

classmethod from_tensors(name, expected, actual)[原始碼][原始碼]

從兩個張量建立一個 VerificationInfo 物件。

引數
返回

VerificationInfo 物件。

返回型別

VerificationInfo

torch.onnx.verification.verify(model, input_args, input_kwargs=None, do_constant_folding=True, dynamic_axes=None, input_names=None, output_names=None, training=<TrainingMode.EVAL: 0>, opset_version=None, keep_initializers_as_inputs=True, verbose=False, fixed_batch_size=False, use_external_data=False, additional_test_inputs=None, options=None)[原始碼][原始碼]

驗證模型匯出到 ONNX 是否與原始 PyTorch 模型一致。

自版本 2.7 起已棄用: 考慮使用 torch.onnx.export(..., dynamo=True) 並使用返回的 ONNXProgram 來測試 ONNX 模型。

引數
觸發
  • AssertionError – 如果 ONNX 模型和 PyTorch 模型的輸出在指定精度內不相等。

  • ValueError – 如果提供的引數無效。

已棄用

以下類和函式已棄用。

class torch.onnx.verification.check_export_model_diff[原始碼][原始碼]
class torch.onnx.verification.GraphInfo[原始碼][原始碼]
class torch.onnx.verification.GraphInfoPrettyPrinter[原始碼][原始碼]
class torch.onnx.verification.OnnxBackend[原始碼][原始碼]
class torch.onnx.verification.OnnxTestCaseRepro[原始碼][原始碼]
class torch.onnx.verification.VerificationOptions[原始碼][原始碼]
torch.onnx.verification.find_mismatch()[原始碼][原始碼]
torch.onnx.verification.verify_aten_graph()[原始碼][原始碼]

文件

訪問全面的 PyTorch 開發者文件

檢視文件

教程

獲取針對初學者和高階開發者的深度教程

檢視教程

資源

查詢開發資源並獲得問題解答

檢視資源