• 教程 >
  • TorchScript 中的模型冻结
快捷方式

TorchScript 中的模型冻结

在本教程中,我们将介绍 TorchScript 中模型冻结的语法。冻结是将 Pytorch 模块的参数和属性值内联到 TorchScript 内部表示的过程。参数和属性值被视为最终值,并且在生成的冻结模块中不能修改它们。

基本语法

可以使用以下 API 调用模型冻结

torch.jit.freeze(mod : ScriptModule, names : str[]) -> ScriptModule

请注意,输入模块可以是脚本化或跟踪的结果。请参阅 https://pytorch.ac.cn/tutorials/beginner/Intro_to_TorchScript_tutorial.html

接下来,我们将使用一个示例演示冻结的工作原理

import torch, time

class Net(torch.nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = torch.nn.Conv2d(1, 32, 3, 1)
        self.conv2 = torch.nn.Conv2d(32, 64, 3, 1)
        self.dropout1 = torch.nn.Dropout2d(0.25)
        self.dropout2 = torch.nn.Dropout2d(0.5)
        self.fc1 = torch.nn.Linear(9216, 128)
        self.fc2 = torch.nn.Linear(128, 10)

    def forward(self, x):
        x = self.conv1(x)
        x = torch.nn.functional.relu(x)
        x = self.conv2(x)
        x = torch.nn.functional.max_pool2d(x, 2)
        x = self.dropout1(x)
        x = torch.flatten(x, 1)
        x = self.fc1(x)
        x = torch.nn.functional.relu(x)
        x = self.dropout2(x)
        x = self.fc2(x)
        output = torch.nn.functional.log_softmax(x, dim=1)
        return output

    @torch.jit.export
    def version(self):
        return 1.0

net = torch.jit.script(Net())
fnet = torch.jit.freeze(net)

print(net.conv1.weight.size())
print(net.conv1.bias)

try:
    print(fnet.conv1.bias)
    # without exception handling, prints:
    # RuntimeError: __torch__.z.___torch_mangle_3.Net does not have a field
    # with name 'conv1'
except RuntimeError:
    print("field 'conv1' is inlined. It does not exist in 'fnet'")

try:
    fnet.version()
    # without exception handling, prints:
    # RuntimeError: __torch__.z.___torch_mangle_3.Net does not have a field
    # with name 'version'
except RuntimeError:
    print("method 'version' is not deleted in fnet. Only 'forward' is preserved")

fnet2 = torch.jit.freeze(net, ["version"])

print(fnet2.version())

B=1
warmup = 1
iter = 1000
input = torch.rand(B, 1,28, 28)

start = time.time()
for i in range(warmup):
    net(input)
end = time.time()
print("Scripted - Warm up time: {0:7.4f}".format(end-start), flush=True)

start = time.time()
for i in range(warmup):
    fnet(input)
end = time.time()
print("Frozen   - Warm up time: {0:7.4f}".format(end-start), flush=True)

start = time.time()
for i in range(iter):
    input = torch.rand(B, 1,28, 28)
    net(input)
end = time.time()
print("Scripted - Inference: {0:5.2f}".format(end-start), flush=True)

start = time.time()
for i in range(iter):
    input = torch.rand(B, 1,28, 28)
    fnet2(input)
end = time.time()
print("Frozen    - Inference time: {0:5.2f}".format(end-start), flush =True)

在我的机器上,我测量了时间

  • 脚本化 - 预热时间:0.0107

  • 冻结 - 预热时间:0.0048

  • 脚本化 - 推理:1.35

  • 冻结 - 推理时间:1.17

在我们的示例中,预热时间测量前两次运行。冻结模型比脚本化模型快 50%。在一些更复杂的模型上,我们观察到预热时间的加速甚至更高。冻结实现了这种加速,因为它在启动前几次运行时执行了 TorchScript 需要执行的一些工作。

推理时间测量模型预热后推理执行时间。虽然我们观察到执行时间的显着变化,但冻结模型通常比脚本化模型快约 15%。当输入更大时,我们观察到加速较小,因为执行主要由张量运算决定。

结论

在本教程中,我们学习了模型冻结。冻结是一种用于优化模型以进行推理的有用技术,它还可以显着减少 TorchScript 预热时间。

脚本的总运行时间:(0 分钟 0.000 秒)

由 Sphinx-Gallery 生成的图库

文档

访问 PyTorch 的全面开发者文档

查看文档

教程

获取针对初学者和高级开发人员的深入教程

查看教程

资源

查找开发资源并获得解答

查看资源