• 文档 >
  • 模型和预训练权重
快捷方式

模型和预训练权重

The torchvision.models 子包包含用于解决不同任务的模型定义,包括:图像分类、像素级语义分割、目标检测、实例分割、人物关键点检测、视频分类和光流。

有关预训练权重的常规信息

TorchVision 为每个提供的架构提供预训练权重,使用 PyTorch torch.hub。实例化预训练模型将将其权重下载到缓存目录。可以使用 TORCH_HOME 环境变量设置此目录。有关详细信息,请参阅 torch.hub.load_state_dict_from_url()

注意

此库中提供的预训练模型可能拥有自己的许可证或条款和条件,这些条款和条件来自用于训练的数据集。您有责任确定您是否有权将模型用于您的用例。

注意

向后兼容性保证可以将序列化 state_dict 加载到使用旧版 PyTorch 版本创建的模型中。相反,加载整个保存的模型或序列化 ScriptModules(使用旧版 PyTorch 版本序列化)可能不会保留历史行为。请参阅以下 文档

初始化预训练模型

从 v0.13 开始,TorchVision 提供了新的 多权重支持 API,用于将不同的权重加载到现有的模型构建器方法中

from torchvision.models import resnet50, ResNet50_Weights

# Old weights with accuracy 76.130%
resnet50(weights=ResNet50_Weights.IMAGENET1K_V1)

# New weights with accuracy 80.858%
resnet50(weights=ResNet50_Weights.IMAGENET1K_V2)

# Best available weights (currently alias for IMAGENET1K_V2)
# Note that these weights may change across versions
resnet50(weights=ResNet50_Weights.DEFAULT)

# Strings are also supported
resnet50(weights="IMAGENET1K_V2")

# No weights - random initialization
resnet50(weights=None)

迁移到新的 API 非常简单。以下两个 API 之间的调用方法都等效

from torchvision.models import resnet50, ResNet50_Weights

# Using pretrained weights:
resnet50(weights=ResNet50_Weights.IMAGENET1K_V1)
resnet50(weights="IMAGENET1K_V1")
resnet50(pretrained=True)  # deprecated
resnet50(True)  # deprecated

# Using no weights:
resnet50(weights=None)
resnet50()
resnet50(pretrained=False)  # deprecated
resnet50(False)  # deprecated

请注意,pretrained 参数现在已弃用,使用它将发出警告,并在 v0.15 中删除。

使用预训练模型

在使用预训练模型之前,必须预处理图像(使用正确的分辨率/插值调整大小,应用推理转换,重新缩放值等)。没有标准的方法来执行此操作,因为它取决于给定模型的训练方式。它可能在模型族、变体甚至权重版本之间有所不同。使用正确的预处理方法至关重要,如果失败可能会导致准确性降低或输出错误。

每个预训练模型的推理转换所需的所有信息都提供在其权重文档中。为了简化推理,TorchVision 将必要的预处理转换捆绑到每个模型权重中。这些可以通过 weight.transforms 属性访问

# Initialize the Weight Transforms
weights = ResNet50_Weights.DEFAULT
preprocess = weights.transforms()

# Apply it to the input image
img_transformed = preprocess(img)

某些模型使用具有不同训练和评估行为的模块,例如批处理归一化。要在这两种模式之间切换,请根据需要使用 model.train()model.eval()。有关详细信息,请参阅 train()eval()

# Initialize model
weights = ResNet50_Weights.DEFAULT
model = resnet50(weights=weights)

# Set model to eval mode
model.eval()

列出和检索可用模型

从 v0.14 开始,TorchVision 提供了一种新的机制,允许按名称列出和检索模型和权重。以下是一些关于如何使用它们的示例

# List available models
all_models = list_models()
classification_models = list_models(module=torchvision.models)

# Initialize models
m1 = get_model("mobilenet_v3_large", weights=None)
m2 = get_model("quantized_mobilenet_v3_large", weights="DEFAULT")

# Fetch weights
weights = get_weight("MobileNet_V3_Large_QuantizedWeights.DEFAULT")
assert weights == MobileNet_V3_Large_QuantizedWeights.DEFAULT

weights_enum = get_model_weights("quantized_mobilenet_v3_large")
assert weights_enum == MobileNet_V3_Large_QuantizedWeights

weights_enum2 = get_model_weights(torchvision.models.quantization.mobilenet_v3_large)
assert weights_enum == weights_enum2

以下是用于检索模型及其相应权重的可用公共函数

get_model(name, **config)

获取模型名称和配置,并返回一个实例化的模型。

get_model_weights(name)

返回与给定模型关联的权重枚举类。

get_weight(name)

通过其完整名称获取权重枚举值。

list_models([module, include, exclude])

返回一个包含已注册模型名称的列表。

使用来自 Hub 的模型

大多数预训练模型可以通过 PyTorch Hub 直接访问,无需安装 TorchVision

import torch

# Option 1: passing weights param as string
model = torch.hub.load("pytorch/vision", "resnet50", weights="IMAGENET1K_V2")

# Option 2: passing weights param as enum
weights = torch.hub.load("pytorch/vision", "get_weight", weights="ResNet50_Weights.IMAGENET1K_V2")
model = torch.hub.load("pytorch/vision", "resnet50", weights=weights)

您还可以通过 PyTorch Hub 检索特定模型的所有可用权重,方法是执行

import torch

weight_enum = torch.hub.load("pytorch/vision", "get_model_weights", name="resnet50")
print([weight for weight in weight_enum])

上述的唯一例外是 torchvision.models.detection 中包含的检测模型。这些模型需要安装 TorchVision,因为它们依赖于自定义 C++ 运算符。

分类

以下分类模型可用,有或没有预训练权重


以下是如何使用预训练的图像分类模型的示例

from torchvision.io import decode_image
from torchvision.models import resnet50, ResNet50_Weights

img = decode_image("test/assets/encode_jpeg/grace_hopper_517x606.jpg")

# Step 1: Initialize model with the best available weights
weights = ResNet50_Weights.DEFAULT
model = resnet50(weights=weights)
model.eval()

# Step 2: Initialize the inference transforms
preprocess = weights.transforms()

# Step 3: Apply inference preprocessing transforms
batch = preprocess(img).unsqueeze(0)

# Step 4: Use the model and print the predicted category
prediction = model(batch).squeeze(0).softmax(0)
class_id = prediction.argmax().item()
score = prediction[class_id].item()
category_name = weights.meta["categories"][class_id]
print(f"{category_name}: {100 * score:.1f}%")

预训练模型输出的类可以在 weights.meta["categories"] 中找到。

所有可用分类权重的表格

使用单个裁剪在 ImageNet-1K 上报告的准确性

权重

Acc@1

Acc@5

参数

GFLOPS

食谱

AlexNet_Weights.IMAGENET1K_V1

56.522

79.066

61.1M

0.71

link

ConvNeXt_Base_Weights.IMAGENET1K_V1

84.062

96.87

88.6M

15.36

link

ConvNeXt_Large_Weights.IMAGENET1K_V1

84.414

96.976

197.8M

34.36

link

ConvNeXt_Small_Weights.IMAGENET1K_V1

83.616

96.65

50.2M

8.68

link

ConvNeXt_Tiny_Weights.IMAGENET1K_V1

82.52

96.146

28.6M

4.46

link

DenseNet121_Weights.IMAGENET1K_V1

74.434

91.972

8.0M

2.83

link

DenseNet161_Weights.IMAGENET1K_V1

77.138

93.56

28.7M

7.73

link

DenseNet169_Weights.IMAGENET1K_V1

75.6

92.806

14.1M

3.36

link

DenseNet201_Weights.IMAGENET1K_V1

76.896

93.37

20.0M

4.29

link

EfficientNet_B0_Weights.IMAGENET1K_V1

77.692

93.532

5.3M

0.39

link

EfficientNet_B1_Weights.IMAGENET1K_V1

78.642

94.186

7.8M

0.69

link

EfficientNet_B1_Weights.IMAGENET1K_V2

79.838

94.934

7.8M

0.69

link

EfficientNet_B2_Weights.IMAGENET1K_V1

80.608

95.31

9.1M

1.09

link

EfficientNet_B3_Weights.IMAGENET1K_V1

82.008

96.054

12.2M

1.83

link

EfficientNet_B4_Weights.IMAGENET1K_V1

83.384

96.594

19.3M

4.39

link

EfficientNet_B5_Weights.IMAGENET1K_V1

83.444

96.628

30.4M

10.27

link

EfficientNet_B6_Weights.IMAGENET1K_V1

84.008

96.916

43.0M

19.07

link

EfficientNet_B7_Weights.IMAGENET1K_V1

84.122

96.908

66.3M

37.75

link

EfficientNet_V2_L_Weights.IMAGENET1K_V1

85.808

97.788

118.5M

56.08

link

EfficientNet_V2_M_Weights.IMAGENET1K_V1

85.112

97.156

54.1M

24.58

link

EfficientNet_V2_S_Weights.IMAGENET1K_V1

84.228

96.878

21.5M

8.37

link

GoogLeNet_Weights.IMAGENET1K_V1

69.778

89.53

6.6M

1.5

link

Inception_V3_Weights.IMAGENET1K_V1

77.294

93.45

27.2M

5.71

link

MNASNet0_5_Weights.IMAGENET1K_V1

67.734

87.49

2.2M

0.1

link

MNASNet0_75_Weights.IMAGENET1K_V1

71.18

90.496

3.2M

0.21

link

MNASNet1_0_Weights.IMAGENET1K_V1

73.456

91.51

4.4M

0.31

link

MNASNet1_3_Weights.IMAGENET1K_V1

76.506

93.522

6.3M

0.53

link

MaxVit_T_Weights.IMAGENET1K_V1

83.7

96.722

30.9M

5.56

link

MobileNet_V2_Weights.IMAGENET1K_V1

71.878

90.286

3.5M

0.3

link

MobileNet_V2_Weights.IMAGENET1K_V2

72.154

90.822

3.5M

0.3

link

MobileNet_V3_Large_Weights.IMAGENET1K_V1

74.042

91.34

5.5M

0.22

link

MobileNet_V3_Large_Weights.IMAGENET1K_V2

75.274

92.566

5.5M

0.22

link

MobileNet_V3_Small_Weights.IMAGENET1K_V1

67.668

87.402

2.5M

0.06

link

RegNet_X_16GF_Weights.IMAGENET1K_V1

80.058

94.944

54.3M

15.94

link

RegNet_X_16GF_Weights.IMAGENET1K_V2

82.716

96.196

54.3M

15.94

link

RegNet_X_1_6GF_Weights.IMAGENET1K_V1

77.04

93.44

9.2M

1.6

link

RegNet_X_1_6GF_Weights.IMAGENET1K_V2

79.668

94.922

9.2M

1.6

link

RegNet_X_32GF_Weights.IMAGENET1K_V1

80.622

95.248

107.8M

31.74

link

RegNet_X_32GF_Weights.IMAGENET1K_V2

83.014

96.288

107.8M

31.74

link

RegNet_X_3_2GF_Weights.IMAGENET1K_V1

78.364

93.992

15.3M

3.18

link

RegNet_X_3_2GF_Weights.IMAGENET1K_V2

81.196

95.43

15.3M

3.18

link

RegNet_X_400MF_Weights.IMAGENET1K_V1

72.834

90.95

5.5M

0.41

link

RegNet_X_400MF_Weights.IMAGENET1K_V2

74.864

92.322

5.5M

0.41

link

RegNet_X_800MF_Weights.IMAGENET1K_V1

75.212

92.348

7.3M

0.8

link

RegNet_X_800MF_Weights.IMAGENET1K_V2

77.522

93.826

7.3M

0.8

link

RegNet_X_8GF_Weights.IMAGENET1K_V1

79.344

94.686

39.6M

8

link

RegNet_X_8GF_Weights.IMAGENET1K_V2

81.682

95.678

39.6M

8

link

RegNet_Y_128GF_Weights.IMAGENET1K_SWAG_E2E_V1

88.228

98.682

644.8M

374.57

link

RegNet_Y_128GF_Weights.IMAGENET1K_SWAG_LINEAR_V1

86.068

97.844

644.8M

127.52

link

RegNet_Y_16GF_Weights.IMAGENET1K_V1

80.424

95.24

83.6M

15.91

link

RegNet_Y_16GF_Weights.IMAGENET1K_V2

82.886

96.328

83.6M

15.91

link

RegNet_Y_16GF_Weights.IMAGENET1K_SWAG_E2E_V1

86.012

98.054

83.6M

46.73

link

RegNet_Y_16GF_Weights.IMAGENET1K_SWAG_LINEAR_V1

83.976

97.244

83.6M

15.91

link

RegNet_Y_1_6GF_Weights.IMAGENET1K_V1

77.95

93.966

11.2M

1.61

link

RegNet_Y_1_6GF_Weights.IMAGENET1K_V2

80.876

95.444

11.2M

1.61

link

RegNet_Y_32GF_Weights.IMAGENET1K_V1

80.878

95.34

145.0M

32.28

link

RegNet_Y_32GF_Weights.IMAGENET1K_V2

83.368

96.498

145.0M

32.28

link

RegNet_Y_32GF_Weights.IMAGENET1K_SWAG_E2E_V1

86.838

98.362

145.0M

94.83

link

RegNet_Y_32GF_Weights.IMAGENET1K_SWAG_LINEAR_V1

84.622

97.48

145.0M

32.28

link

RegNet_Y_3_2GF_Weights.IMAGENET1K_V1

78.948

94.576

19.4M

3.18

link

RegNet_Y_3_2GF_Weights.IMAGENET1K_V2

81.982

95.972

19.4M

3.18

link

RegNet_Y_400MF_Weights.IMAGENET1K_V1

74.046

91.716

4.3M

0.4

link

RegNet_Y_400MF_Weights.IMAGENET1K_V2

75.804

92.742

4.3M

0.4

link

RegNet_Y_800MF_Weights.IMAGENET1K_V1

76.42

93.136

6.4M

0.83

link

RegNet_Y_800MF_Weights.IMAGENET1K_V2

78.828

94.502

6.4M

0.83

link

RegNet_Y_8GF_Weights.IMAGENET1K_V1

80.032

95.048

39.4M

8.47

link

RegNet_Y_8GF_Weights.IMAGENET1K_V2

82.828

96.33

39.4M

8.47

link

ResNeXt101_32X8D_Weights.IMAGENET1K_V1

79.312

94.526

88.8M

16.41

link

ResNeXt101_32X8D_Weights.IMAGENET1K_V2

82.834

96.228

88.8M

16.41

link

ResNeXt101_64X4D_Weights.IMAGENET1K_V1

83.246

96.454

83.5M

15.46

link

ResNeXt50_32X4D_Weights.IMAGENET1K_V1

77.618

93.698

25.0M

4.23

link

ResNeXt50_32X4D_Weights.IMAGENET1K_V2

81.198

95.34

25.0M

4.23

link

ResNet101_Weights.IMAGENET1K_V1

77.374

93.546

44.5M

7.8

link

ResNet101_Weights.IMAGENET1K_V2

81.886

95.78

44.5M

7.8

link

ResNet152_Weights.IMAGENET1K_V1

78.312

94.046

60.2M

11.51

link

ResNet152_Weights.IMAGENET1K_V2

82.284

96.002

60.2M

11.51

link

ResNet18_Weights.IMAGENET1K_V1

69.758

89.078

11.7M

1.81

link

ResNet34_Weights.IMAGENET1K_V1

73.314

91.42

21.8M

3.66

link

ResNet50_Weights.IMAGENET1K_V1

76.13

92.862

25.6M

4.09

link

ResNet50_Weights.IMAGENET1K_V2

80.858

95.434

25.6M

4.09

link

ShuffleNet_V2_X0_5_Weights.IMAGENET1K_V1

60.552

81.746

1.4M

0.04

link

ShuffleNet_V2_X1_0_Weights.IMAGENET1K_V1

69.362

88.316

2.3M

0.14

link

ShuffleNet_V2_X1_5_Weights.IMAGENET1K_V1

72.996

91.086

3.5M

0.3

link

ShuffleNet_V2_X2_0_Weights.IMAGENET1K_V1

76.23

93.006

7.4M

0.58

link

SqueezeNet1_0_Weights.IMAGENET1K_V1

58.092

80.42

1.2M

0.82

link

SqueezeNet1_1_Weights.IMAGENET1K_V1

58.178

80.624

1.2M

0.35

link

Swin_B_Weights.IMAGENET1K_V1

83.582

96.64

87.8M

15.43

link

Swin_S_Weights.IMAGENET1K_V1

83.196

96.36

49.6M

8.74

link

Swin_T_Weights.IMAGENET1K_V1

81.474

95.776

28.3M

4.49

link

Swin_V2_B_Weights.IMAGENET1K_V1

84.112

96.864

87.9M

20.32

link

Swin_V2_S_Weights.IMAGENET1K_V1

83.712

96.816

49.7M

11.55

link

Swin_V2_T_Weights.IMAGENET1K_V1

82.072

96.132

28.4M

5.94

link

VGG11_BN_Weights.IMAGENET1K_V1

70.37

89.81

132.9M

7.61

link

VGG11_Weights.IMAGENET1K_V1

69.02

88.628

132.9M

7.61

link

VGG13_BN_Weights.IMAGENET1K_V1

71.586

90.374

133.1M

11.31

link

VGG13_Weights.IMAGENET1K_V1

69.928

89.246

133.0M

11.31

link

VGG16_BN_Weights.IMAGENET1K_V1

73.36

91.516

138.4M

15.47

link

VGG16_Weights.IMAGENET1K_V1

71.592

90.382

138.4M

15.47

link

VGG16_Weights.IMAGENET1K_FEATURES

nan

nan

138.4M

15.47

link

VGG19_BN_Weights.IMAGENET1K_V1

74.218

91.842

143.7M

19.63

link

VGG19_Weights.IMAGENET1K_V1

72.376

90.876

143.7M

19.63

link

ViT_B_16_Weights.IMAGENET1K_V1

81.072

95.318

86.6M

17.56

link

ViT_B_16_Weights.IMAGENET1K_SWAG_E2E_V1

85.304

97.65

86.9M

55.48

link

ViT_B_16_Weights.IMAGENET1K_SWAG_LINEAR_V1

81.886

96.18

86.6M

17.56

link

ViT_B_32_Weights.IMAGENET1K_V1

75.912

92.466

88.2M

4.41

link

ViT_H_14_Weights.IMAGENET1K_SWAG_E2E_V1

88.552

98.694

633.5M

1016.72

link

ViT_H_14_Weights.IMAGENET1K_SWAG_LINEAR_V1

85.708

97.73

632.0M

167.29

link

ViT_L_16_Weights.IMAGENET1K_V1

79.662

94.638

304.3M

61.55

link

ViT_L_16_Weights.IMAGENET1K_SWAG_E2E_V1

88.064

98.512

305.2M

361.99

link

ViT_L_16_Weights.IMAGENET1K_SWAG_LINEAR_V1

85.146

97.422

304.3M

61.55

link

ViT_L_32_Weights.IMAGENET1K_V1

76.972

93.07

306.5M

15.38

link

Wide_ResNet101_2_Weights.IMAGENET1K_V1

78.848

94.284

126.9M

22.75

link

Wide_ResNet101_2_Weights.IMAGENET1K_V2

82.51

96.02

126.9M

22.75

link

Wide_ResNet50_2_Weights.IMAGENET1K_V1

78.468

94.086

68.9M

11.4

link

Wide_ResNet50_2_Weights.IMAGENET1K_V2

81.602

95.758

68.9M

11.4

link

量化模型

以下架构支持 INT8 量化模型,无论是否使用预训练权重


以下是如何使用预训练的量化图像分类模型的示例

from torchvision.io import decode_image
from torchvision.models.quantization import resnet50, ResNet50_QuantizedWeights

img = decode_image("test/assets/encode_jpeg/grace_hopper_517x606.jpg")

# Step 1: Initialize model with the best available weights
weights = ResNet50_QuantizedWeights.DEFAULT
model = resnet50(weights=weights, quantize=True)
model.eval()

# Step 2: Initialize the inference transforms
preprocess = weights.transforms()

# Step 3: Apply inference preprocessing transforms
batch = preprocess(img).unsqueeze(0)

# Step 4: Use the model and print the predicted category
prediction = model(batch).squeeze(0).softmax(0)
class_id = prediction.argmax().item()
score = prediction[class_id].item()
category_name = weights.meta["categories"][class_id]
print(f"{category_name}: {100 * score}%")

预训练模型输出的类可以在 weights.meta["categories"] 中找到。

所有可用量化分类权重表

使用单个裁剪在 ImageNet-1K 上报告的准确性

权重

Acc@1

Acc@5

参数

GIPS

食谱

GoogLeNet_QuantizedWeights.IMAGENET1K_FBGEMM_V1

69.826

89.404

6.6M

1.5

link

Inception_V3_QuantizedWeights.IMAGENET1K_FBGEMM_V1

77.176

93.354

27.2M

5.71

link

MobileNet_V2_QuantizedWeights.IMAGENET1K_QNNPACK_V1

71.658

90.15

3.5M

0.3

link

MobileNet_V3_Large_QuantizedWeights.IMAGENET1K_QNNPACK_V1

73.004

90.858

5.5M

0.22

link

ResNeXt101_32X8D_QuantizedWeights.IMAGENET1K_FBGEMM_V1

78.986

94.48

88.8M

16.41

link

ResNeXt101_32X8D_QuantizedWeights.IMAGENET1K_FBGEMM_V2

82.574

96.132

88.8M

16.41

link

ResNeXt101_64X4D_QuantizedWeights.IMAGENET1K_FBGEMM_V1

82.898

96.326

83.5M

15.46

link

ResNet18_QuantizedWeights.IMAGENET1K_FBGEMM_V1

69.494

88.882

11.7M

1.81

link

ResNet50_QuantizedWeights.IMAGENET1K_FBGEMM_V1

75.92

92.814

25.6M

4.09

link

ResNet50_QuantizedWeights.IMAGENET1K_FBGEMM_V2

80.282

94.976

25.6M

4.09

link

ShuffleNet_V2_X0_5_QuantizedWeights.IMAGENET1K_FBGEMM_V1

57.972

79.78

1.4M

0.04

link

ShuffleNet_V2_X1_0_QuantizedWeights.IMAGENET1K_FBGEMM_V1

68.36

87.582

2.3M

0.14

link

ShuffleNet_V2_X1_5_QuantizedWeights.IMAGENET1K_FBGEMM_V1

72.052

90.7

3.5M

0.3

link

ShuffleNet_V2_X2_0_QuantizedWeights.IMAGENET1K_FBGEMM_V1

75.354

92.488

7.4M

0.58

link

语义分割

警告

分割模块处于 Beta 阶段,不保证向后兼容性。

以下语义分割模型可用,无论是否使用预训练权重


以下是如何使用预训练的语义分割模型的示例

from torchvision.io.image import decode_image
from torchvision.models.segmentation import fcn_resnet50, FCN_ResNet50_Weights
from torchvision.transforms.functional import to_pil_image

img = decode_image("gallery/assets/dog1.jpg")

# Step 1: Initialize model with the best available weights
weights = FCN_ResNet50_Weights.DEFAULT
model = fcn_resnet50(weights=weights)
model.eval()

# Step 2: Initialize the inference transforms
preprocess = weights.transforms()

# Step 3: Apply inference preprocessing transforms
batch = preprocess(img).unsqueeze(0)

# Step 4: Use the model and visualize the prediction
prediction = model(batch)["out"]
normalized_masks = prediction.softmax(dim=1)
class_to_idx = {cls: idx for (idx, cls) in enumerate(weights.meta["categories"])}
mask = normalized_masks[0, class_to_idx["dog"]]
to_pil_image(mask).show()

预训练模型输出的类别可以在 weights.meta["categories"] 中找到。模型的输出格式在 语义分割模型 中说明。

所有可用语义分割权重表

所有模型都在 COCO val2017 的子集上进行评估,该子集包含 Pascal VOC 数据集中存在的 20 个类别

权重

平均 IoU

逐像素精度

参数

GFLOPS

食谱

DeepLabV3_MobileNet_V3_Large_Weights.COCO_WITH_VOC_LABELS_V1

60.3

91.2

11.0M

10.45

link

DeepLabV3_ResNet101_Weights.COCO_WITH_VOC_LABELS_V1

67.4

92.4

61.0M

258.74

link

DeepLabV3_ResNet50_Weights.COCO_WITH_VOC_LABELS_V1

66.4

92.4

42.0M

178.72

link

FCN_ResNet101_Weights.COCO_WITH_VOC_LABELS_V1

63.7

91.9

54.3M

232.74

link

FCN_ResNet50_Weights.COCO_WITH_VOC_LABELS_V1

60.5

91.4

35.3M

152.72

link

LRASPP_MobileNet_V3_Large_Weights.COCO_WITH_VOC_LABELS_V1

57.9

91.2

3.2M

2.09

link

目标检测、实例分割和人物关键点检测

用于检测、实例分割和关键点检测的预训练模型使用 torchvision 中的分类模型进行初始化。模型期望一个 Tensor[C, H, W] 列表。有关更多信息,请查看模型的构造函数。

警告

检测模块处于 Beta 阶段,不保证向后兼容性。

目标检测

以下目标检测模型可用,无论是否使用预训练权重


以下是如何使用预训练的目标检测模型的示例

from torchvision.io.image import decode_image
from torchvision.models.detection import fasterrcnn_resnet50_fpn_v2, FasterRCNN_ResNet50_FPN_V2_Weights
from torchvision.utils import draw_bounding_boxes
from torchvision.transforms.functional import to_pil_image

img = decode_image("test/assets/encode_jpeg/grace_hopper_517x606.jpg")

# Step 1: Initialize model with the best available weights
weights = FasterRCNN_ResNet50_FPN_V2_Weights.DEFAULT
model = fasterrcnn_resnet50_fpn_v2(weights=weights, box_score_thresh=0.9)
model.eval()

# Step 2: Initialize the inference transforms
preprocess = weights.transforms()

# Step 3: Apply inference preprocessing transforms
batch = [preprocess(img)]

# Step 4: Use the model and visualize the prediction
prediction = model(batch)[0]
labels = [weights.meta["categories"][i] for i in prediction["labels"]]
box = draw_bounding_boxes(img, boxes=prediction["boxes"],
                          labels=labels,
                          colors="red",
                          width=4, font_size=30)
im = to_pil_image(box.detach())
im.show()

预训练模型输出的类别可以在 weights.meta["categories"] 中找到。有关如何绘制模型边界框的详细信息,您可以参考 实例分割模型

所有可用目标检测权重表

边界框 MAP 在 COCO val2017 上报告

权重

边界框 MAP

参数

GFLOPS

食谱

FCOS_ResNet50_FPN_Weights.COCO_V1

39.2

32.3M

128.21

link

FasterRCNN_MobileNet_V3_Large_320_FPN_Weights.COCO_V1

22.8

19.4M

0.72

link

FasterRCNN_MobileNet_V3_Large_FPN_Weights.COCO_V1

32.8

19.4M

4.49

link

FasterRCNN_ResNet50_FPN_V2_Weights.COCO_V1

46.7

43.7M

280.37

link

FasterRCNN_ResNet50_FPN_Weights.COCO_V1

37

41.8M

134.38

link

RetinaNet_ResNet50_FPN_V2_Weights.COCO_V1

41.5

38.2M

152.24

link

RetinaNet_ResNet50_FPN_Weights.COCO_V1

36.4

34.0M

151.54

link

SSD300_VGG16_Weights.COCO_V1

25.1

35.6M

34.86

link

SSDLite320_MobileNet_V3_Large_Weights.COCO_V1

21.3

3.4M

0.58

link

实例分割

以下实例分割模型可用,无论是否使用预训练权重


有关如何绘制模型掩码的详细信息,您可以参考 实例分割模型

所有可用实例分割权重表

边界框和掩码 MAP 在 COCO val2017 上报告

权重

边界框 MAP

掩码 MAP

参数

GFLOPS

食谱

MaskRCNN_ResNet50_FPN_V2_Weights.COCO_V1

47.4

41.8

46.4M

333.58

link

MaskRCNN_ResNet50_FPN_Weights.COCO_V1

37.9

34.6

44.4M

134.38

link

关键点检测

以下人物关键点检测模型可用,无论是否使用预训练权重


预训练模型输出的类别可以在 weights.meta["keypoint_names"] 中找到。有关如何绘制模型边界框的详细信息,您可以参考 可视化关键点

所有可用关键点检测权重表

边界框和关键点 MAP 在 COCO val2017 上报告

权重

边界框 MAP

关键点 MAP

参数

GFLOPS

食谱

KeypointRCNN_ResNet50_FPN_Weights.COCO_LEGACY

50.6

61.1

59.1M

133.92

link

KeypointRCNN_ResNet50_FPN_Weights.COCO_V1

54.6

65

59.1M

137.42

link

视频分类

警告

视频模块处于 Beta 阶段,不保证向后兼容性。

以下视频分类模型可用,无论是否使用预训练权重


以下是如何使用预训练的视频分类模型的示例

from torchvision.io.video import read_video
from torchvision.models.video import r3d_18, R3D_18_Weights

vid, _, _ = read_video("test/assets/videos/v_SoccerJuggling_g23_c01.avi", output_format="TCHW")
vid = vid[:32]  # optionally shorten duration

# Step 1: Initialize model with the best available weights
weights = R3D_18_Weights.DEFAULT
model = r3d_18(weights=weights)
model.eval()

# Step 2: Initialize the inference transforms
preprocess = weights.transforms()

# Step 3: Apply inference preprocessing transforms
batch = preprocess(vid).unsqueeze(0)

# Step 4: Use the model and print the predicted category
prediction = model(batch).squeeze(0).softmax(0)
label = prediction.argmax().item()
score = prediction[label].item()
category_name = weights.meta["categories"][label]
print(f"{category_name}: {100 * score}%")

预训练模型输出的类可以在 weights.meta["categories"] 中找到。

所有可用视频分类权重表

准确率在 Kinetics-400 上报告,使用单裁剪,剪辑长度为 16

权重

Acc@1

Acc@5

参数

GFLOPS

食谱

MC3_18_Weights.KINETICS400_V1

63.96

84.13

11.7M

43.34

link

MViT_V1_B_Weights.KINETICS400_V1

78.477

93.582

36.6M

70.6

link

MViT_V2_S_Weights.KINETICS400_V1

80.757

94.665

34.5M

64.22

link

R2Plus1D_18_Weights.KINETICS400_V1

67.463

86.175

31.5M

40.52

link

R3D_18_Weights.KINETICS400_V1

63.2

83.479

33.4M

40.7

link

S3D_Weights.KINETICS400_V1

68.368

88.05

8.3M

17.98

link

Swin3D_B_Weights.KINETICS400_V1

79.427

94.386

88.0M

140.67

link

Swin3D_B_Weights.KINETICS400_IMAGENET22K_V1

81.643

95.574

88.0M

140.67

link

Swin3D_S_Weights.KINETICS400_V1

79.521

94.158

49.8M

82.84

link

Swin3D_T_Weights.KINETICS400_V1

77.715

93.519

28.2M

43.88

link

光流

以下光流模型可用,无论是否使用预训练

文档

访问 PyTorch 的全面开发者文档

查看文档

教程

获取针对初学者和高级开发人员的深入教程

查看教程

资源

查找开发资源并获得您的问题的解答

查看资源