快捷方式

Torchscript 支援

注意

Colab 上嘗試或前往末尾下載完整示例程式碼。

本示例演示了 torchvision transforms 對 Tensor 影像的 torchscript 支援。

from pathlib import Path

import matplotlib.pyplot as plt

import torch
import torch.nn as nn

import torchvision.transforms as v1
from torchvision.io import decode_image

plt.rcParams["savefig.bbox"] = 'tight'
torch.manual_seed(1)

# If you're trying to run that on Colab, you can download the assets and the
# helpers from https://github.com/pytorch/vision/tree/main/gallery/
import sys
sys.path += ["../transforms"]
from helpers import plot
ASSETS_PATH = Path('../assets')

大多數 transforms 支援 torchscript。對於組合 transforms,我們使用 torch.nn.Sequential 而不是 Compose

dog1 = decode_image(str(ASSETS_PATH / 'dog1.jpg'))
dog2 = decode_image(str(ASSETS_PATH / 'dog2.jpg'))

transforms = torch.nn.Sequential(
    v1.RandomCrop(224),
    v1.RandomHorizontalFlip(p=0.3),
)

scripted_transforms = torch.jit.script(transforms)

plot([dog1, scripted_transforms(dog1), dog2, scripted_transforms(dog2)])
plot scripted tensor transforms

警告

上面我們使用了 torchvision.transforms 名稱空間中的 transforms,即“v1” transforms。torchvision.transforms.v2 名稱空間中的 v2 transforms 是在程式碼中使用 transforms 的推薦方式。

v2 transforms 也支援 torchscript,但如果您在 v2 **類** transform 上呼叫 torch.jit.script(),實際上會得到其(已指令碼化的)v1 等效項。由於 v1 和 v2 之間的實現差異,這可能導致指令碼化執行和 eager 執行之間產生略微不同的結果。

如果您確實需要 v2 transforms 的 torchscript 支援,**我們建議對 torchvision.transforms.v2.functional 名稱空間中的 functionals 進行指令碼化**,以避免意外情況。

下面,我們將演示如何結合影像 transformations 和模型前向傳播,同時使用 torch.jit.script 來獲得一個單獨的指令碼化模組。

我們來定義一個 Predictor 模組,它對輸入 tensor 進行變換,然後在其上應用一個 ImageNet 模型。

from torchvision.models import resnet18, ResNet18_Weights


class Predictor(nn.Module):

    def __init__(self):
        super().__init__()
        weights = ResNet18_Weights.DEFAULT
        self.resnet18 = resnet18(weights=weights, progress=False).eval()
        self.transforms = weights.transforms(antialias=True)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        with torch.no_grad():
            x = self.transforms(x)
            y_pred = self.resnet18(x)
            return y_pred.argmax(dim=1)

現在,讓我們定義 Predictor 的指令碼化和非指令碼化例項,並將其應用於多個相同大小的 tensor 影像。

device = "cuda" if torch.cuda.is_available() else "cpu"

predictor = Predictor().to(device)
scripted_predictor = torch.jit.script(predictor).to(device)

batch = torch.stack([dog1, dog2]).to(device)

res = predictor(batch)
res_scripted = scripted_predictor(batch)
Downloading: "https://download.pytorch.org/models/resnet18-f37072fd.pth" to /root/.cache/torch/hub/checkpoints/resnet18-f37072fd.pth

我們可以驗證指令碼化模型和非指令碼化模型的預測結果是一致的。

import json

with open(Path('../assets') / 'imagenet_class_index.json') as labels_file:
    labels = json.load(labels_file)

for i, (pred, pred_scripted) in enumerate(zip(res, res_scripted)):
    assert pred == pred_scripted
    print(f"Prediction for Dog {i + 1}: {labels[str(pred.item())]}")
Prediction for Dog 1: ['n02113023', 'Pembroke']
Prediction for Dog 2: ['n02106662', 'German_shepherd']

由於模型已指令碼化,可以輕鬆地將其儲存到磁碟並重復使用。

import tempfile

with tempfile.NamedTemporaryFile() as f:
    scripted_predictor.save(f.name)

    dumped_scripted_predictor = torch.jit.load(f.name)
    res_scripted_dumped = dumped_scripted_predictor(batch)
assert (res_scripted_dumped == res_scripted).all()

指令碼總執行時間: (0 分 1.453 秒)

由 Sphinx-Gallery 生成的圖集

文件

查閱 PyTorch 的完整開發者文件

檢視文件

教程

獲取面向初學者和高階開發者的深入教程

檢視教程

資源

查詢開發資源並獲得問題解答

檢視資源