TorchVision 现在支持按名称列出并初始化所有内置模型和权重。这一新 API 构建于近期引入的 多权重支持 API (Multi-weight support API) 之上,目前处于 Beta 测试阶段,旨在解决社区长期以来的 需求。

您可以在 TorchVision 的 最新 nightly 版本 中试用此新 API。我们希望在 TorchVision v0.14 正式发布前收集反馈。我们已在 GitHub 上创建了专门的 Issue,您可以在其中发表评论、提出问题和建议!
查询和初始化可用模型
在新的模型注册 API 出现之前,开发人员必须查询模块的 __dict__ 属性,才能列出所有可用模型或通过名称获取特定的模型构建方法。
# Initialize a model by its name:
model = torchvision.models.__dict__[model_name]()
# List available models:
available_models = [
k for k, v in torchvision.models.__dict__.items()
if callable(v) and k[0].islower() and k[0] != "_"
]
上述方法并不总是能产生预期的结果,且难以发现。例如,由于 get_weight() 方法是在同一模块下公开暴露的,因此即使它不是一个模型,也会被包含在列表中。总的来说,社区此前曾反馈希望减少冗余(更少的导入、更短的名称等),并能够直接通过名称初始化模型和权重(以更好地支持配置文件、TorchHub 等)。为了解决这个问题,我们开发了一套模型注册 API。
一种新方法
我们在 torchvision.models 模块下增加了 4 个新方法:
from torchvision.models import get_model, get_model_weights, get_weight, list_models
其风格和命名规范与 Philip Meier 为 Datasets V2 API 提出的原型机制非常一致,旨在提供相似的用户体验。我们特意将这些模型注册方法设为私有,因为目前我们仅专注于支持 TorchVision 的内置模型。
列出模型
通过单个函数调用即可列出 TorchVision 中所有可用的模型:
>>> list_models()
['alexnet', 'mobilenet_v3_large', 'mobilenet_v3_small', 'quantized_mobilenet_v3_large', ...]
要列出特定子模块的可用模型:
>>> list_models(module=torchvision.models)
['alexnet', 'mobilenet_v3_large', 'mobilenet_v3_small', ...]
>>> list_models(module=torchvision.models.quantization)
['quantized_mobilenet_v3_large', ...]
初始化模型
既然您已经知道哪些模型可用,就可以轻松地使用预训练权重来初始化模型:
>>> get_model("quantized_mobilenet_v3_large", weights="DEFAULT")
QuantizableMobileNetV3(
(features): Sequential(
....
)
)
获取权重
有时,在使用配置文件或使用 TorchHub 时,您可能已经有了特定权重条目的名称并希望获取其对应的实例。这可以通过以下方法轻松实现:
>>> get_weight("ResNet50_Weights.IMAGENET1K_V2")
ResNet50_Weights.IMAGENET1K_V2
要获取包含特定模型所有可用权重的枚举类,您可以使用其名称:
>>> get_model_weights("quantized_mobilenet_v3_large")
<enum 'MobileNet_V3_Large_QuantizedWeights'>
或者使用其模型构建方法:
>>> get_model_weights(torchvision.models.quantization.mobilenet_v3_large)
<enum 'MobileNet_V3_Large_QuantizedWeights'>
TorchHub 支持
这些新方法也通过 TorchHub 提供:
import torch
# Fetching a specific weight entry by its name:
weights = torch.hub.load("pytorch/vision", "get_weight", weights="ResNet50_Weights.IMAGENET1K_V2")
# Fetching the weights enum class to list all available entries:
weight_enum = torch.hub.load("pytorch/vision", "get_model_weights", name="resnet50")
print([weight for weight in weight_enum])
综合使用
例如,如果您想检索所有具有预训练权重的小型模型并初始化其中一个,只需使用上述 API 即可:
import torchvision
from torchvision.models import get_model, get_model_weights, list_models
max_params = 5000000
tiny_models = []
for model_name in list_models(module=torchvision.models):
weights_enum = get_model_weights(model_name)
if len([w for w in weights_enum if w.meta["num_params"] <= max_params]) > 0:
tiny_models.append(model_name)
print(tiny_models)
# ['mnasnet0_5', 'mnasnet0_75', 'mnasnet1_0', 'mobilenet_v2', ...]
model = get_model(tiny_models[0], weights="DEFAULT")
print(sum(x.numel() for x in model.state_dict().values()))
# 2239188
欲了解更多技术细节,请参阅原始 RFC。请抽出几分钟时间对新 API 提供反馈,这对我们将其从 Beta 版转为正式版并包含在下一个版本中至关重要。您可以在专门的 Github Issue 中进行反馈。我们期待收到您的评论!