torch.hub¶
Pytorch Hub 是一个经过预训练的模型存储库,旨在促进研究的可重复性。
发布模型¶
Pytorch Hub 支持通过添加一个简单的 hubconf.py 文件,将经过预训练的模型(模型定义和预训练权重)发布到 GitHub 存储库;
hubconf.py 可以有多个入口点。每个入口点都定义为一个 python 函数(例如:您要发布的经过预训练的模型)。
def entrypoint_name(*args, **kwargs):
    # args & kwargs are optional, for models which take positional/keyword arguments.
    ...
如何实现入口点?¶
如果我们在 pytorch/vision/hubconf.py 中扩展实现,这里有一个代码片段指定了 resnet18 模型的入口点。在大多数情况下,在 hubconf.py 中导入正确的函数就足够了。这里我们只想使用扩展版本作为一个示例来说明它是如何工作的。您可以在 pytorch/vision 存储库 中看到完整脚本
dependencies = ['torch']
from torchvision.models.resnet import resnet18 as _resnet18
# resnet18 is the name of entrypoint
def resnet18(pretrained=False, **kwargs):
    """ # This docstring shows up in hub.help()
    Resnet18 model
    pretrained (bool): kwargs, load pretrained weights into the model
    """
    # Call the model, load pretrained weights
    model = _resnet18(pretrained=pretrained, **kwargs)
    return model
- dependencies变量是加载模型所需的包名称的列表。请注意,这可能与训练模型所需的依赖项略有不同。
- args和- kwargs传递给真正的可调用函数。
- 函数的文档字符串用作帮助消息。它解释了模型的作用以及允许的位置/关键字参数。强烈建议在此处添加一些示例。 
- 入口点函数可以返回一个模型(nn.module),也可以返回辅助工具以使用户工作流更顺畅,例如标记器。 
- 以下划线为前缀的可调用函数被视为辅助函数,不会显示在 - torch.hub.list()中。
- 预训练权重可以存储在 GitHub 仓库中,也可以通过 - torch.hub.load_state_dict_from_url()加载。如果小于 2GB,建议将其附加到 项目版本 并使用版本中的 URL。在上面的示例中- torchvision.models.resnet.resnet18处理- pretrained,或者你可以在入口点定义中放置以下逻辑。
if pretrained:
    # For checkpoint saved in local GitHub repo, e.g. <RELATIVE_PATH_TO_CHECKPOINT>=weights/save.pth
    dirname = os.path.dirname(__file__)
    checkpoint = os.path.join(dirname, <RELATIVE_PATH_TO_CHECKPOINT>)
    state_dict = torch.load(checkpoint)
    model.load_state_dict(state_dict)
    # For checkpoint saved elsewhere
    checkpoint = 'https://download.pytorch.org/models/resnet18-5c106cde.pth'
    model.load_state_dict(torch.hub.load_state_dict_from_url(checkpoint, progress=False))
重要提示¶
- 已发布的模型至少应位于分支/标签中。它不能是随机提交。 
从 Hub 加载模型¶
Pytorch Hub 提供了便捷的 API,可通过 torch.hub.list() 浏览 hub 中所有可用的模型,通过 torch.hub.help() 显示文档字符串和示例,并使用 torch.hub.load() 加载预训练模型。
- torch.hub.list(github, force_reload=False, skip_validation=False, trust_repo=None, verbose=True)[source]¶
- 列出 - github指定的仓库中可用的所有可调用入口点。- 参数
- github (str) – 格式为 “repo_owner/repo_name[:ref]” 的字符串,其中 ref(标签或分支)是可选的。如果未指定 - ref,则假定默认分支为- main(如果存在),否则为- master。示例:‘pytorch/vision:0.10’
- force_reload (bool, 可选) – 是否丢弃现有缓存并强制重新下载。默认为 - False。
- skip_validation (bool, 可选) – 如果 - False,torchhub 将检查- github参数指定的 branch 或 commit 是否正确属于 repo 所有者。这会向 GitHub API 发出请求;你可以通过设置- GITHUB_TOKEN环境变量来指定一个非默认的 GitHub 令牌。默认值为- False。
- trust_repo (bool, str 或 None) – - "check",- True,- False或- None。此参数在 v1.12 中引入,有助于确保用户仅运行来自他们信任的 repo 的代码。- 如果为 - False,系统会提示用户是否信任该 repo。
- 如果为 - True,repo 将被添加到信任列表中,无需明确确认即可加载。
- 如果为 - "check",repo 将与缓存中的信任 repo 列表进行比对。如果不在该列表中,行为将回退到- trust_repo=False选项。
- 如果为 - None:这会引发一个警告,提示用户将- trust_repo设置为- False、- True或- "check"。这仅出于向后兼容性考虑而存在,将在 v2.0 中移除。
 - 默认值为 - None,最终将在 v2.0 中更改为- "check"。
- verbose (bool, 可选) – 如果为 - False,则静音有关命中本地缓存的消息。请注意,无法静音有关首次下载的消息。默认值为- True。
 
- 返回
- 可用的可调用入口点 
- 返回类型
 - 示例 - >>> entrypoints = torch.hub.list('pytorch/vision', force_reload=True) 
- torch.hub.help(github, model, force_reload=False, skip_validation=False, trust_repo=None)[source]¶
- 显示入口点 - model的文档字符串。- 参数
- github (str) – 格式为 <repo_owner/repo_name[:ref]> 的字符串,其中包含可选的 ref(标签或分支)。如果未指定 - ref,则假定默认分支为- main(如果存在),否则为- master。例如:‘pytorch/vision:0.10’
- model (str) – repo 的 - hubconf.py中定义的入口点名称的字符串
- force_reload (bool, 可选) – 是否丢弃现有缓存并强制重新下载。默认为 - False。
- skip_validation (bool, optional) – 如果 - False,torchhub 将检查- github参数指定的 ref 是否正确属于 repo 所有者。这会向 GitHub API 发出请求;你可以通过设置- GITHUB_TOKEN环境变量来指定非默认的 GitHub 令牌。默认值为- False。
- trust_repo (bool, str 或 None) – - "check",- True,- False或- None。此参数在 v1.12 中引入,有助于确保用户仅运行来自他们信任的 repo 的代码。- 如果为 - False,系统会提示用户是否信任该 repo。
- 如果为 - True,repo 将被添加到信任列表中,无需明确确认即可加载。
- 如果为 - "check",repo 将与缓存中的信任 repo 列表进行比对。如果不在该列表中,行为将回退到- trust_repo=False选项。
- 如果为 - None:这会引发一个警告,提示用户将- trust_repo设置为- False、- True或- "check"。这仅出于向后兼容性考虑而存在,将在 v2.0 中移除。
 - 默认值为 - None,最终将在 v2.0 中更改为- "check"。
 
 - 示例 - >>> print(torch.hub.help('pytorch/vision', 'resnet18', force_reload=True)) 
- torch.hub.load(repo_or_dir, model, *args, source='github', trust_repo=None, force_reload=False, verbose=True, skip_validation=False, **kwargs)[source]¶
- 从 github repo 或本地目录加载模型。 - 注意:加载模型是典型用例,但也可以用于加载其他对象,如 tokenizer、损失函数等。 - 如果 - source为 ‘github’,则- repo_or_dir的预期格式为- repo_owner/repo_name[:ref],其中包含可选的 ref(标签或分支)。- 如果 - source为 ‘local’,则- repo_or_dir的预期格式为本地目录的路径。- 参数
- repo_or_dir (str) – 如果 - source为 ‘github’,则应对应于格式为- repo_owner/repo_name[:ref]的 github repo,其中包含可选的 ref(标签或分支),例如 ‘pytorch/vision:0.10’。如果未指定- ref,则假定默认分支为- main(如果存在),否则为- master。如果- source为 ‘local’,则应为本地目录的路径。
- model (str) – repo/dir 中 - hubconf.py中定义的可调用项(入口点)的名称。
- *args (可选) – 可调用项 - model的对应参数。
- source (str, 可选) – ‘github’ 或 ‘local’。指定如何解释 - repo_or_dir。默认值为 ‘github’。
- trust_repo (bool, str 或 None) – - "check",- True,- False或- None。此参数在 v1.12 中引入,有助于确保用户仅运行来自他们信任的 repo 的代码。- 如果为 - False,系统会提示用户是否信任该 repo。
- 如果为 - True,repo 将被添加到信任列表中,无需明确确认即可加载。
- 如果为 - "check",repo 将与缓存中的信任 repo 列表进行比对。如果不在该列表中,行为将回退到- trust_repo=False选项。
- 如果为 - None:这会引发一个警告,提示用户将- trust_repo设置为- False、- True或- "check"。这仅出于向后兼容性考虑而存在,将在 v2.0 中移除。
 - 默认值为 - None,最终将在 v2.0 中更改为- "check"。
- force_reload (bool, 可选) – 是否强制无条件地重新下载 github 仓库。如果 - source = 'local',则不会产生任何影响。默认值为- False。
- verbose (bool, 可选) – 如果 - False,则静音有关命中本地缓存的消息。请注意,无法静音有关首次下载的消息。如果- source = 'local',则不会产生任何影响。默认值为- True。
- skip_validation (bool, 可选) – 如果 - False,torchhub 将检查- github参数指定的 branch 或 commit 是否正确属于 repo 所有者。这会向 GitHub API 发出请求;你可以通过设置- GITHUB_TOKEN环境变量来指定一个非默认的 GitHub 令牌。默认值为- False。
- **kwargs (可选) – 可调用项 - model的对应关键字参数。
 
- 返回
- 使用给定的 - *args和- **kwargs调用- model可调用项时的输出。
 - 示例 - >>> # from a github repo >>> repo = 'pytorch/vision' >>> model = torch.hub.load(repo, 'resnet50', weights='ResNet50_Weights.IMAGENET1K_V1') >>> # from a local directory >>> path = '/some/local/path/pytorch/vision' >>> model = torch.hub.load(path, 'resnet50', weights='ResNet50_Weights.DEFAULT') 
- torch.hub.download_url_to_file(url, dst, hash_prefix=None, progress=True)[source]¶
- 将给定 URL 中的对象下载到本地路径。 - 参数
 - 示例 - >>> torch.hub.download_url_to_file('https://s3.amazonaws.com/pytorch/models/resnet18-5c106cde.pth', '/tmp/temporary_file') 
- torch.hub.load_state_dict_from_url(url, model_dir=None, map_location=None, progress=True, check_hash=False, file_name=None, weights_only=False)[source]¶
- 加载给定 URL 中的 Torch 序列化对象。 - 如果下载的文件是 zip 文件,它将自动解压。 - 如果对象已存在于 model_dir 中,则对其进行反序列化并返回。 - model_dir的默认值为- <hub_dir>/checkpoints,其中- hub_dir是- get_dir()返回的目录。- 参数
- url (str) – 要下载的对象的 URL 
- model_dir (str, 可选) – 保存对象的目录 
- map_location (可选) – 指定如何重新映射存储位置的函数或字典(请参阅 torch.load) 
- progress (bool, 可选) – 是否在 stderr 中显示进度条。默认值:True 
- check_hash (bool, 可选) – 如果为 True,则 URL 的文件名部分应遵循命名约定 - filename-<sha256>.ext,其中- <sha256>是文件内容的 SHA256 哈希的前八位或更多位。哈希用于确保名称唯一并验证文件内容。默认值:False
- file_name (str, 可选) – 下载的文件的名称。如果未设置,将使用 - url中的文件名。
- weights_only (bool, 可选) – 如果为 True,则只加载权重,不加载复杂的腌制对象。建议用于不受信任的来源。有关更多详细信息,请参阅 - load()。
 
- 返回类型
 - 示例 - >>> state_dict = torch.hub.load_state_dict_from_url('https://s3.amazonaws.com/pytorch/models/resnet18-5c106cde.pth') 
运行加载的模型:¶
请注意,torch.hub.load() 中的 *args 和 **kwargs 用于实例化模型。加载模型后,如何才能了解如何使用该模型?建议的工作流程如下:
- dir(model)查看模型的所有可用方法。
- help(model.foo)查看- model.foo运行所需的参数
为了帮助用户探索而不必反复参考文档,我们强烈建议仓库所有者使函数帮助信息清晰简洁。包括一个最小的工作示例也有帮助。
下载的模型保存在哪里?¶
按以下顺序使用位置
- 调用 - hub.set_dir(<PATH_TO_HUB_DIR>)
- $TORCH_HOME/hub,如果设置了环境变量- TORCH_HOME。
- $XDG_CACHE_HOME/torch/hub,如果设置了环境变量- XDG_CACHE_HOME。
- ~/.cache/torch/hub
缓存逻辑¶
默认情况下,我们不会在加载文件后对其进行清理。如果 Hub 已存在于 get_dir() 返回的目录中,则 Hub 默认使用缓存。
用户可以通过调用 hub.load(..., force_reload=True) 来强制重新加载。这将删除现有的 GitHub 文件夹和下载的权重,重新初始化新的下载。当更新发布到同一分支时,这很有用,用户可以随时了解最新版本。
已知限制:¶
Torch hub 通过导入包(就像已安装一样)来工作。在 Python 中导入会产生一些副作用。例如,您可以在 Python 缓存 sys.modules 和 sys.path_importer_cache 中看到新项目,这是正常的 Python 行为。这也意味着,如果您从不同的存储库导入不同的模型,而这些存储库具有相同的子包名称(通常是 model 子包),则在导入时可能会出现导入错误。解决此类导入错误的一种方法是从 sys.modules 字典中删除有问题的子包;更多详细信息,请参阅 此 GitHub 问题。
值得在此处提到的一个已知限制:用户不能在同一 Python 进程中加载同一存储库的两个不同分支。这就像在 Python 中安装两个名称相同的包,这是不好的。如果您真的尝试这样做,缓存可能会加入其中并给您带来惊喜。当然,在单独的进程中加载它们完全没问题。