torch.hub¶
Pytorch Hub 是一个预训练模型存储库,旨在促进研究的可重复性。
发布模型¶
Pytorch Hub 支持通过添加一个简单的 hubconf.py
文件将预训练模型(模型定义和预训练权重)发布到 GitHub 存储库。
hubconf.py
可以有多个入口点。每个入口点都定义为一个 Python 函数(例如:要发布的预训练模型)。
def entrypoint_name(*args, **kwargs):
# args & kwargs are optional, for models which take positional/keyword arguments.
...
如何实现一个入口点?¶
以下代码片段指定了 resnet18
模型的入口点,如果我们在 pytorch/vision/hubconf.py
中扩展了实现。在大多数情况下,导入 hubconf.py
中的正确函数就足够了。这里我们只是想使用扩展版本作为示例来展示其工作原理。您可以在 pytorch/vision 存储库 中看到完整的脚本
dependencies = ['torch']
from torchvision.models.resnet import resnet18 as _resnet18
# resnet18 is the name of entrypoint
def resnet18(pretrained=False, **kwargs):
""" # This docstring shows up in hub.help()
Resnet18 model
pretrained (bool): kwargs, load pretrained weights into the model
"""
# Call the model, load pretrained weights
model = _resnet18(pretrained=pretrained, **kwargs)
return model
dependencies
变量是加载模型所需的包名称的列表。请注意,这可能与训练模型所需的依赖项略有不同。args
和kwargs
将传递给实际的可调用函数。函数的文档字符串用作帮助消息。它解释了模型的功能以及允许的位置/关键字参数。强烈建议在此添加一些示例。
入口点函数可以返回一个模型(nn.module),或者返回辅助工具以使用户工作流程更顺畅,例如令牌生成器。
以下划线开头的可调用函数被视为辅助函数,不会显示在
torch.hub.list()
中。预训练权重可以存储在 GitHub 存储库中的本地位置,或者可以通过
torch.hub.load_state_dict_from_url()
加载。如果小于 2GB,建议将其附加到 项目版本 并使用版本中的 URL。在上面的示例中,torchvision.models.resnet.resnet18
处理pretrained
,或者您也可以在入口点定义中添加以下逻辑。
if pretrained:
# For checkpoint saved in local GitHub repo, e.g. <RELATIVE_PATH_TO_CHECKPOINT>=weights/save.pth
dirname = os.path.dirname(__file__)
checkpoint = os.path.join(dirname, <RELATIVE_PATH_TO_CHECKPOINT>)
state_dict = torch.load(checkpoint)
model.load_state_dict(state_dict)
# For checkpoint saved elsewhere
checkpoint = 'https://download.pytorch.org/models/resnet18-5c106cde.pth'
model.load_state_dict(torch.hub.load_state_dict_from_url(checkpoint, progress=False))
重要通知¶
发布的模型至少应该在一个分支/标签中。它不能是一个随机提交。
从 Hub 加载模型¶
Pytorch Hub 提供方便的 API,可以通过 torch.hub.list()
探索 Hub 中所有可用的模型,通过 torch.hub.help()
显示文档字符串和示例,并使用 torch.hub.load()
加载预训练模型。
- torch.hub.list(github, force_reload=False, skip_validation=False, trust_repo=None, verbose=True)[source]¶
列出
github
指定的存储库中可用的所有可调用入口点。- 参数
github (str) – 格式为“repo_owner/repo_name[:ref]” 的字符串,包含可选的 ref(标签或分支)。如果未指定
ref
,则默认分支假定为main
(如果存在),否则为master
。示例:“pytorch/vision:0.10”force_reload (bool, 可选) – 是否丢弃现有缓存并强制重新下载。默认值为
False
。skip_validation (布尔值, 可选) – 如果为
False
,torchhub 将检查github
参数指定的分支或提交是否属于仓库所有者。这将向 GitHub API 发出请求;您可以通过设置GITHUB_TOKEN
环境变量来指定非默认 GitHub 令牌。默认值为False
。trust_repo (布尔值, 字符串 或 None) –
"check"
,True
,False
或None
。此参数在 v1.12 中引入,有助于确保用户只运行来自他们信任的仓库的代码。如果为
False
,系统将提示用户是否信任该仓库。如果为
True
,该仓库将被添加到信任列表中,并在加载时无需明确确认。如果为
"check"
,该仓库将与缓存中的信任仓库列表进行核对。如果不在该列表中,行为将回退到trust_repo=False
选项。如果为
None
:这将引发警告,提示用户将trust_repo
设置为False
、True
或"check"
。这仅用于向后兼容,将在 v2.0 中移除。
默认值为
None
,最终将在 v2.0 中更改为"check"
。verbose (布尔值, 可选) – 如果为
False
,将静默有关命中本地缓存的消息。请注意,无法静默有关首次下载的消息。默认值为True
。
- 返回值
可用的可调用入口点
- 返回类型
示例
>>> entrypoints = torch.hub.list('pytorch/vision', force_reload=True)
- torch.hub.help(github, model, force_reload=False, skip_validation=False, trust_repo=None)[source]¶
显示入口点
model
的文档字符串。- 参数
github (字符串) – 格式为 <repo_owner/repo_name[:ref]> 的字符串,其中包含可选的 ref(标签或分支)。如果未指定
ref
,则假设默认分支为main
(如果存在),否则为master
。示例:“pytorch/vision:0.10”model (字符串) – 在仓库的
hubconf.py
中定义的可调用(入口点)名称的字符串force_reload (bool, 可选) – 是否丢弃现有缓存并强制重新下载。默认值为
False
。skip_validation (布尔值, 可选) – 如果为
False
,torchhub 将检查github
参数指定的 ref 是否属于仓库所有者。这将向 GitHub API 发出请求;您可以通过设置GITHUB_TOKEN
环境变量来指定非默认 GitHub 令牌。默认值为False
。trust_repo (布尔值, 字符串 或 None) –
"check"
,True
,False
或None
。此参数在 v1.12 中引入,有助于确保用户只运行来自他们信任的仓库的代码。如果为
False
,系统将提示用户是否信任该仓库。如果为
True
,该仓库将被添加到信任列表中,并在加载时无需明确确认。如果为
"check"
,该仓库将与缓存中的信任仓库列表进行核对。如果不在该列表中,行为将回退到trust_repo=False
选项。如果为
None
:这将引发警告,提示用户将trust_repo
设置为False
、True
或"check"
。这仅用于向后兼容,将在 v2.0 中移除。
默认值为
None
,最终将在 v2.0 中更改为"check"
。
示例
>>> print(torch.hub.help('pytorch/vision', 'resnet18', force_reload=True))
- torch.hub.load(repo_or_dir, model, *args, source='github', trust_repo=None, force_reload=False, verbose=True, skip_validation=False, **kwargs)[source]¶
从 GitHub 仓库或本地目录加载模型。
注意:加载模型是典型用例,但这也可以用于加载其他对象,例如令牌化器、损失函数等。
如果
source
为 ‘github’,则repo_or_dir
预期为repo_owner/repo_name[:ref]
格式,其中包含可选的 ref(标签或分支)。如果
source
为 ‘local’,则repo_or_dir
预期为指向本地目录的路径。- 参数
repo_or_dir (字符串) – 如果
source
为 ‘github’,则应对应于格式为repo_owner/repo_name[:ref]
的 GitHub 仓库,其中包含可选的 ref(标签或分支),例如 ‘pytorch/vision:0.10’。如果未指定ref
,则假设默认分支为main
(如果存在),否则为master
。如果source
为 ‘local’,则应指向本地目录的路径。model (字符串) – 在仓库/目录的
hubconf.py
中定义的可调用(入口点)的名称。*args (可选) – 可调用
model
的对应参数。source (字符串, 可选) – ‘github’ 或 ‘local’。指定如何解释
repo_or_dir
。默认值为 ‘github’。trust_repo (布尔值, 字符串 或 None) –
"check"
,True
,False
或None
。此参数在 v1.12 中引入,有助于确保用户只运行来自他们信任的仓库的代码。如果为
False
,系统将提示用户是否信任该仓库。如果为
True
,该仓库将被添加到信任列表中,并在加载时无需明确确认。如果为
"check"
,该仓库将与缓存中的信任仓库列表进行核对。如果不在该列表中,行为将回退到trust_repo=False
选项。如果为
None
:这将引发警告,提示用户将trust_repo
设置为False
、True
或"check"
。这仅用于向后兼容,将在 v2.0 中移除。
默认值为
None
,最终将在 v2.0 中更改为"check"
。force_reload (布尔值, 可选) – 是否强制无条件重新下载 GitHub 仓库。如果
source = 'local'
,则无效。默认值为False
。verbose (布尔值, 可选) – 如果为
False
,将静默有关命中本地缓存的消息。请注意,无法静默有关首次下载的消息。如果source = 'local'
,则无效。默认值为True
。skip_validation (布尔值, 可选) – 如果为
False
,torchhub 将检查github
参数指定的分支或提交是否属于仓库所有者。这将向 GitHub API 发出请求;您可以通过设置GITHUB_TOKEN
环境变量来指定非默认 GitHub 令牌。默认值为False
。**kwargs (可选) – 可调用
model
的对应关键字参数。
- 返回值
可调用
model
在使用给定的*args
和**kwargs
调用时的输出。
示例
>>> # from a github repo >>> repo = 'pytorch/vision' >>> model = torch.hub.load(repo, 'resnet50', weights='ResNet50_Weights.IMAGENET1K_V1') >>> # from a local directory >>> path = '/some/local/path/pytorch/vision' >>> model = torch.hub.load(path, 'resnet50', weights='ResNet50_Weights.DEFAULT')
- torch.hub.download_url_to_file(url, dst, hash_prefix=None, progress=True)[source]¶
将给定 URL 上的对象下载到本地路径。
- 参数
示例
>>> torch.hub.download_url_to_file('https://s3.amazonaws.com/pytorch/models/resnet18-5c106cde.pth', '/tmp/temporary_file')
- torch.hub.load_state_dict_from_url(url, model_dir=None, map_location=None, progress=True, check_hash=False, file_name=None, weights_only=False)[source]¶
从给定 URL 加载 Torch 序列化对象。
如果下载的文件是 zip 文件,它将自动解压缩。
如果对象已存在于 model_dir 中,则将其反序列化并返回。
model_dir
的默认值为<hub_dir>/checkpoints
,其中hub_dir
是由get_dir()
返回的目录。- 参数
url (字符串) – 要下载的对象的 URL
model_dir (str, optional) – 保存对象的目录
map_location (optional) – 指定如何重新映射存储位置的函数或字典(参见 torch.load)
progress (bool, optional) – 是否向 stderr 显示进度条。默认:True
check_hash (bool, optional) – 如果为 True,则 URL 的文件名部分应遵循命名约定
filename-<sha256>.ext
,其中<sha256>
是文件内容的 SHA256 哈希的前八位或更多位数字。哈希用于确保唯一名称并验证文件内容。默认:Falsefile_name (str, optional) – 下载文件的名称。如果未设置,将使用
url
中的文件名。weights_only (bool, optional) – 如果为 True,则仅加载权重,不加载复杂的腌制对象。建议用于不可信来源。有关更多详细信息,请参阅
load()
。
- 返回类型
示例
>>> state_dict = torch.hub.load_state_dict_from_url('https://s3.amazonaws.com/pytorch/models/resnet18-5c106cde.pth')
运行加载的模型:¶
请注意,*args
和 **kwargs
在 torch.hub.load()
中用于实例化模型。加载模型后,如何找出可以对模型执行的操作?建议的工作流程是
dir(model)
查看模型的所有可用方法。help(model.foo)
检查model.foo
接受哪些参数以运行。
为了帮助用户在不反复参考文档的情况下进行探索,我们强烈建议仓库所有者使函数帮助消息清晰简洁。包含一个最小的工作示例也很有帮助。
我的下载模型保存在哪里?¶
位置按以下顺序使用:
调用
hub.set_dir(<PATH_TO_HUB_DIR>)
$TORCH_HOME/hub
,如果设置了环境变量TORCH_HOME
。$XDG_CACHE_HOME/torch/hub
,如果设置了环境变量XDG_CACHE_HOME
。~/.cache/torch/hub
缓存逻辑¶
默认情况下,我们不会在加载对象后清理文件。Hub 默认情况下使用缓存(如果它已存在于由 get_dir()
返回的目录中)。
用户可以通过调用 hub.load(..., force_reload=True)
强制重新加载。这将删除现有的 GitHub 文件夹和下载的权重,重新初始化一个新的下载。当发布到同一分支的更新时,这很有用,用户可以随时了解最新版本。
已知限制:¶
Torch hub 通过将包导入到好像已安装一样的方式来工作。在 Python 中导入有一些副作用。例如,您可以在 Python 缓存 sys.modules
和 sys.path_importer_cache
中看到新项目,这是 Python 的正常行为。这也意味着当从不同仓库导入不同的模型时,您可能会遇到导入错误,如果仓库具有相同的子包名称(通常是一个 model
子包)。解决这些导入错误的一种方法是从 sys.modules
字典中删除有问题的子包;有关更多详细信息,请参阅 此 GitHub 问题。
这里值得一提的一个已知限制:用户不能在同一个 Python 进程中加载同一个仓库的不同分支。这就像在 Python 中安装两个同名的包一样,这是不好的。如果您尝试这样做,缓存可能会加入并给您带来惊喜。当然,在不同的进程中加载它们是完全可以的。