TorchVision 推出了一款新的向後相容 API,用於構建支援多權重的模型。新 API 允許在同一模型變體上載入不同的預訓練權重,跟蹤重要的元資料(如分類標籤),幷包含使用模型所需的預處理轉換。在這篇博文中,我們計劃回顧原型 API,展示其功能,並強調與現有 API 的主要區別。

我們希望在最終確定 API 之前聽取您的想法。為了收集您的反饋,我們建立了一個 Github Issue,您可以在其中釋出您的想法、問題和評論。
當前 API 的侷限性
TorchVision 目前提供預訓練模型,這些模型可以作為遷移學習的起點,或者直接用於計算機視覺應用。例項化預訓練模型並進行預測的典型方法是
import torch
from PIL import Image
from torchvision import models as M
from torchvision.transforms import transforms as T
img = Image.open("test/assets/encode_jpeg/grace_hopper_517x606.jpg")
# Step 1: Initialize model
model = M.resnet50(pretrained=True)
model.eval()
# Step 2: Define and initialize the inference transforms
preprocess = T.Compose([
T.Resize([256, ]),
T.CenterCrop(224),
T.PILToTensor(),
T.ConvertImageDtype(torch.float),
T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
# Step 3: Apply inference preprocessing transforms
batch = preprocess(img).unsqueeze(0)
prediction = model(batch).squeeze(0).softmax(0)
# Step 4: Use the model and print the predicted category
class_id = prediction.argmax().item()
score = prediction[class_id].item()
with open("imagenet_classes.txt", "r") as f:
categories = [s.strip() for s in f.readlines()]
category_name = categories[class_id]
print(f"{category_name}: {100 * score}%")
上述方法存在一些侷限性
- 無法支援多個預訓練權重:由於
pretrained變數是布林值,我們只能提供一組權重。當我們顯著 提高現有模型的準確性 並希望將這些改進提供給社群時,這帶來了嚴重的侷限性。它也阻止我們提供同一模型變體在不同資料集上的預訓練權重。 - 缺少推理/預處理轉換:使用者在使用模型之前被迫定義必要的轉換。推理轉換通常與訓練過程和用於估計權重的日期集相關聯。這些轉換中的任何微小差異(例如插值值、調整大小/裁剪大小等)都可能導致準確性大幅下降或模型無法使用。
- 缺少元資料:與權重相關的關鍵資訊對使用者不可用。例如,需要查閱外部來源和文件才能找到 類別標籤、訓練配方、準確性指標等。
新 API 解決了上述侷限性,並減少了標準任務所需的樣板程式碼。
原型 API 概述
讓我們看看如何使用新 API 獲得與上述完全相同的結果
from PIL import Image
from torchvision.prototype import models as PM
img = Image.open("test/assets/encode_jpeg/grace_hopper_517x606.jpg")
# Step 1: Initialize model
weights = PM.ResNet50_Weights.IMAGENET1K_V1
model = PM.resnet50(weights=weights)
model.eval()
# Step 2: Initialize the inference transforms
preprocess = weights.transforms()
# Step 3: Apply inference preprocessing transforms
batch = preprocess(img).unsqueeze(0)
prediction = model(batch).squeeze(0).softmax(0)
# Step 4: Use the model and print the predicted category
class_id = prediction.argmax().item()
score = prediction[class_id].item()
category_name = weights.meta["categories"][class_id]
print(f"{category_name}: {100 * score}*%*")
正如我們所看到的,新 API 消除了上述侷限性。讓我們詳細探討新功能。
多權重支援
新 API 的核心是能夠為同一模型變體定義多個不同的權重。每個模型構建方法(例如 resnet50)都有一個相關的列舉類(例如 ResNet50_Weights),其中包含與可用預訓練權重數量一樣多的條目。此外,每個列舉類都有一個 DEFAULT 別名,指向特定模型可用的最佳權重。這使得希望始終使用最佳可用權重的使用者無需修改程式碼即可實現。
這是一個使用不同權重初始化模型的示例
from torchvision.prototype.models import resnet50, ResNet50_Weights
# Legacy weights with accuracy 76.130%
model = resnet50(weights=ResNet50_Weights.IMAGENET1K_V1)
# New weights with accuracy 80.858%
model = resnet50(weights=ResNet50_Weights.IMAGENET1K_V2)
# Best available weights (currently alias for IMAGENET1K_V2)
model = resnet50(weights=ResNet50_Weights.DEFAULT)
# No weights - random initialization
model = resnet50(weights=None)
關聯的元資料和預處理轉換
每個模型的權重都與元資料相關聯。我們儲存的資訊型別取決於模型的任務(分類、檢測、分割等)。典型資訊包括訓練配方的連結、插值模式、類別和驗證指標等資訊。這些值可以透過 meta 屬性以程式設計方式訪問
from torchvision.prototype.models import ResNet50_Weights
# Accessing a single record
size = ResNet50_Weights.IMAGENET1K_V2.meta["size"]
# Iterating the items of the meta-data dictionary
for k, v in ResNet50_Weights.IMAGENET1K_V2.meta.items():
print(k, v)
此外,每個權重條目都與必要的預處理轉換相關聯。所有當前的預處理轉換都是 JIT 可指令碼化的,可以透過 transforms 屬性訪問。在使用它們處理資料之前,需要初始化/構造轉換。這種惰性初始化方案旨在確保解決方案的記憶體效率。轉換的輸入可以是 PIL.Image,也可以是使用 torchvision.io 讀取的 Tensor。
from torchvision.prototype.models import ResNet50_Weights
# Initializing preprocessing at standard 224x224 resolution
preprocess = ResNet50_Weights.IMAGENET1K_V2.transforms()
# Initializing preprocessing at 400x400 resolution
preprocess = ResNet50_Weights.IMAGENET1K_V2.transforms(crop_size=400, resize_size=400)
# Once initialized the callable can accept the image data:
# img_preprocessed = preprocess(img)
將權重與其元資料和預處理相關聯將提高透明度,改善可復現性,並使其更容易記錄一組權重是如何生成的。
按名稱獲取權重
能夠直接將權重與其屬性(元資料、預處理可呼叫物件等)關聯起來是我們的實現使用列舉而不是字串的原因。然而,對於僅知道權重名稱的情況,我們提供了一種能夠將權重名稱連結到其列舉的方法
from torchvision.prototype.models import get_weight
# Weights can be retrieved by name:
assert get_weight("ResNet50_Weights.IMAGENET1K_V1") == ResNet50_Weights.IMAGENET1K_V1
assert get_weight("ResNet50_Weights.IMAGENET1K_V2") == ResNet50_Weights.IMAGENET1K_V2
# Including using the DEFAULT alias:
assert get_weight("ResNet50_Weights.DEFAULT") == ResNet50_Weights.IMAGENET1K_V2
棄用
在新 API 中,以前用於將權重載入到完整模型或其骨幹的布林引數 pretrained 和 pretrained_backbone 已被棄用。當前的實現完全向後相容,因為它無縫地將舊引數對映到新引數。在新的構建器中使用舊引數會發出以下棄用警告
>>> model = torchvision.prototype.models.resnet50(pretrained=True)
UserWarning: The parameter 'pretrained' is deprecated, please use 'weights' instead.
UserWarning:
Arguments other than a weight enum or `None` for 'weights' are deprecated.
The current behavior is equivalent to passing `weights=ResNet50_Weights.IMAGENET1K_V1`.
You can also use `weights=ResNet50_Weights.DEFAULT` to get the most up-to-date weights.
此外,構建器方法要求使用關鍵字引數。位置引數的使用已被棄用,使用它們會發出以下警告
>>> model = torchvision.prototype.models.resnet50(None)
UserWarning:
Using 'weights' as positional parameter(s) is deprecated.
Please use keyword parameter(s) instead.
測試新 API
遷移到新 API 非常簡單。以下兩種 API 之間的方法呼叫是等效的
# Using pretrained weights:
torchvision.prototype.models.resnet50(weights=ResNet50_Weights.IMAGENET1K_V1)
torchvision.models.resnet50(pretrained=True)
torchvision.models.resnet50(True)
# Using no weights:
torchvision.prototype.models.resnet50(weights=None)
torchvision.models.resnet50(pretrained=False)
torchvision.models.resnet50(False)
請注意,原型功能僅在 TorchVision 的夜間版本中可用,因此要使用它,您需要按如下方式安裝它
conda install torchvision -c pytorch-nightly
有關安裝夜間版本的替代方法,請參閱 PyTorch 下載頁面。您還可以從最新的主版本原始碼安裝 TorchVision;有關更多資訊,請參閱我們的 倉庫。
使用新 API 訪問最先進的模型權重
如果您仍然不相信嘗試新的 API,這裡有另一個理由。我們最近重新整理了 訓練配方,並從我們的許多模型中獲得了 SOTA 準確性。改進的權重可以透過新 API 輕鬆訪問。以下是模型改進的快速概述

| 模型 | 舊 Acc@1 | 新 Acc@1 |
|---|---|---|
| EfficientNet B1 | 78.642 | 79.838 |
| MobileNetV3 Large | 74.042 | 75.274 |
| 量化 ResNet50 | 75.92 | 80.282 |
| 量化 ResNeXt101 32x8d | 78.986 | 82.574 |
| RegNet X 400mf | 72.834 | 74.864 |
| RegNet X 800mf | 75.212 | 77.522 |
| RegNet X 1 6gf | 77.04 | 79.668 |
| RegNet X 3 2gf | 78.364 | 81.198 |
| RegNet X 8gf | 79.344 | 81.682 |
| RegNet X 16gf | 80.058 | 82.72 |
| RegNet X 32gf | 80.622 | 83.018 |
| RegNet Y 400mf | 74.046 | 75.806 |
| RegNet Y 800mf | 76.42 | 78.838 |
| RegNet Y 1 6gf | 77.95 | 80.882 |
| RegNet Y 3 2gf | 78.948 | 81.984 |
| RegNet Y 8gf | 80.032 | 82.828 |
| RegNet Y 16gf | 80.424 | 82.89 |
| RegNet Y 32gf | 80.878 | 83.366 |
| ResNet50 | 76.13 | 80.858 |
| ResNet101 | 77.374 | 81.886 |
| ResNet152 | 78.312 | 82.284 |
| ResNeXt50 32x4d | 77.618 | 81.198 |
| ResNeXt101 32x8d | 79.312 | 82.834 |
| Wide ResNet50 2 | 78.468 | 81.602 |
| Wide ResNet101 2 | 78.848 | 82.51 |
請花幾分鐘時間提供您對新 API 的反饋,這對將其從原型階段提升並在下個版本中包含至關重要。您可以在專門的 Github Issue 上進行此操作。我們期待您的評論!