• 文档 >
  • torch.utils.model_zoo
快捷方式

torch.utils.model_zoo

已移动到 torch.hub

torch.utils.model_zoo.load_url(url, model_dir=None, map_location=None, progress=True, check_hash=False, file_name=None, weights_only=False)[source]

加载给定 URL 处的 Torch 序列化对象。

如果下载的文件是 zip 文件,它将自动解压缩。

如果对象已存在于 model_dir 中,则会反序列化并返回。 model_dir 的默认值为 <hub_dir>/checkpoints,其中 hub_dir 是由 get_dir() 返回的目录。

参数
  • url (str) – 要下载的对象的 URL

  • model_dir (str, 可选) – 用于保存对象的目录

  • map_location (可选) – 一个函数或字典,用于指定如何重新映射存储位置(请参阅 torch.load)

  • progress (bool, 可选) – 是否向 stderr 显示进度条。默认值:True

  • check_hash (bool, 可选) – 如果为 True,则 URL 的文件名部分应遵循命名约定 filename-<sha256>.ext,其中 <sha256> 是文件内容 SHA256 哈希值的前八位或更多位数字。哈希值用于确保名称的唯一性并验证文件内容。默认值:False

  • file_name (str, 可选) – 下载文件的名称。如果未设置,将使用来自 url 的文件名。

  • weights_only (bool, 可选) – 如果为 True,则仅加载权重,而不加载复杂的 pickle 对象。建议用于不受信任的来源。有关更多详细信息,请参阅 load()

返回类型

Dict[str, Any]

示例

>>> state_dict = torch.hub.load_state_dict_from_url(
...     "https://s3.amazonaws.com/pytorch/models/resnet18-5c106cde.pth"
... )

文档

访问 PyTorch 的全面开发者文档

查看文档

教程

获取面向初学者和高级开发者的深入教程

查看教程

资源

查找开发资源并获得解答

查看资源