• 教程 >
  • 最佳化用於部署的 Vision Transformer 模型
快捷方式

最佳化用於部署的 Vision Transformer 模型

創建於: 2021 年 3 月 15 日 | 最後更新於: 2024 年 1 月 19 日 | 最後驗證於: 2024 年 11 月 5 日

Jeff Tang, Geeta Chauhan

Vision Transformer 模型將最前沿的基於注意力的 Transformer 模型(在自然語言處理領域引入並取得了各種最先進 (SOTA) 的結果)應用於計算機視覺任務。Facebook 資料高效影像 Transformer 模型 DeiT 是一個在 ImageNet 上訓練用於影像分類的 Vision Transformer 模型。

在本教程中,我們將首先介紹 DeiT 是什麼以及如何使用它,然後逐步完成指令碼化、量化、最佳化以及在 iOS 和 Android 應用中使用模型的完整步驟。我們還將比較量化最佳化模型與非量化非最佳化模型的效能,並展示在這些步驟中應用量化和最佳化對模型帶來的好處。

什麼是 DeiT

自 2012 年深度學習興起以來,卷積神經網路 (CNN) 一直是影像分類的主要模型,但 CNN 通常需要數億張影像進行訓練才能達到 SOTA 結果。DeiT 是一種 Vision Transformer 模型,在訓練中需要更少的資料和計算資源,就能在執行影像分類任務時與領先的 CNN 競爭。這得益於 DeiT 的兩個關鍵組成部分:

  • 資料增強,模擬在更大規模資料集上進行訓練;

  • 原生蒸餾,允許 Transformer 網路從 CNN 的輸出中學習。

DeiT 表明 Transformer 可以成功應用於計算機視覺任務,即使資料和資源有限。有關 DeiT 的更多詳細資訊,請參閱其程式碼庫論文

使用 DeiT 進行影像分類

請遵循 DeiT 程式碼庫中的 README.md 檔案,獲取有關如何使用 DeiT 進行影像分類的詳細資訊,或者為了快速測試,首先安裝所需的軟體包:

pip install torch torchvision timm pandas requests

要在 Google Colab 中執行,請執行以下命令安裝依賴項:

!pip install timm pandas requests

然後執行以下指令碼:

from PIL import Image
import torch
import timm
import requests
import torchvision.transforms as transforms
from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD

print(torch.__version__)
# should be 1.8.0


model = torch.hub.load('facebookresearch/deit:main', 'deit_base_patch16_224', pretrained=True)
model.eval()

transform = transforms.Compose([
    transforms.Resize(256, interpolation=3),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD),
])

img = Image.open(requests.get("https://raw.githubusercontent.com/pytorch/ios-demo-app/master/HelloWorld/HelloWorld/HelloWorld/image.png", stream=True).raw)
img = transform(img)[None,]
out = model(img)
clsidx = torch.argmax(out)
print(clsidx.item())
2.7.0+cu126
Downloading: "https://github.com/facebookresearch/deit/zipball/main" to /var/lib/ci-user/.cache/torch/hub/main.zip
/usr/local/lib/python3.10/dist-packages/timm/models/registry.py:4: FutureWarning:

Importing from timm.models.registry is deprecated, please import via timm.models

/usr/local/lib/python3.10/dist-packages/timm/models/layers/__init__.py:48: FutureWarning:

Importing from timm.models.layers is deprecated, please import via timm.layers

/var/lib/ci-user/.cache/torch/hub/facebookresearch_deit_main/models.py:63: UserWarning:

Overwriting deit_tiny_patch16_224 in registry with models.deit_tiny_patch16_224. This is because the name being registered conflicts with an existing name. Please check if this is not expected.

/var/lib/ci-user/.cache/torch/hub/facebookresearch_deit_main/models.py:78: UserWarning:

Overwriting deit_small_patch16_224 in registry with models.deit_small_patch16_224. This is because the name being registered conflicts with an existing name. Please check if this is not expected.

/var/lib/ci-user/.cache/torch/hub/facebookresearch_deit_main/models.py:93: UserWarning:

Overwriting deit_base_patch16_224 in registry with models.deit_base_patch16_224. This is because the name being registered conflicts with an existing name. Please check if this is not expected.

/var/lib/ci-user/.cache/torch/hub/facebookresearch_deit_main/models.py:108: UserWarning:

Overwriting deit_tiny_distilled_patch16_224 in registry with models.deit_tiny_distilled_patch16_224. This is because the name being registered conflicts with an existing name. Please check if this is not expected.

/var/lib/ci-user/.cache/torch/hub/facebookresearch_deit_main/models.py:123: UserWarning:

Overwriting deit_small_distilled_patch16_224 in registry with models.deit_small_distilled_patch16_224. This is because the name being registered conflicts with an existing name. Please check if this is not expected.

/var/lib/ci-user/.cache/torch/hub/facebookresearch_deit_main/models.py:138: UserWarning:

Overwriting deit_base_distilled_patch16_224 in registry with models.deit_base_distilled_patch16_224. This is because the name being registered conflicts with an existing name. Please check if this is not expected.

/var/lib/ci-user/.cache/torch/hub/facebookresearch_deit_main/models.py:153: UserWarning:

Overwriting deit_base_patch16_384 in registry with models.deit_base_patch16_384. This is because the name being registered conflicts with an existing name. Please check if this is not expected.

/var/lib/ci-user/.cache/torch/hub/facebookresearch_deit_main/models.py:168: UserWarning:

Overwriting deit_base_distilled_patch16_384 in registry with models.deit_base_distilled_patch16_384. This is because the name being registered conflicts with an existing name. Please check if this is not expected.

Downloading: "https://dl.fbaipublicfiles.com/deit/deit_base_patch16_224-b5f2ef4d.pth" to /var/lib/ci-user/.cache/torch/hub/checkpoints/deit_base_patch16_224-b5f2ef4d.pth

  0%|          | 0.00/330M [00:00<?, ?B/s]
  5%|5         | 18.1M/330M [00:00<00:01, 189MB/s]
 13%|#3        | 44.4M/330M [00:00<00:01, 240MB/s]
 24%|##3       | 77.9M/330M [00:00<00:00, 290MB/s]
 32%|###1      | 106M/330M [00:00<00:00, 286MB/s]
 41%|####      | 135M/330M [00:00<00:00, 291MB/s]
 52%|#####1    | 171M/330M [00:00<00:00, 320MB/s]
 61%|######1   | 202M/330M [00:00<00:00, 321MB/s]
 71%|#######1  | 234M/330M [00:00<00:00, 328MB/s]
 81%|########  | 266M/330M [00:00<00:00, 328MB/s]
 91%|######### | 300M/330M [00:01<00:00, 339MB/s]
100%|##########| 330M/330M [00:01<00:00, 318MB/s]
269

輸出應該是 269,根據 ImageNet 類別索引與標籤檔案的對應關係,它對映到 timber wolf, grey wolf, gray wolf, Canis lupus

現在我們已經驗證可以使用 DeiT 模型對影像進行分類,接下來看看如何修改模型以便它可以在 iOS 和 Android 應用上執行。

指令碼化 DeiT

要在移動裝置上使用該模型,我們首先需要對模型進行指令碼化。有關快速概覽,請參閱指令碼化和最佳化秘籍。執行以下程式碼,將上一步中使用的 DeiT 模型轉換為可在移動裝置上執行的 TorchScript 格式。

model = torch.hub.load('facebookresearch/deit:main', 'deit_base_patch16_224', pretrained=True)
model.eval()
scripted_model = torch.jit.script(model)
scripted_model.save("fbdeit_scripted.pt")
Using cache found in /var/lib/ci-user/.cache/torch/hub/facebookresearch_deit_main

生成了大小約為 346MB 的指令碼化模型檔案 fbdeit_scripted.pt

量化 DeiT

為了在基本保持推理準確性不變的情況下顯著減小訓練模型的尺寸,可以對模型應用量化。由於 DeiT 中使用了 Transformer 模型,我們可以輕鬆地對模型應用動態量化,因為動態量化對 LSTM 和 Transformer 模型的效果最好(詳情請參閱此處)。

現在執行以下程式碼:

# Use 'x86' for server inference (the old 'fbgemm' is still available but 'x86' is the recommended default) and ``qnnpack`` for mobile inference.
backend = "x86" # replaced with ``qnnpack`` causing much worse inference speed for quantized model on this notebook
model.qconfig = torch.quantization.get_default_qconfig(backend)
torch.backends.quantized.engine = backend

quantized_model = torch.quantization.quantize_dynamic(model, qconfig_spec={torch.nn.Linear}, dtype=torch.qint8)
scripted_quantized_model = torch.jit.script(quantized_model)
scripted_quantized_model.save("fbdeit_scripted_quantized.pt")
/var/lib/ci-user/.local/lib/python3.10/site-packages/torch/ao/quantization/observer.py:244: UserWarning:

Please use quant_min and quant_max to specify the range for observers.                     reduce_range will be deprecated in a future release of PyTorch.

這生成了模型的指令碼化和量化版本 fbdeit_quantized_scripted.pt,大小約為 89MB,相較於非量化模型的 346MB,減小了 74%!

你可以使用 scripted_quantized_model 生成相同的推理結果

out = scripted_quantized_model(img)
clsidx = torch.argmax(out)
print(clsidx.item())
# The same output 269 should be printed
269

最佳化 DeiT

在移動裝置上使用量化和指令碼化模型之前的最後一步是進行最佳化

from torch.utils.mobile_optimizer import optimize_for_mobile
optimized_scripted_quantized_model = optimize_for_mobile(scripted_quantized_model)
optimized_scripted_quantized_model.save("fbdeit_optimized_scripted_quantized.pt")

生成的 fbdeit_optimized_scripted_quantized.pt 檔案的大小與量化、指令碼化但未最佳化的模型大致相同。推理結果保持不變。

out = optimized_scripted_quantized_model(img)
clsidx = torch.argmax(out)
print(clsidx.item())
# Again, the same output 269 should be printed
269

使用 Lite Interpreter

為了瞭解 Lite Interpreter 能帶來多少模型尺寸減小和推理速度提升,讓我們建立模型的精簡(lite)版本。

optimized_scripted_quantized_model._save_for_lite_interpreter("fbdeit_optimized_scripted_quantized_lite.ptl")
ptl = torch.jit.load("fbdeit_optimized_scripted_quantized_lite.ptl")

儘管精簡模型的尺寸與非精簡版本相當,但在移動裝置上執行精簡版本時,預期會有推理速度提升。

比較推理速度

為了瞭解四種模型(原始模型、指令碼化模型、量化指令碼化模型、最佳化量化指令碼化模型)的推理速度差異,請執行以下程式碼:

with torch.autograd.profiler.profile(use_cuda=False) as prof1:
    out = model(img)
with torch.autograd.profiler.profile(use_cuda=False) as prof2:
    out = scripted_model(img)
with torch.autograd.profiler.profile(use_cuda=False) as prof3:
    out = scripted_quantized_model(img)
with torch.autograd.profiler.profile(use_cuda=False) as prof4:
    out = optimized_scripted_quantized_model(img)
with torch.autograd.profiler.profile(use_cuda=False) as prof5:
    out = ptl(img)

print("original model: {:.2f}ms".format(prof1.self_cpu_time_total/1000))
print("scripted model: {:.2f}ms".format(prof2.self_cpu_time_total/1000))
print("scripted & quantized model: {:.2f}ms".format(prof3.self_cpu_time_total/1000))
print("scripted & quantized & optimized model: {:.2f}ms".format(prof4.self_cpu_time_total/1000))
print("lite model: {:.2f}ms".format(prof5.self_cpu_time_total/1000))
original model: 121.17ms
scripted model: 128.88ms
scripted & quantized model: 94.96ms
scripted & quantized & optimized model: 136.20ms
lite model: 112.27ms

在 Google Colab 上執行的結果如下:

original model: 1236.69ms
scripted model: 1226.72ms
scripted & quantized model: 593.19ms
scripted & quantized & optimized model: 598.01ms
lite model: 600.72ms

以下結果總結了每種模型所需的推理時間以及每種模型相對於原始模型的百分比減小。

import pandas as pd
import numpy as np

df = pd.DataFrame({'Model': ['original model','scripted model', 'scripted & quantized model', 'scripted & quantized & optimized model', 'lite model']})
df = pd.concat([df, pd.DataFrame([
    ["{:.2f}ms".format(prof1.self_cpu_time_total/1000), "0%"],
    ["{:.2f}ms".format(prof2.self_cpu_time_total/1000),
     "{:.2f}%".format((prof1.self_cpu_time_total-prof2.self_cpu_time_total)/prof1.self_cpu_time_total*100)],
    ["{:.2f}ms".format(prof3.self_cpu_time_total/1000),
     "{:.2f}%".format((prof1.self_cpu_time_total-prof3.self_cpu_time_total)/prof1.self_cpu_time_total*100)],
    ["{:.2f}ms".format(prof4.self_cpu_time_total/1000),
     "{:.2f}%".format((prof1.self_cpu_time_total-prof4.self_cpu_time_total)/prof1.self_cpu_time_total*100)],
    ["{:.2f}ms".format(prof5.self_cpu_time_total/1000),
     "{:.2f}%".format((prof1.self_cpu_time_total-prof5.self_cpu_time_total)/prof1.self_cpu_time_total*100)]],
    columns=['Inference Time', 'Reduction'])], axis=1)

print(df)

"""
        Model                             Inference Time    Reduction
0   original model                             1236.69ms           0%
1   scripted model                             1226.72ms        0.81%
2   scripted & quantized model                  593.19ms       52.03%
3   scripted & quantized & optimized model      598.01ms       51.64%
4   lite model                                  600.72ms       51.43%
"""
                                    Model  ... Reduction
0                          original model  ...        0%
1                          scripted model  ...    -6.36%
2              scripted & quantized model  ...    21.63%
3  scripted & quantized & optimized model  ...   -12.41%
4                              lite model  ...     7.34%

[5 rows x 3 columns]

'\n        Model                             Inference Time    Reduction\n0\toriginal model                             1236.69ms           0%\n1\tscripted model                             1226.72ms        0.81%\n2\tscripted & quantized model                  593.19ms       52.03%\n3\tscripted & quantized & optimized model      598.01ms       51.64%\n4\tlite model                                  600.72ms       51.43%\n'

文件

查閱 PyTorch 的全面開發者文件

檢視文件

教程

獲取面向初學者和高階開發者的深度教程

檢視教程

資源

查詢開發資源並獲得解答

檢視資源