跳转到主要内容
博客

使用 TorchVision 中的新 API 轻松列出和初始化模型

作者: 2022 年 8 月 18 日2024 年 11 月 15 日暂无评论

TorchVision 现在支持按名称列出和初始化所有可用的内置模型和权重。这个新的 API 基于最近推出的多权重支持 API,目前处于 Beta 阶段,它解决了社区长期以来的请求

您可以在 TorchVision 的最新每夜版中试用新的 API。我们希望在 TorchVision v0.14 中最终确定该功能之前收集反馈。我们创建了一个专门的Github Issue,您可以在其中发布您的评论、问题和建议!

查询和初始化可用模型

在新的模型注册 API 之前,开发人员必须查询模块的__dict__属性才能列出所有可用模型或按名称获取特定的模型构建器方法

# Initialize a model by its name:
model = torchvision.models.__dict__[model_name]()

# List available models:
available_models = [
    k for k, v in torchvision.models.__dict__.items()
    if callable(v) and k[0].islower() and k[0] != "_"
]

上述方法不总是能产生预期的结果,并且难以发现。例如,由于get_weight()方法在同一模块下公开,它将被包含在列表中,尽管它不是模型。总的来说,减少冗余(更少的导入、更短的名称等)并能够直接从名称初始化模型和权重(更好地支持配置、TorchHub 等)是社区之前提供的反馈。为了解决这个问题,我们开发了一个模型注册 API。

一种新方法

我们在 torchvision.models 模块下添加了 4 个新方法

from torchvision.models import get_model, get_model_weights, get_weight, list_models

样式和命名约定与 Philip Meier 为Datasets V2 API 提出的原型机制紧密对齐,旨在提供相似的用户体验。模型注册方法有意保持私有,因为我们目前只专注于支持 TorchVision 的内置模型。

列出模型

在 TorchVision 中列出所有可用模型可以通过单个函数调用完成

>>> list_models()
['alexnet', 'mobilenet_v3_large', 'mobilenet_v3_small', 'quantized_mobilenet_v3_large', ...]

要列出特定子模块的可用模型

>>> list_models(module=torchvision.models)
['alexnet', 'mobilenet_v3_large', 'mobilenet_v3_small', ...]
>>> list_models(module=torchvision.models.quantization)
['quantized_mobilenet_v3_large', ...]

初始化模型

现在您知道哪些模型可用,您可以轻松地使用预训练权重初始化模型

>>> get_model("quantized_mobilenet_v3_large", weights="DEFAULT")
QuantizableMobileNetV3(
  (features): Sequential(
   ....
   )
)

获取权重

有时,在使用配置文件或 TorchHub 时,您可能有一个特定权重条目的名称,并希望获取其实例。这可以通过以下方法轻松完成

>>> get_weight("ResNet50_Weights.IMAGENET1K_V2")
ResNet50_Weights.IMAGENET1K_V2

要获取包含特定模型所有可用权重的枚举类,您可以使用其名称

>>> get_model_weights("quantized_mobilenet_v3_large")
<enum 'MobileNet_V3_Large_QuantizedWeights'>

或者其模型构建器方法

>>> get_model_weights(torchvision.models.quantization.mobilenet_v3_large)
<enum 'MobileNet_V3_Large_QuantizedWeights'>

TorchHub 支持

新方法也可以通过 TorchHub 使用

import torch

# Fetching a specific weight entry by its name:
weights = torch.hub.load("pytorch/vision", "get_weight", weights="ResNet50_Weights.IMAGENET1K_V2")

# Fetching the weights enum class to list all available entries:
weight_enum = torch.hub.load("pytorch/vision", "get_model_weights", name="resnet50")
print([weight for weight in weight_enum])

整合起来

例如,如果您想检索所有带有预训练权重的中小型模型并初始化其中一个,只需使用上述 API 即可

import torchvision
from torchvision.models import get_model, get_model_weights, list_models


max_params = 5000000

tiny_models = []
for model_name in list_models(module=torchvision.models):
    weights_enum = get_model_weights(model_name)
    if len([w for w in weights_enum if w.meta["num_params"] <= max_params]) > 0:
        tiny_models.append(model_name)

print(tiny_models)
# ['mnasnet0_5', 'mnasnet0_75', 'mnasnet1_0', 'mobilenet_v2', ...]

model = get_model(tiny_models[0], weights="DEFAULT")
print(sum(x.numel() for x in model.state_dict().values()))
# 2239188

有关更多技术细节,请参阅原始RFC。请花几分钟时间就新 API 提供您的反馈,因为这对于将其从测试版升级并包含在下一次发布中至关重要。您可以在专门的Github Issue上进行此操作。我们期待阅读您的评论!