ResNeXt101


模型描述
ResNeXt101-32x4d 是在《Aggregated Residual Transformations for Deep Neural Networks》(用于深度神经网络的聚合残差变换)论文中引入的模型。
它基于常规的 ResNet 模型,将瓶颈模块(bottleneck block)内部的 3×3 卷积替换为 3×3 分组卷积。
该模型在 Volta、Turing 和 NVIDIA Ampere GPU 架构上使用 Tensor Core 进行混合精度训练。因此,研究人员可以获得比不使用 Tensor Core 快 3 倍的训练结果,同时体验到混合精度训练带来的优势。该模型会针对每个 NGC 月度容器版本进行测试,以确保其准确性和性能随时间推移保持一致。
在使用混合精度进行训练时,我们使用 NHWC 数据布局。
请注意,ResNeXt101-32x4d 模型可以使用 TorchScript、ONNX Runtime 或 TensorRT 作为执行后端,部署在 NVIDIA Triton Inference Server 上进行推理。详情请查看 NGC。
模型架构

图片来源:Aggregated Residual Transformations for Deep Neural Networks
图片展示了 ResNet 瓶颈模块与 ResNeXt 瓶颈模块之间的区别。
ResNeXt101-32x4d 模型的基数(cardinality)等于 32,瓶颈宽度(bottleneck width)等于 4。
示例
在下面的示例中,我们将使用预训练的 ResNeXt101-32x4d 模型对图像执行推理并展示结果。
要运行此示例,您需要安装一些额外的 Python 包。这些包用于图像预处理和可视化。
!pip install validators matplotlib
import torch
from PIL import Image
import torchvision.transforms as transforms
import numpy as np
import json
import requests
import matplotlib.pyplot as plt
import warnings
warnings.filterwarnings('ignore')
%matplotlib inline
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
print(f'Using {device} for inference')
加载在 ImageNet 数据集上预训练的模型。
resneXt = torch.hub.load('NVIDIA/DeepLearningExamples:torchhub', 'nvidia_resneXt')
utils = torch.hub.load('NVIDIA/DeepLearningExamples:torchhub', 'nvidia_convnets_processing_utils')
resneXt.eval().to(device)
准备示例输入数据。
uris = [
'http://images.cocodataset.org/test-stuff2017/000000024309.jpg',
'http://images.cocodataset.org/test-stuff2017/000000028117.jpg',
'http://images.cocodataset.org/test-stuff2017/000000006149.jpg',
'http://images.cocodataset.org/test-stuff2017/000000004954.jpg',
]
batch = torch.cat(
[utils.prepare_input_from_uri(uri) for uri in uris]
).to(device)
运行推理。使用 pick_n_best(predictions=output, n=topN) 辅助函数来根据模型选取 N 个最可能的假设。
with torch.no_grad():
output = torch.nn.functional.softmax(resneXt(batch), dim=1)
results = utils.pick_n_best(predictions=output, n=5)
显示结果。
for uri, result in zip(uris, results):
img = Image.open(requests.get(uri, stream=True).raw)
img.thumbnail((256,256), Image.ANTIALIAS)
plt.imshow(img)
plt.show()
print(result)
详情
有关模型输入和输出、训练配方、推理和性能的详细信息,请访问:GitHub 和/或 NGC
参考文献