可复现性是许多研究领域(包括基于机器学习技术的领域)的基本要求。然而,许多机器学习出版物要么不可复现,要么难以复现。随着研究出版物数量的持续增长,包括 arXiv 上托管的数万篇论文以及达到历史新高的会议投稿量,研究可复现性比以往任何时候都更加重要。虽然这些出版物中有许多都附带了代码和训练好的模型,这很有帮助,但仍然需要用户自行解决许多步骤。
我们很高兴宣布推出 PyTorch Hub,这是一个简单的 API 和工作流程,提供了提高机器学习研究可复现性的基本构建模块。PyTorch Hub 包含一个专门为促进研究可复现性和支持新研究而设计的预训练模型库。它还内置了对 Colab 的支持,与 Papers With Code 集成,目前包含广泛的模型集,包括分类和分割、生成模型、Transformers 等。

[所有者] 发布模型
PyTorch Hub 支持将预训练模型(模型定义和预训练权重)发布到 GitHub 仓库,只需添加一个简单的 hubconf.py
文件。该文件列出了支持的模型以及运行这些模型所需的依赖项。示例可在 torchvision、huggingface-bert 和 gan-model-zoo 仓库中找到。
让我们来看看最简单的情况:torchvision
的 hubconf.py
# Optional list of dependencies required by the package
dependencies = ['torch']
from torchvision.models.alexnet import alexnet
from torchvision.models.densenet import densenet121, densenet169, densenet201, densenet161
from torchvision.models.inception import inception_v3
from torchvision.models.resnet import resnet18, resnet34, resnet50, resnet101, resnet152,\
resnext50_32x4d, resnext101_32x8d
from torchvision.models.squeezenet import squeezenet1_0, squeezenet1_1
from torchvision.models.vgg import vgg11, vgg13, vgg16, vgg19, vgg11_bn, vgg13_bn, vgg16_bn, vgg19_bn
from torchvision.models.segmentation import fcn_resnet101, deeplabv3_resnet101
from torchvision.models.googlenet import googlenet
from torchvision.models.shufflenetv2 import shufflenet_v2_x0_5, shufflenet_v2_x1_0
from torchvision.models.mobilenet import mobilenet_v2
在 torchvision
中,模型具有以下属性
- 每个模型文件都可以独立运行和执行
- 它们不需要 PyTorch 之外的任何包(在
hubconf.py
中编码为dependencies['torch']
) - 它们不需要单独的入口点,因为创建的模型开箱即用
最小化包依赖可以减少用户加载您的模型进行即时实验的阻力。
一个更复杂的示例是 HuggingFace 的 BERT 模型。以下是它们的 hubconf.py
dependencies = ['torch', 'tqdm', 'boto3', 'requests', 'regex']
from hubconfs.bert_hubconf import (
bertTokenizer,
bertModel,
bertForNextSentencePrediction,
bertForPreTraining,
bertForMaskedLM,
bertForSequenceClassification,
bertForMultipleChoice,
bertForQuestionAnswering,
bertForTokenClassification
)
然后每个模型都需要创建一个入口点。这是一个代码片段,用于指定 bertForMaskedLM
模型的入口点,该入口点返回预训练模型权重。
def bertForMaskedLM(*args, **kwargs):
"""
BertForMaskedLM includes the BertModel Transformer followed by the
pre-trained masked language modeling head.
Example:
...
"""
model = BertForMaskedLM.from_pretrained(*args, **kwargs)
return model
这些入口点可以作为复杂模型工厂的包装器。它们可以提供清晰一致的帮助文档字符串,包含支持下载预训练权重(例如通过 pretrained=True
)的逻辑,或者具有额外的 Hub 特定功能,例如可视化。
准备好 hubconf.py
文件后,您可以根据此处提供的模板发送拉取请求。我们的目标是策划高质量、易于复现、对研究复现性最有益的模型。因此,我们可能会与您一起完善您的拉取请求,在某些情况下可能会拒绝一些低质量的模型发布。一旦我们接受您的拉取请求,您的模型将很快出现在 PyTorch Hub 网页上,供所有用户探索。
[用户] 工作流程
作为用户,PyTorch Hub 允许您遵循几个简单的步骤来完成以下操作:1)探索可用的模型;2)加载模型;以及 3)了解任何给定模型有哪些可用方法。让我们逐一 살펴一些示例。
探索可用的入口点。
用户可以使用 torch.hub.list()
API 列出仓库中所有可用的入口点。
>>> torch.hub.list('pytorch/vision')
>>>
['alexnet',
'deeplabv3_resnet101',
'densenet121',
...
'vgg16',
'vgg16_bn',
'vgg19',
'vgg19_bn']
请注意,PyTorch Hub 还允许辅助入口点(除了预训练模型),例如 BERT 模型中用于预处理的 bertTokenizer
,以使用户工作流程更流畅。
加载模型
既然我们知道 Hub 中有哪些模型可用,用户就可以使用 torch.hub.load()
API 加载模型入口点。这只需要一个简单的命令,无需安装 wheel 包。此外,torch.hub.help()
API 可以提供关于如何实例化模型的有用信息。
print(torch.hub.help('pytorch/vision', 'deeplabv3_resnet101'))
model = torch.hub.load('pytorch/vision', 'deeplabv3_resnet101', pretrained=True)
仓库所有者经常会希望持续添加错误修复或性能改进。PyTorch Hub 使通过调用以下命令获取最新更新变得非常简单:
model = torch.hub.load(..., force_reload=True)
我们相信这将有助于减轻仓库所有者重复发布包的负担,并让他们更专注于研究。它还确保作为用户,您正在获得最新可用的模型。
反之,稳定性对用户很重要。因此,一些模型所有者从特定的分支或标签(而不是 master
分支)提供模型,以确保代码的稳定性。例如,pytorch_GAN_zoo
从 hub
分支提供模型
model = torch.hub.load('facebookresearch/pytorch_GAN_zoo:hub', 'DCGAN', pretrained=True, useGPU=False)
请注意,传递给 hub.load()
的 *args
、**kwargs
用于实例化模型。在上面的示例中,pretrained=True
和 useGPU=False
被传递给模型的入口点。
探索已加载的模型
从 PyTorch Hub 加载模型后,您可以使用以下工作流程来查找受支持的可用方法,并更好地了解运行它需要哪些参数。
使用 dir(model)
查看模型所有可用方法。让我们来看看 bertForMaskedLM
的可用方法。
>>> dir(model)
>>>
['forward'
...
'to'
'state_dict',
]
help(model.forward)
提供了一个视图,说明运行已加载的模型需要哪些参数。
>>> help(model.forward)
>>>
Help on method forward in module pytorch_pretrained_bert.modeling:
forward(input_ids, token_type_ids=None, attention_mask=None, masked_lm_labels=None)
...
仔细查看 BERT 和 DeepLabV3 页面,您可以在其中查看加载后如何使用这些模型。
其他探索方式
PyTorch Hub 中可用的模型还支持 Colab,并直接链接到 Papers With Code,您只需单击即可开始使用。此处提供了一个很好的入门示例(如下所示)。

额外资源
- PyTorch Hub API 文档可在此处找到。
- 在此处提交模型,以在 PyTorch Hub 中发布。
- 前往 https://pytorch.ac.cn/hub 了解更多关于可用模型的信息。
- 在 paperswithcode.com 上查找即将发布的更多模型。
非常感谢 HuggingFace、PapersWithCode 团队、fast.ai 和 Nvidia 的同仁们,以及 Morgane Riviere (FAIR Paris) 和许多其他帮助启动这项工作的人们!!
致敬!
PyTorch 团队
常见问题
问:如果我们想贡献一个 Hub 中已有的模型,但我的模型准确率可能更高,我仍然应该贡献吗?
答:是的!!Hub 的下一步是实现一个赞/踩系统来推荐最佳模型。
问:谁负责托管 PyTorch Hub 的模型权重?
答:作为贡献者,您负责托管模型权重。您可以在您喜欢的云存储中托管模型,或者如果容量允许,也可以在 GitHub 上托管。如果无法自行托管权重,请在 hub 仓库中提交 issue 与我们联系。
问:如果我的模型是在私有数据上训练的怎么办?我仍然应该贡献这个模型吗?
答:不!PyTorch Hub 围绕开放研究展开,这包括使用开放数据集来训练这些模型。如果提交了关于专有模型的拉取请求,我们将礼貌地要求您重新提交一个在开放可用数据上训练的模型。
问:我下载的模型保存在哪里?
答:我们遵循 XDG Base Directory Specification 规范,并遵守有关缓存文件和目录的常见标准。
使用位置的优先顺序为:
- 调用
hub.set_dir(<PATH_TO_HUB_DIR>)
$TORCH_HOME/hub
,如果设置了环境变量TORCH_HOME
。$XDG_CACHE_HOME/torch/hub
,如果设置了环境变量XDG_CACHE_HOME
。~/.cache/torch/hub