快捷方式

Torchscript 支持

注意

Colab 上尝试或转到结尾下载完整的示例代码。

此示例说明了 torchvision 变换 在张量图像上的 torchscript 支持。

from pathlib import Path

import matplotlib.pyplot as plt

import torch
import torch.nn as nn

import torchvision.transforms as v1
from torchvision.io import decode_image

plt.rcParams["savefig.bbox"] = 'tight'
torch.manual_seed(1)

# If you're trying to run that on Colab, you can download the assets and the
# helpers from https://github.com/pytorch/vision/tree/main/gallery/
import sys
sys.path += ["../transforms"]
from helpers import plot
ASSETS_PATH = Path('../assets')

大多数变换都支持 torchscript。对于组合变换,我们使用 torch.nn.Sequential 而不是 Compose

dog1 = decode_image(str(ASSETS_PATH / 'dog1.jpg'))
dog2 = decode_image(str(ASSETS_PATH / 'dog2.jpg'))

transforms = torch.nn.Sequential(
    v1.RandomCrop(224),
    v1.RandomHorizontalFlip(p=0.3),
)

scripted_transforms = torch.jit.script(transforms)

plot([dog1, scripted_transforms(dog1), dog2, scripted_transforms(dog2)])
plot scripted tensor transforms

警告

上面我们使用了来自 torchvision.transforms 命名空间的变换,即“v1”变换。来自 torchvision.transforms.v2 命名空间的 v2 变换是在代码中使用变换的 推荐 方法。

v2 变换也支持 torchscript,但是如果你在 v2 **类**变换上调用 torch.jit.script(),你实际上会得到其(脚本化的)v1 等价物。由于 v1 和 v2 之间的实现差异,这可能导致脚本化执行和急切执行之间的结果略有不同。

如果你确实需要 v2 变换的 torchscript 支持,**我们建议对** torchvision.transforms.v2.functional 命名空间中的**函数进行脚本化**,以避免意外情况。

下面我们现在展示如何结合图像变换和模型前向传递,同时使用 torch.jit.script 获得单个脚本化模块。

让我们定义一个 Predictor 模块,它转换输入张量,然后在其上应用 ImageNet 模型。

from torchvision.models import resnet18, ResNet18_Weights


class Predictor(nn.Module):

    def __init__(self):
        super().__init__()
        weights = ResNet18_Weights.DEFAULT
        self.resnet18 = resnet18(weights=weights, progress=False).eval()
        self.transforms = weights.transforms(antialias=True)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        with torch.no_grad():
            x = self.transforms(x)
            y_pred = self.resnet18(x)
            return y_pred.argmax(dim=1)

现在,让我们定义 Predictor 的脚本化和非脚本化实例,并将其应用于多个相同大小的张量图像

device = "cuda" if torch.cuda.is_available() else "cpu"

predictor = Predictor().to(device)
scripted_predictor = torch.jit.script(predictor).to(device)

batch = torch.stack([dog1, dog2]).to(device)

res = predictor(batch)
res_scripted = scripted_predictor(batch)
Downloading: "https://download.pytorch.org/models/resnet18-f37072fd.pth" to /root/.cache/torch/hub/checkpoints/resnet18-f37072fd.pth

我们可以验证脚本化模型和非脚本化模型的预测是否相同

import json

with open(Path('../assets') / 'imagenet_class_index.json') as labels_file:
    labels = json.load(labels_file)

for i, (pred, pred_scripted) in enumerate(zip(res, res_scripted)):
    assert pred == pred_scripted
    print(f"Prediction for Dog {i + 1}: {labels[str(pred.item())]}")
Prediction for Dog 1: ['n02113023', 'Pembroke']
Prediction for Dog 2: ['n02106662', 'German_shepherd']

由于模型是脚本化的,因此可以轻松地将其转储到磁盘上并重新使用

import tempfile

with tempfile.NamedTemporaryFile() as f:
    scripted_predictor.save(f.name)

    dumped_scripted_predictor = torch.jit.load(f.name)
    res_scripted_dumped = dumped_scripted_predictor(batch)
assert (res_scripted_dumped == res_scripted).all()

脚本的总运行时间:(0 分钟 1.504 秒)

由 Sphinx-Gallery 生成的图库

文档

访问 PyTorch 的全面开发者文档

查看文档

教程

获取面向初学者和高级开发人员的深度教程

查看教程

资源

查找开发资源并获得问题的解答

查看资源