模型和预训练权重¶
The torchvision.models
子包包含用于解决不同任务的模型定义,包括:图像分类、像素级语义分割、目标检测、实例分割、人物关键点检测、视频分类和光流。
有关预训练权重的常规信息¶
TorchVision 为每个提供的架构提供预训练权重,使用 PyTorch torch.hub
。实例化预训练模型将将其权重下载到缓存目录。可以使用 TORCH_HOME 环境变量设置此目录。有关详细信息,请参阅 torch.hub.load_state_dict_from_url()
。
注意
此库中提供的预训练模型可能拥有自己的许可证或条款和条件,这些条款和条件来自用于训练的数据集。您有责任确定您是否有权将模型用于您的用例。
注意
向后兼容性保证可以将序列化 state_dict
加载到使用旧版 PyTorch 版本创建的模型中。相反,加载整个保存的模型或序列化 ScriptModules
(使用旧版 PyTorch 版本序列化)可能不会保留历史行为。请参阅以下 文档
初始化预训练模型¶
从 v0.13 开始,TorchVision 提供了新的 多权重支持 API,用于将不同的权重加载到现有的模型构建器方法中
from torchvision.models import resnet50, ResNet50_Weights
# Old weights with accuracy 76.130%
resnet50(weights=ResNet50_Weights.IMAGENET1K_V1)
# New weights with accuracy 80.858%
resnet50(weights=ResNet50_Weights.IMAGENET1K_V2)
# Best available weights (currently alias for IMAGENET1K_V2)
# Note that these weights may change across versions
resnet50(weights=ResNet50_Weights.DEFAULT)
# Strings are also supported
resnet50(weights="IMAGENET1K_V2")
# No weights - random initialization
resnet50(weights=None)
迁移到新的 API 非常简单。以下两个 API 之间的调用方法都等效
from torchvision.models import resnet50, ResNet50_Weights
# Using pretrained weights:
resnet50(weights=ResNet50_Weights.IMAGENET1K_V1)
resnet50(weights="IMAGENET1K_V1")
resnet50(pretrained=True) # deprecated
resnet50(True) # deprecated
# Using no weights:
resnet50(weights=None)
resnet50()
resnet50(pretrained=False) # deprecated
resnet50(False) # deprecated
请注意,pretrained
参数现在已弃用,使用它将发出警告,并在 v0.15 中删除。
使用预训练模型¶
在使用预训练模型之前,必须预处理图像(使用正确的分辨率/插值调整大小,应用推理转换,重新缩放值等)。没有标准的方法来执行此操作,因为它取决于给定模型的训练方式。它可能在模型族、变体甚至权重版本之间有所不同。使用正确的预处理方法至关重要,如果失败可能会导致准确性降低或输出错误。
每个预训练模型的推理转换所需的所有信息都提供在其权重文档中。为了简化推理,TorchVision 将必要的预处理转换捆绑到每个模型权重中。这些可以通过 weight.transforms
属性访问
# Initialize the Weight Transforms
weights = ResNet50_Weights.DEFAULT
preprocess = weights.transforms()
# Apply it to the input image
img_transformed = preprocess(img)
某些模型使用具有不同训练和评估行为的模块,例如批处理归一化。要在这两种模式之间切换,请根据需要使用 model.train()
或 model.eval()
。有关详细信息,请参阅 train()
或 eval()
。
# Initialize model
weights = ResNet50_Weights.DEFAULT
model = resnet50(weights=weights)
# Set model to eval mode
model.eval()
列出和检索可用模型¶
从 v0.14 开始,TorchVision 提供了一种新的机制,允许按名称列出和检索模型和权重。以下是一些关于如何使用它们的示例
# List available models
all_models = list_models()
classification_models = list_models(module=torchvision.models)
# Initialize models
m1 = get_model("mobilenet_v3_large", weights=None)
m2 = get_model("quantized_mobilenet_v3_large", weights="DEFAULT")
# Fetch weights
weights = get_weight("MobileNet_V3_Large_QuantizedWeights.DEFAULT")
assert weights == MobileNet_V3_Large_QuantizedWeights.DEFAULT
weights_enum = get_model_weights("quantized_mobilenet_v3_large")
assert weights_enum == MobileNet_V3_Large_QuantizedWeights
weights_enum2 = get_model_weights(torchvision.models.quantization.mobilenet_v3_large)
assert weights_enum == weights_enum2
以下是用于检索模型及其相应权重的可用公共函数
|
获取模型名称和配置,并返回一个实例化的模型。 |
|
返回与给定模型关联的权重枚举类。 |
|
通过其完整名称获取权重枚举值。 |
|
返回一个包含已注册模型名称的列表。 |
使用来自 Hub 的模型¶
大多数预训练模型可以通过 PyTorch Hub 直接访问,无需安装 TorchVision
import torch
# Option 1: passing weights param as string
model = torch.hub.load("pytorch/vision", "resnet50", weights="IMAGENET1K_V2")
# Option 2: passing weights param as enum
weights = torch.hub.load("pytorch/vision", "get_weight", weights="ResNet50_Weights.IMAGENET1K_V2")
model = torch.hub.load("pytorch/vision", "resnet50", weights=weights)
您还可以通过 PyTorch Hub 检索特定模型的所有可用权重,方法是执行
import torch
weight_enum = torch.hub.load("pytorch/vision", "get_model_weights", name="resnet50")
print([weight for weight in weight_enum])
上述的唯一例外是 torchvision.models.detection
中包含的检测模型。这些模型需要安装 TorchVision,因为它们依赖于自定义 C++ 运算符。
分类¶
以下分类模型可用,有或没有预训练权重
以下是如何使用预训练的图像分类模型的示例
from torchvision.io import decode_image
from torchvision.models import resnet50, ResNet50_Weights
img = decode_image("test/assets/encode_jpeg/grace_hopper_517x606.jpg")
# Step 1: Initialize model with the best available weights
weights = ResNet50_Weights.DEFAULT
model = resnet50(weights=weights)
model.eval()
# Step 2: Initialize the inference transforms
preprocess = weights.transforms()
# Step 3: Apply inference preprocessing transforms
batch = preprocess(img).unsqueeze(0)
# Step 4: Use the model and print the predicted category
prediction = model(batch).squeeze(0).softmax(0)
class_id = prediction.argmax().item()
score = prediction[class_id].item()
category_name = weights.meta["categories"][class_id]
print(f"{category_name}: {100 * score:.1f}%")
预训练模型输出的类可以在 weights.meta["categories"]
中找到。
所有可用分类权重的表格¶
使用单个裁剪在 ImageNet-1K 上报告的准确性
权重 |
Acc@1 |
Acc@5 |
参数 |
GFLOPS |
食谱 |
---|---|---|---|---|---|
56.522 |
79.066 |
61.1M |
0.71 |
||
84.062 |
96.87 |
88.6M |
15.36 |
||
84.414 |
96.976 |
197.8M |
34.36 |
||
83.616 |
96.65 |
50.2M |
8.68 |
||
82.52 |
96.146 |
28.6M |
4.46 |
||
74.434 |
91.972 |
8.0M |
2.83 |
||
77.138 |
93.56 |
28.7M |
7.73 |
||
75.6 |
92.806 |
14.1M |
3.36 |
||
76.896 |
93.37 |
20.0M |
4.29 |
||
77.692 |
93.532 |
5.3M |
0.39 |
||
78.642 |
94.186 |
7.8M |
0.69 |
||
79.838 |
94.934 |
7.8M |
0.69 |
||
80.608 |
95.31 |
9.1M |
1.09 |
||
82.008 |
96.054 |
12.2M |
1.83 |
||
83.384 |
96.594 |
19.3M |
4.39 |
||
83.444 |
96.628 |
30.4M |
10.27 |
||
84.008 |
96.916 |
43.0M |
19.07 |
||
84.122 |
96.908 |
66.3M |
37.75 |
||
85.808 |
97.788 |
118.5M |
56.08 |
||
85.112 |
97.156 |
54.1M |
24.58 |
||
84.228 |
96.878 |
21.5M |
8.37 |
||
69.778 |
89.53 |
6.6M |
1.5 |
||
77.294 |
93.45 |
27.2M |
5.71 |
||
67.734 |
87.49 |
2.2M |
0.1 |
||
71.18 |
90.496 |
3.2M |
0.21 |
||
73.456 |
91.51 |
4.4M |
0.31 |
||
76.506 |
93.522 |
6.3M |
0.53 |
||
83.7 |
96.722 |
30.9M |
5.56 |
||
71.878 |
90.286 |
3.5M |
0.3 |
||
72.154 |
90.822 |
3.5M |
0.3 |
||
74.042 |
91.34 |
5.5M |
0.22 |
||
75.274 |
92.566 |
5.5M |
0.22 |
||
67.668 |
87.402 |
2.5M |
0.06 |
||
80.058 |
94.944 |
54.3M |
15.94 |
||
82.716 |
96.196 |
54.3M |
15.94 |
||
77.04 |
93.44 |
9.2M |
1.6 |
||
79.668 |
94.922 |
9.2M |
1.6 |
||
80.622 |
95.248 |
107.8M |
31.74 |
||
83.014 |
96.288 |
107.8M |
31.74 |
||
78.364 |
93.992 |
15.3M |
3.18 |
||
81.196 |
95.43 |
15.3M |
3.18 |
||
72.834 |
90.95 |
5.5M |
0.41 |
||
74.864 |
92.322 |
5.5M |
0.41 |
||
75.212 |
92.348 |
7.3M |
0.8 |
||
77.522 |
93.826 |
7.3M |
0.8 |
||
79.344 |
94.686 |
39.6M |
8 |
||
81.682 |
95.678 |
39.6M |
8 |
||
88.228 |
98.682 |
644.8M |
374.57 |
||
86.068 |
97.844 |
644.8M |
127.52 |
||
80.424 |
95.24 |
83.6M |
15.91 |
||
82.886 |
96.328 |
83.6M |
15.91 |
||
86.012 |
98.054 |
83.6M |
46.73 |
||
83.976 |
97.244 |
83.6M |
15.91 |
||
77.95 |
93.966 |
11.2M |
1.61 |
||
80.876 |
95.444 |
11.2M |
1.61 |
||
80.878 |
95.34 |
145.0M |
32.28 |
||
83.368 |
96.498 |
145.0M |
32.28 |
||
86.838 |
98.362 |
145.0M |
94.83 |
||
84.622 |
97.48 |
145.0M |
32.28 |
||
78.948 |
94.576 |
19.4M |
3.18 |
||
81.982 |
95.972 |
19.4M |
3.18 |
||
74.046 |
91.716 |
4.3M |
0.4 |
||
75.804 |
92.742 |
4.3M |
0.4 |
||
76.42 |
93.136 |
6.4M |
0.83 |
||
78.828 |
94.502 |
6.4M |
0.83 |
||
80.032 |
95.048 |
39.4M |
8.47 |
||
82.828 |
96.33 |
39.4M |
8.47 |
||
79.312 |
94.526 |
88.8M |
16.41 |
||
82.834 |
96.228 |
88.8M |
16.41 |
||
83.246 |
96.454 |
83.5M |
15.46 |
||
77.618 |
93.698 |
25.0M |
4.23 |
||
81.198 |
95.34 |
25.0M |
4.23 |
||
77.374 |
93.546 |
44.5M |
7.8 |
||
81.886 |
95.78 |
44.5M |
7.8 |
||
78.312 |
94.046 |
60.2M |
11.51 |
||
82.284 |
96.002 |
60.2M |
11.51 |
||
69.758 |
89.078 |
11.7M |
1.81 |
||
73.314 |
91.42 |
21.8M |
3.66 |
||
76.13 |
92.862 |
25.6M |
4.09 |
||
80.858 |
95.434 |
25.6M |
4.09 |
||
60.552 |
81.746 |
1.4M |
0.04 |
||
69.362 |
88.316 |
2.3M |
0.14 |
||
72.996 |
91.086 |
3.5M |
0.3 |
||
76.23 |
93.006 |
7.4M |
0.58 |
||
58.092 |
80.42 |
1.2M |
0.82 |
||
58.178 |
80.624 |
1.2M |
0.35 |
||
83.582 |
96.64 |
87.8M |
15.43 |
||
83.196 |
96.36 |
49.6M |
8.74 |
||
81.474 |
95.776 |
28.3M |
4.49 |
||
84.112 |
96.864 |
87.9M |
20.32 |
||
83.712 |
96.816 |
49.7M |
11.55 |
||
82.072 |
96.132 |
28.4M |
5.94 |
||
70.37 |
89.81 |
132.9M |
7.61 |
||
69.02 |
88.628 |
132.9M |
7.61 |
||
71.586 |
90.374 |
133.1M |
11.31 |
||
69.928 |
89.246 |
133.0M |
11.31 |
||
73.36 |
91.516 |
138.4M |
15.47 |
||
71.592 |
90.382 |
138.4M |
15.47 |
||
nan |
nan |
138.4M |
15.47 |
||
74.218 |
91.842 |
143.7M |
19.63 |
||
72.376 |
90.876 |
143.7M |
19.63 |
||
81.072 |
95.318 |
86.6M |
17.56 |
||
85.304 |
97.65 |
86.9M |
55.48 |
||
81.886 |
96.18 |
86.6M |
17.56 |
||
75.912 |
92.466 |
88.2M |
4.41 |
||
88.552 |
98.694 |
633.5M |
1016.72 |
||
85.708 |
97.73 |
632.0M |
167.29 |
||
79.662 |
94.638 |
304.3M |
61.55 |
||
88.064 |
98.512 |
305.2M |
361.99 |
||
85.146 |
97.422 |
304.3M |
61.55 |
||
76.972 |
93.07 |
306.5M |
15.38 |
||
78.848 |
94.284 |
126.9M |
22.75 |
||
82.51 |
96.02 |
126.9M |
22.75 |
||
78.468 |
94.086 |
68.9M |
11.4 |
||
81.602 |
95.758 |
68.9M |
11.4 |
量化模型¶
以下架构支持 INT8 量化模型,无论是否使用预训练权重
以下是如何使用预训练的量化图像分类模型的示例
from torchvision.io import decode_image
from torchvision.models.quantization import resnet50, ResNet50_QuantizedWeights
img = decode_image("test/assets/encode_jpeg/grace_hopper_517x606.jpg")
# Step 1: Initialize model with the best available weights
weights = ResNet50_QuantizedWeights.DEFAULT
model = resnet50(weights=weights, quantize=True)
model.eval()
# Step 2: Initialize the inference transforms
preprocess = weights.transforms()
# Step 3: Apply inference preprocessing transforms
batch = preprocess(img).unsqueeze(0)
# Step 4: Use the model and print the predicted category
prediction = model(batch).squeeze(0).softmax(0)
class_id = prediction.argmax().item()
score = prediction[class_id].item()
category_name = weights.meta["categories"][class_id]
print(f"{category_name}: {100 * score}%")
预训练模型输出的类可以在 weights.meta["categories"]
中找到。
所有可用量化分类权重表¶
使用单个裁剪在 ImageNet-1K 上报告的准确性
权重 |
Acc@1 |
Acc@5 |
参数 |
GIPS |
食谱 |
---|---|---|---|---|---|
69.826 |
89.404 |
6.6M |
1.5 |
||
77.176 |
93.354 |
27.2M |
5.71 |
||
71.658 |
90.15 |
3.5M |
0.3 |
||
73.004 |
90.858 |
5.5M |
0.22 |
||
78.986 |
94.48 |
88.8M |
16.41 |
||
82.574 |
96.132 |
88.8M |
16.41 |
||
82.898 |
96.326 |
83.5M |
15.46 |
||
69.494 |
88.882 |
11.7M |
1.81 |
||
75.92 |
92.814 |
25.6M |
4.09 |
||
80.282 |
94.976 |
25.6M |
4.09 |
||
57.972 |
79.78 |
1.4M |
0.04 |
||
68.36 |
87.582 |
2.3M |
0.14 |
||
72.052 |
90.7 |
3.5M |
0.3 |
||
75.354 |
92.488 |
7.4M |
0.58 |
语义分割¶
警告
分割模块处于 Beta 阶段,不保证向后兼容性。
以下语义分割模型可用,无论是否使用预训练权重
以下是如何使用预训练的语义分割模型的示例
from torchvision.io.image import decode_image
from torchvision.models.segmentation import fcn_resnet50, FCN_ResNet50_Weights
from torchvision.transforms.functional import to_pil_image
img = decode_image("gallery/assets/dog1.jpg")
# Step 1: Initialize model with the best available weights
weights = FCN_ResNet50_Weights.DEFAULT
model = fcn_resnet50(weights=weights)
model.eval()
# Step 2: Initialize the inference transforms
preprocess = weights.transforms()
# Step 3: Apply inference preprocessing transforms
batch = preprocess(img).unsqueeze(0)
# Step 4: Use the model and visualize the prediction
prediction = model(batch)["out"]
normalized_masks = prediction.softmax(dim=1)
class_to_idx = {cls: idx for (idx, cls) in enumerate(weights.meta["categories"])}
mask = normalized_masks[0, class_to_idx["dog"]]
to_pil_image(mask).show()
预训练模型输出的类别可以在 weights.meta["categories"]
中找到。模型的输出格式在 语义分割模型 中说明。
所有可用语义分割权重表¶
所有模型都在 COCO val2017 的子集上进行评估,该子集包含 Pascal VOC 数据集中存在的 20 个类别
权重 |
平均 IoU |
逐像素精度 |
参数 |
GFLOPS |
食谱 |
---|---|---|---|---|---|
|
60.3 |
91.2 |
11.0M |
10.45 |
|
67.4 |
92.4 |
61.0M |
258.74 |
||
66.4 |
92.4 |
42.0M |
178.72 |
||
63.7 |
91.9 |
54.3M |
232.74 |
||
60.5 |
91.4 |
35.3M |
152.72 |
||
57.9 |
91.2 |
3.2M |
2.09 |
目标检测、实例分割和人物关键点检测¶
用于检测、实例分割和关键点检测的预训练模型使用 torchvision 中的分类模型进行初始化。模型期望一个 Tensor[C, H, W]
列表。有关更多信息,请查看模型的构造函数。
警告
检测模块处于 Beta 阶段,不保证向后兼容性。
目标检测¶
以下目标检测模型可用,无论是否使用预训练权重
以下是如何使用预训练的目标检测模型的示例
from torchvision.io.image import decode_image
from torchvision.models.detection import fasterrcnn_resnet50_fpn_v2, FasterRCNN_ResNet50_FPN_V2_Weights
from torchvision.utils import draw_bounding_boxes
from torchvision.transforms.functional import to_pil_image
img = decode_image("test/assets/encode_jpeg/grace_hopper_517x606.jpg")
# Step 1: Initialize model with the best available weights
weights = FasterRCNN_ResNet50_FPN_V2_Weights.DEFAULT
model = fasterrcnn_resnet50_fpn_v2(weights=weights, box_score_thresh=0.9)
model.eval()
# Step 2: Initialize the inference transforms
preprocess = weights.transforms()
# Step 3: Apply inference preprocessing transforms
batch = [preprocess(img)]
# Step 4: Use the model and visualize the prediction
prediction = model(batch)[0]
labels = [weights.meta["categories"][i] for i in prediction["labels"]]
box = draw_bounding_boxes(img, boxes=prediction["boxes"],
labels=labels,
colors="red",
width=4, font_size=30)
im = to_pil_image(box.detach())
im.show()
预训练模型输出的类别可以在 weights.meta["categories"]
中找到。有关如何绘制模型边界框的详细信息,您可以参考 实例分割模型。
所有可用目标检测权重表¶
边界框 MAP 在 COCO val2017 上报告
权重 |
边界框 MAP |
参数 |
GFLOPS |
食谱 |
---|---|---|---|---|
39.2 |
32.3M |
128.21 |
||
22.8 |
19.4M |
0.72 |
||
32.8 |
19.4M |
4.49 |
||
46.7 |
43.7M |
280.37 |
||
37 |
41.8M |
134.38 |
||
41.5 |
38.2M |
152.24 |
||
36.4 |
34.0M |
151.54 |
||
25.1 |
35.6M |
34.86 |
||
21.3 |
3.4M |
0.58 |
实例分割¶
以下实例分割模型可用,无论是否使用预训练权重
有关如何绘制模型掩码的详细信息,您可以参考 实例分割模型。
所有可用实例分割权重表¶
边界框和掩码 MAP 在 COCO val2017 上报告
权重 |
边界框 MAP |
掩码 MAP |
参数 |
GFLOPS |
食谱 |
---|---|---|---|---|---|
47.4 |
41.8 |
46.4M |
333.58 |
||
37.9 |
34.6 |
44.4M |
134.38 |
关键点检测¶
以下人物关键点检测模型可用,无论是否使用预训练权重
预训练模型输出的类别可以在 weights.meta["keypoint_names"]
中找到。有关如何绘制模型边界框的详细信息,您可以参考 可视化关键点。
所有可用关键点检测权重表¶
边界框和关键点 MAP 在 COCO val2017 上报告
权重 |
边界框 MAP |
关键点 MAP |
参数 |
GFLOPS |
食谱 |
---|---|---|---|---|---|
50.6 |
61.1 |
59.1M |
133.92 |
||
54.6 |
65 |
59.1M |
137.42 |
视频分类¶
警告
视频模块处于 Beta 阶段,不保证向后兼容性。
以下视频分类模型可用,无论是否使用预训练权重
以下是如何使用预训练的视频分类模型的示例
from torchvision.io.video import read_video
from torchvision.models.video import r3d_18, R3D_18_Weights
vid, _, _ = read_video("test/assets/videos/v_SoccerJuggling_g23_c01.avi", output_format="TCHW")
vid = vid[:32] # optionally shorten duration
# Step 1: Initialize model with the best available weights
weights = R3D_18_Weights.DEFAULT
model = r3d_18(weights=weights)
model.eval()
# Step 2: Initialize the inference transforms
preprocess = weights.transforms()
# Step 3: Apply inference preprocessing transforms
batch = preprocess(vid).unsqueeze(0)
# Step 4: Use the model and print the predicted category
prediction = model(batch).squeeze(0).softmax(0)
label = prediction.argmax().item()
score = prediction[label].item()
category_name = weights.meta["categories"][label]
print(f"{category_name}: {100 * score}%")
预训练模型输出的类可以在 weights.meta["categories"]
中找到。
所有可用视频分类权重表¶
准确率在 Kinetics-400 上报告,使用单裁剪,剪辑长度为 16
权重 |
Acc@1 |
Acc@5 |
参数 |
GFLOPS |
食谱 |
---|---|---|---|---|---|
63.96 |
84.13 |
11.7M |
43.34 |
||
78.477 |
93.582 |
36.6M |
70.6 |
||
80.757 |
94.665 |
34.5M |
64.22 |
||
67.463 |
86.175 |
31.5M |
40.52 |
||
63.2 |
83.479 |
33.4M |
40.7 |
||
68.368 |
88.05 |
8.3M |
17.98 |
||
79.427 |
94.386 |
88.0M |
140.67 |
||
81.643 |
95.574 |
88.0M |
140.67 |
||
79.521 |
94.158 |
49.8M |
82.84 |
||
77.715 |
93.519 |
28.2M |
43.88 |
光流¶
以下光流模型可用,无论是否使用预训练