• 文档 >
  • 创建 TorchScript 模块
快捷方式

创建 TorchScript 模块

TorchScript 是一种从 PyTorch 代码创建可序列化和可优化模型的方法。PyTorch 提供了关于如何执行此操作的详细文档 https://pytorch.ac.cn/tutorials/beginner/Intro_to_TorchScript_tutorial.html 但这里简要介绍一下关键背景信息和流程

PyTorch 程序基于 Module 构建,模块可以用于组合更高级别的模块。Modules 包含一个构造函数来设置模块、参数和子模块,以及一个 forward 函数,该函数描述了在调用模块时如何使用参数和子模块。

例如,我们可以像这样定义一个 LeNet 模块

 1import torch.nn as nn
 2import torch.nn.functional as F
 3
 4
 5class LeNetFeatExtractor(nn.Module):
 6    def __init__(self):
 7        super(LeNetFeatExtractor, self).__init__()
 8        self.conv1 = nn.Conv2d(1, 6, 3)
 9        self.conv2 = nn.Conv2d(6, 16, 3)
10
11    def forward(self, x):
12        x = F.max_pool2d(F.relu(self.conv1(x)), (2, 2))
13        x = F.max_pool2d(F.relu(self.conv2(x)), 2)
14        return x
15
16
17class LeNetClassifier(nn.Module):
18    def __init__(self):
19        super(LeNetClassifier, self).__init__()
20        self.fc1 = nn.Linear(16 * 6 * 6, 120)
21        self.fc2 = nn.Linear(120, 84)
22        self.fc3 = nn.Linear(84, 10)
23
24    def forward(self, x):
25        x = torch.flatten(x, 1)
26        x = F.relu(self.fc1(x))
27        x = F.relu(self.fc2(x))
28        x = self.fc3(x)
29        return x
30
31
32class LeNet(nn.Module):
33    def __init__(self):
34        super(LeNet, self).__init__()
35        self.feat = LeNetFeatExtractor()
36        self.classifier = LeNetClassifier()
37
38    def forward(self, x):
39        x = self.feat(x)
40        x = self.classifier(x)
41        return x

.

显然,您可能希望将如此简单的模型整合到一个模块中,但我们可以看到 PyTorch 在此处的组合性

从这里开始,有两种从 PyTorch Python 代码转到 TorchScript 代码的途径:追踪和脚本编写。

追踪会跟踪模块被调用时的执行路径,并记录发生的事情。要追踪我们的 LeNet 模块实例,我们可以使用示例输入调用 torch.jit.trace

import torch

model = LeNet()
input_data = torch.empty([1, 1, 32, 32])
traced_model = torch.jit.trace(model, input_data)

脚本实际上会用编译器检查您的代码,并生成等效的 TorchScript 程序。不同之处在于,由于追踪会跟踪模块的执行,因此它无法获取控制流实例。通过从 Python 代码入手,编译器可以包含这些组件。我们可以通过调用 torch.jit.script 在我们的 LeNet 模块上运行脚本编译器

import torch

model = LeNet()
script_model = torch.jit.script(model)

使用其中一种路径而不是另一种路径是有原因的,PyTorch 文档中有关于如何选择的信息。从 Torch-TensorRT 的角度来看,对追踪模块有更好的支持(即您的模块更可能编译),因为它不包含完整编程语言的所有复杂性,尽管两种路径都受支持。

在脚本编写或追踪您的模块后,您将获得一个 TorchScript 模块。这包含用于运行模块的代码和参数,这些代码和参数存储在 Torch-TensorRT 可以使用的中间表示中。

以下是 LeNet 追踪模块 IR 的外观

graph(%self.1 : __torch__.___torch_mangle_10.LeNet,
    %input.1 : Float(1, 1, 32, 32)):
    %129 : __torch__.___torch_mangle_9.LeNetClassifier = prim::GetAttr[name="classifier"](%self.1)
    %119 : __torch__.___torch_mangle_5.LeNetFeatExtractor = prim::GetAttr[name="feat"](%self.1)
    %137 : Tensor = prim::CallMethod[name="forward"](%119, %input.1)
    %138 : Tensor = prim::CallMethod[name="forward"](%129, %137)
    return (%138)

以及 LeNet 脚本模块 IR

graph(%self : __torch__.LeNet,
    %x.1 : Tensor):
    %2 : __torch__.LeNetFeatExtractor = prim::GetAttr[name="feat"](%self)
    %x.3 : Tensor = prim::CallMethod[name="forward"](%2, %x.1) # x.py:38:12
    %5 : __torch__.LeNetClassifier = prim::GetAttr[name="classifier"](%self)
    %x.5 : Tensor = prim::CallMethod[name="forward"](%5, %x.3) # x.py:39:12
    return (%x.5)

您可以看到 IR 保留了我们在 Python 代码中拥有的模块结构。

在 Python 中使用 TorchScript

TorchScript 模块的运行方式与您运行普通 PyTorch 模块的方式相同。您可以使用 forward 方法或仅调用模块 torch_script_module(in_tensor) 来运行前向传递。JIT 编译器将即时编译和优化模块,然后返回结果。

将 TorchScript 模块保存到磁盘

对于追踪或脚本模块,您都可以使用以下命令将模块保存到磁盘

import torch

model = LeNet()
script_model = torch.jit.script(model)
script_model.save("lenet_scripted.ts")

文档

访问全面的 PyTorch 开发者文档

查看文档

教程

获取面向初学者和高级开发者的深入教程

查看教程

资源

查找开发资源并获得问题解答

查看资源