快捷方式

理解基於 TorchDynamo 的 ONNX 匯出器記憶體使用

之前的基於 TorchScript 的 ONNX 匯出器會執行一次模型來跟蹤其執行過程,如果模型的記憶體需求超過可用的 GPU 記憶體,可能會導致 GPU 記憶體不足。這個問​​題已透過新的基於 TorchDynamo 的 ONNX 匯出器得到解決。

基於 TorchDynamo 的 ONNX 匯出器利用 torch.export.export() 函式來利用 FakeTensorMode,從而避免在匯出過程中執行實際的張量計算。與基於 TorchScript 的 ONNX 匯出器相比,這種方法顯著降低了記憶體使用量。

下面是一個示例,展示了基於 TorchScript 和基於 TorchDynamo 的 ONNX 匯出器之間的記憶體使用差異。在此示例中,我們使用了 MONAI 中的 HighResNet 模型。在繼續之前,請從 PyPI 安裝它

pip install monai

PyTorch 提供了一個用於捕獲和視覺化記憶體使用跟蹤的工具。我們將使用此工具記錄兩種匯出器在匯出過程中的記憶體使用情況並比較結果。您可以在理解 CUDA 記憶體使用上找到有關此工具的更多詳細資訊。

基於 TorchScript 的匯出器

可以執行以下程式碼生成一個快照檔案,該檔案記錄匯出過程中已分配 CUDA 記憶體的狀態。

import torch

from monai.networks.nets import (
    HighResNet,
)

torch.cuda.memory._record_memory_history()

model = HighResNet(
    spatial_dims=3, in_channels=1, out_channels=3, norm_type="batch"
).eval()

model = model.to("cuda")
data = torch.randn(30, 1, 48, 48, 48, dtype=torch.float32).to("cuda")

with torch.no_grad():
    onnx_program = torch.onnx.export(
        model,
        data,
        "torchscript_exporter_highresnet.onnx",
        dynamo=False,
    )

snapshot_name = "torchscript_exporter_example.pickle"
print(f"generate {snapshot_name}")

torch.cuda.memory._dump_snapshot(snapshot_name)
print("Export is done.")

開啟 pytorch.org/memory_viz 並將生成的 pickled 快照檔案拖放到視覺化工具中。記憶體使用情況如下所示

_images/torch_script_exporter_memory_usage.png

透過此圖,我們可以看到記憶體使用峰值高於 2.8GB。

基於 TorchDynamo 的匯出器

可以執行以下程式碼生成一個快照檔案,該檔案記錄匯出過程中已分配 CUDA 記憶體的狀態。

import torch

from monai.networks.nets import (
    HighResNet,
)

torch.cuda.memory._record_memory_history()

model = HighResNet(
    spatial_dims=3, in_channels=1, out_channels=3, norm_type="batch"
).eval()

model = model.to("cuda")
data = torch.randn(30, 1, 48, 48, 48, dtype=torch.float32).to("cuda")

with torch.no_grad():
    onnx_program = torch.onnx.export(
                        model,
                        data,
                        "test_faketensor.onnx",
                        dynamo=True,
                    )

snapshot_name = f"torchdynamo_exporter_example.pickle"
print(f"generate {snapshot_name}")

torch.cuda.memory._dump_snapshot(snapshot_name)
print(f"Export is done.")

開啟 pytorch.org/memory_viz 並將生成的 pickled 快照檔案拖放到視覺化工具中。記憶體使用情況如下所示

_images/torch_dynamo_exporter_memory_usage.png

透過此圖,我們可以看到記憶體使用峰值僅為約 45MB。與基於 TorchScript 的匯出器記憶體使用峰值相比,它降低了 98% 的記憶體使用量。

文件

訪問 PyTorch 的全面開發者文件

檢視文件

教程

獲取面向初學者和高階開發者的深入教程

檢視教程

資源

查詢開發資源並獲得問題解答

檢視資源