ResNeXt101


模型描述
ResNeXt101-32x4d 是在 Aggregated Residual Transformations for Deep Neural Networks 論文中介紹的模型。
它基於常規 ResNet 模型,將瓶頸塊內的 3×3 卷積替換為 3×3 分組卷積。
該模型使用 Tensor Cores 在 Volta、Turing 和 NVIDIA Ampere GPU 架構上以混合精度進行訓練。因此,研究人員可以獲得比不使用 Tensor Cores 訓練快 3 倍的結果,同時享受混合精度訓練的優勢。該模型針對每個 NGC 每月容器釋出進行測試,以確保隨著時間的推移保持一致的準確性和效能。
在使用混合精度進行訓練時,我們使用 NHWC 資料佈局。
請注意,ResNeXt101-32x4d 模型可以使用 TorchScript、ONNX Runtime 或 TensorRT 作為執行後端,在 NVIDIA Triton 推理伺服器上進行推理部署。有關詳細資訊,請檢視 NGC
模型架構

圖片來源:Aggregated Residual Transformations for Deep Neural Networks
圖片顯示了 ResNet 瓶頸塊和 ResNeXt 瓶頸塊之間的差異。
ResNeXt101-32x4d 模型的基數等於 32,瓶頸寬度等於 4。
示例
在下面的示例中,我們將使用預訓練的 ResNeXt101-32x4d 模型對影像執行推理並呈現結果。
要執行此示例,您需要安裝一些額外的 Python 包。這些包用於影像預處理和視覺化。
!pip install validators matplotlib
import torch
from PIL import Image
import torchvision.transforms as transforms
import numpy as np
import json
import requests
import matplotlib.pyplot as plt
import warnings
warnings.filterwarnings('ignore')
%matplotlib inline
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
print(f'Using {device} for inference')
載入在 ImageNet 資料集上預訓練的模型。
resneXt = torch.hub.load('NVIDIA/DeepLearningExamples:torchhub', 'nvidia_resneXt')
utils = torch.hub.load('NVIDIA/DeepLearningExamples:torchhub', 'nvidia_convnets_processing_utils')
resneXt.eval().to(device)
準備示例輸入資料。
uris = [
'http://images.cocodataset.org/test-stuff2017/000000024309.jpg',
'http://images.cocodataset.org/test-stuff2017/000000028117.jpg',
'http://images.cocodataset.org/test-stuff2017/000000006149.jpg',
'http://images.cocodataset.org/test-stuff2017/000000004954.jpg',
]
batch = torch.cat(
[utils.prepare_input_from_uri(uri) for uri in uris]
).to(device)
執行推理。使用輔助函式 pick_n_best(predictions=output, n=topN) 根據模型選擇 N 個最可能的假設。
with torch.no_grad():
output = torch.nn.functional.softmax(resneXt(batch), dim=1)
results = utils.pick_n_best(predictions=output, n=5)
顯示結果。
for uri, result in zip(uris, results):
img = Image.open(requests.get(uri, stream=True).raw)
img.thumbnail((256,256), Image.ANTIALIAS)
plt.imshow(img)
plt.show()
print(result)
詳情
有關模型輸入和輸出、訓練方法、推理和效能的詳細資訊,請訪問:github 和/或 NGC
參考文獻