适用于移动设备的脚本和优化食谱¶
此食谱演示了如何将 PyTorch 模型转换为 TorchScript,该脚本可以在 iOS 和 Android 等高性能 C++ 环境中运行,以及如何优化转换后的 TorchScript 模型以进行移动部署。
简介¶
在训练 PyTorch 模型并可选地(但最好)进行量化后(有关更多详细信息,请参阅 量化食谱),在模型可以在 iOS 和 Android 应用程序中使用之前,一个必不可少的步骤是将依赖 Python 的模型转换为 TorchScript,然后可以进一步针对移动应用程序进行优化。转换为 TorchScript 可以像简单地调用一次函数一样简单,也可以像在许多不同地方更改原始模型一样复杂。
先决条件¶
PyTorch 1.6.0 或 1.7.0
转换为 TorchScript¶
有两种基本方法可以将 PyTorch 模型转换为 TorchScript,使用 trace 和 script。在某些情况下,可能还需要混合使用 trace 和 script - 有关更多信息,请参阅 此处。
使用 trace 方法¶
要在一个模型上使用 trace 方法,需要指定一个模型的示例或虚拟输入,实际输入的大小需要与示例输入的大小相同,并且模型定义不能有控制流,例如 if 或 for。这些限制的原因是,在具有示例输入的模型上运行 trace 只是用该输入调用模型的 forward 方法,并记录模型层中执行的所有操作,从而创建模型的跟踪。
import torch
dummy_input = torch.rand(1, 3, 224, 224)
torchscript_model = torch.jit.trace(model_quantized, dummy_input)
使用 script 方法¶
对于上面的示例,在下面调用 script 没有任何区别
torchscript_model = torch.jit.script(model_quantized)
但是,如果模型有一些流程控制,则 trace 无法正确记录所有可能的跟踪。例如,请查看 此处 的示例模型定义代码片段
import torch
class MyDecisionGate(torch.nn.Module):
def forward(self, x):
if x.sum() > 0:
return x
else:
return -x
x = torch.rand(3, 4)
traced_cell = torch.jit.trace(MyDecisionGate(), x)
print(traced_cell.code)
上面的代码将输出
TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can''t record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
if x.sum() > 0:
def forward(self,
x: Tensor) -> Tensor:
return x
请注意,上面“跟踪可能无法推广到其他输入”的警告意味着,如果模型有任何类型的数据相关控制流,则 trace 不是正确的答案。但是,如果我们将上面 Python 代码片段的最后两行(在代码输出之前)替换为
scripted_cell = torch.jit.script(MyDecisionGate())
print(scripted_cell.code)
如下面的 print 结果所示,脚本化的模型将涵盖所有可能的输入,从而推广到其他输入
def forward(self,
x: Tensor) -> Tensor:
_0 = bool(torch.gt(torch.sum(x, dtype=None), 0))
if _0:
_1 = x
else:
_1 = torch.neg(x)
return _1
这是另一个使用 trace 和 script 的示例 - 它转换了在 PyTorch 教程 NLP FROM SCRATCH: TRANSLATION WITH A SEQUENCE TO SEQUENCE NETWORK AND ATTENTION 中训练的模型
encoder = EncoderRNN(input_lang.n_words, hidden_size)
decoder = AttnDecoderRNN(hidden_size, output_lang.n_words)
# method 1: using trace with example inputs
encoder_input=torch.tensor([1])
encoder_hidden=torch.zeros(1, 1, hidden_size)
decoder_input1=torch.tensor([[0]])
decoder_input2=torch.zeros(1, 1, hidden_size)
decoder_input3=torch.zeros(MAX_LENGTH, hidden_size)
traced_encoder = torch.jit.trace(encoder, (encoder_input, encoder_hidden))
traced_decoder = torch.jit.trace(decoder, (decoder_input1, decoder_input2, decoder_input3))
# method 2: using script
scripted_encoder = torch.jit.script(encoder)
scripted_decoder = torch.jit.script(decoder)
那么,是否可以简单地始终使用 script 调用并将模型转换为 TorchScript?答案是否定的,因为 TorchScript 实际上是 Python 的一个子集,为了使 script 工作,PyTorch 模型定义必须仅使用该 TorchScript Python 子集的语言特性。 TorchScript 语言参考 涵盖了 TorchScript 中支持的所有内容的详细信息。下面我们将描述使用 script 方法时的一些常见错误。
修复使用 script 方法时的常见错误¶
如果您将 script 方法应用于非平凡模型,则可能会遇到几种类型的错误。查看 本教程 以获取将聊天机器人模型转换为 TorchScript 的完整示例。但在运行 script 方法时,请按照以下步骤修复常见错误
1. RuntimeError attribute lookup is not defined on python value of type¶
对于此错误,请将模型的值作为参数传递给构造函数。这是因为当在接受另一个模型作为参数的模型上调用 script 时,传递的模型实际上是 TracedModule 或 ScriptModule 类型,而不是 Module 类型,这使得在脚本化时模型属性未定义。
例如,上面教程中的 LuongAttnDecoderRNN 模块有一个属性 n_layers,而 GreedySearchDecoder 模块引用了 LuongAttnDecoderRNN 模块的 decoder 实例的 n_layers 属性,因此为了使 script 工作,GreedySearchDecoder 模块的构造函数需要从
def __init__(self, encoder, decoder):
更改为
def __init__(self, encoder, decoder, decoder_n_layers):
...
self._decoder_n_layers = decoder_n_layers
并且 GreedySearchDecoder 的 forward 方法需要引用 self._decoder_n_layers 而不是 decoder.n_layers。
2. RuntimeError python value of type ‘…’ cannot be used as a value.¶
此错误的完整消息继续显示 Perhaps it is a closed over global variable? If so, please consider passing it in as an argument or use a local variable instead.,将全局变量的值存储为模型构造函数中的属性(无需将它们添加到名为 __constants__ 的特殊列表中)。原因是全局值可以在正常的模型训练和推理中方便地使用,但在脚本化期间无法访问全局值。
例如,device 和 SOS_token 是全局变量,为了使 script 工作,需要将它们添加到 GreedySearchDecoder 的构造函数中
self._device = device
self._SOS_token = SOS_token
并在 GreedySearchDecoder 的 forward 方法中将其分别称为 self._device 和 self._SOS_token,而不是 device 和 SOS_token。
3. RuntimeError all inputs of range must be ‘…’, found Tensor (inferred) in argument¶
错误消息继续显示:add type definitions for each of the module’s forward method arguments. Because all parameters to a TorchScript function are of the `torch.Tensor type by default, you need to specifically declare the type for each parameter that is not of type ‘Tensor’. For a complete list of TorchScript-supported types, see here.
例如,GreedySearchDecoder 的 forward 方法签名需要从
def forward(self, input_seq, input_length, max_length):
更改为
def forward(self, input_seq, input_length, max_length : int):
在使用上面 trace 或 script 方法并修复可能的错误后,您应该拥有一个准备好在移动设备上优化的 TorchScript 模型。
优化 TorchScript 模型¶
只需运行以下代码片段即可优化使用 trace 和/或 script 方法生成的 TorchScript 模型
from torch.utils.mobile_optimizer import optimize_for_mobile
optimized_torchscript_model = optimize_for_mobile(torchscript_model)
然后可以保存优化后的模型并在移动应用程序中部署
optimized_torchscript_model.save("optimized_torchscript_model.pth")
默认情况下,对于 CPU 后端,optimize_for_mobile 执行以下类型的优化
Conv2D 和 BatchNorm 融合,它将 Conv2d-BatchNorm2d 折叠到 Conv2d 中;
插入和折叠预打包操作,它重写模型图以使用其预打包对应项替换 2D 卷积和线性操作。
ReLU 和 hardtanh 融合,它通过查找 ReLU/hardtanh 操作并将它们融合在一起重写图。
删除 Dropout,当训练为 false 时,从该模块中删除 dropout 节点。
Conv 打包参数提升,它将卷积打包参数移动到根模块,以便可以删除卷积结构。这减少了模型大小而不会影响数值。
对于 Vulkan 后端,`optimize_for_mobile` 执行以下类型的优化
自动 GPU 传输,它重写图,以便将输入和输出数据移入和移出 GPU 成为模型的一部分。
可以通过将优化黑名单作为参数传递给 optimize_for_mobile 来禁用优化类型。
了解更多¶
官方的 TorchScript 语言参考。
The torch.utils.mobile_optimizer API 文档。