注意
点击此处下载完整的示例代码
使用 TorchScript 部署 Seq2Seq 模型¶
创建日期:2018 年 9 月 17 日 | 最后更新:2024 年 12 月 02 日 | 最后验证:2024 年 11 月 05 日
警告
TorchScript 已不再处于积极开发阶段。
本教程将逐步介绍使用 TorchScript API 将 sequence-to-sequence 模型迁移到 TorchScript 的过程。我们将转换的模型是来自聊天机器人教程的聊天机器人模型。你可以将本教程视为聊天机器人教程的“第二部分”,并部署你自己的预训练模型;或者你可以从本文档开始,使用我们托管的预训练模型。在后一种情况下,你可以参考原始的聊天机器人教程了解数据预处理、模型理论定义和模型训练的详细信息。
什么是 TorchScript?¶
在基于深度学习项目的研究和开发阶段,与 PyTorch 这样的 eager、命令式接口进行交互是有利的。这使用户能够编写熟悉、惯用的 Python 代码,允许使用 Python 数据结构、控制流操作、print 语句和调试工具。尽管 eager 接口是研究和实验应用的有利工具,但当需要在生产环境中部署模型时,拥有一个基于图的模型表示就非常有益了。延迟的图表示允许进行乱序执行等优化,并能够针对高度优化的硬件架构。此外,基于图的表示也支持与框架无关的模型导出。PyTorch 提供了将 eager 模式代码逐步转换为 TorchScript 的机制,TorchScript 是 Python 的一个静态可分析和可优化的子集,Torch 使用它来独立于 Python 运行时表示深度学习程序。
用于将 eager 模式的 PyTorch 程序转换为 TorchScript 的 API 位于 torch.jit
模块中。该模块有两种核心模式用于将 eager 模式模型转换为 TorchScript 图表示:追踪 (tracing) 和 脚本化 (scripting)。torch.jit.trace
函数接受一个 module 或函数以及一组示例输入。然后它将示例输入通过该函数或 module 运行,同时追踪遇到的计算步骤,并输出一个执行追踪到的操作的基于图的函数。追踪 (Tracing) 非常适合不涉及数据依赖控制流的简单 module 和函数,例如标准的卷积神经网络。然而,如果追踪包含数据依赖 if 语句和循环的函数,则只会记录示例输入所经过的执行路径上的操作。换句话说,控制流本身不会被捕获。为了转换包含数据依赖控制流的 module 和函数,提供了脚本化 (scripting) 机制。torch.jit.script
函数/装饰器接受一个 module 或函数,不需要示例输入。脚本化随后会显式地将 module 或函数代码转换为 TorchScript,包括所有的控制流。使用脚本化有一个需要注意的地方,它只支持 Python 的一个子集,因此你可能需要重写代码以使其与 TorchScript 语法兼容。
有关支持的所有功能的详细信息,请参阅 TorchScript 语言参考。为了提供最大的灵活性,你也可以将追踪和脚本化模式混合使用来表示整个程序,并且这些技术可以逐步应用。

致谢¶
本教程的灵感来自以下来源
Yuan-Kuei Wu 的 pytorch-chatbot 实现:https://github.com/ywk991112/pytorch-chatbot
Sean Robertson 的 practical-pytorch seq2seq-translation 示例:https://github.com/spro/practical-pytorch/tree/master/seq2seq-translation
FloydHub 的 Cornell Movie Corpus 预处理代码:https://github.com/floydhub/textutil-preprocess-cornell-movie-corpus
准备环境¶
首先,我们将导入所需的模块并设置一些常量。如果你打算使用自己的模型,请确保 MAX_LENGTH
常量设置正确。提醒一下,此常量定义了训练期间允许的最大句子长度以及模型能够产生的最大输出长度。
import torch
import torch.nn as nn
import torch.nn.functional as F
import re
import os
import unicodedata
import numpy as np
device = torch.device("cpu")
MAX_LENGTH = 10 # Maximum sentence length
# Default word tokens
PAD_token = 0 # Used for padding short sentences
SOS_token = 1 # Start-of-sentence token
EOS_token = 2 # End-of-sentence token
模型概览¶
如前所述,我们使用的模型是 sequence-to-sequence (seq2seq) 模型。当输入是变长序列,并且输出也是变长序列,且输出不一定与输入一一对应时,就会使用此类模型。Seq2seq 模型由两个协同工作的循环神经网络 (RNN) 组成:一个编码器 (encoder) 和一个解码器 (decoder)。

图片来源:https://jeddy92.github.io/JEddy92.github.io/ts_seq2seq_intro/
编码器¶
编码器 RNN 逐个 token(例如单词)遍历输入句子,在每个时间步输出一个“输出”向量和一个“隐藏状态”向量。然后隐藏状态向量被传递到下一个时间步,而输出向量被记录下来。编码器将它在序列中每个点看到的上下文转换为高维空间中的一组点,解码器将使用这些点为给定任务生成有意义的输出。
数据处理¶
尽管我们的模型概念上处理 token 序列,但实际上,它们像所有机器学习模型一样处理数字。在这种情况下,模型词汇表中的每个单词(在训练前建立)都被映射到一个整数索引。我们使用一个 Voc
对象来包含从单词到索引的映射,以及词汇表中的单词总数。我们稍后在运行模型之前加载该对象。
此外,为了能够运行评估,我们必须提供一个工具来处理我们的字符串输入。normalizeString
函数将字符串中的所有字符转换为小写并移除所有非字母字符。indexesFromSentence
函数接受一个单词组成的句子,并返回相应的单词索引序列。
class Voc:
def __init__(self, name):
self.name = name
self.trimmed = False
self.word2index = {}
self.word2count = {}
self.index2word = {PAD_token: "PAD", SOS_token: "SOS", EOS_token: "EOS"}
self.num_words = 3 # Count SOS, EOS, PAD
def addSentence(self, sentence):
for word in sentence.split(' '):
self.addWord(word)
def addWord(self, word):
if word not in self.word2index:
self.word2index[word] = self.num_words
self.word2count[word] = 1
self.index2word[self.num_words] = word
self.num_words += 1
else:
self.word2count[word] += 1
# Remove words below a certain count threshold
def trim(self, min_count):
if self.trimmed:
return
self.trimmed = True
keep_words = []
for k, v in self.word2count.items():
if v >= min_count:
keep_words.append(k)
print('keep_words {} / {} = {:.4f}'.format(
len(keep_words), len(self.word2index), len(keep_words) / len(self.word2index)
))
# Reinitialize dictionaries
self.word2index = {}
self.word2count = {}
self.index2word = {PAD_token: "PAD", SOS_token: "SOS", EOS_token: "EOS"}
self.num_words = 3 # Count default tokens
for word in keep_words:
self.addWord(word)
# Lowercase and remove non-letter characters
def normalizeString(s):
s = s.lower()
s = re.sub(r"([.!?])", r" \1", s)
s = re.sub(r"[^a-zA-Z.!?]+", r" ", s)
return s
# Takes string sentence, returns sentence of word indexes
def indexesFromSentence(voc, sentence):
return [voc.word2index[word] for word in sentence.split(' ')] + [EOS_token]
定义编码器¶
我们使用 torch.nn.GRU
模块实现编码器的 RNN,我们将一批句子(词嵌入向量)输入给它,它在内部逐 token 遍历句子并计算隐藏状态。我们将此模块初始化为双向的,这意味着我们有两个独立的 GRU:一个按时间顺序遍历序列,另一个按逆序遍历。我们最终返回这两个 GRU 输出的总和。由于我们的模型是使用批处理进行训练的,因此 EncoderRNN
模型的 forward
函数需要一个填充的输入批次。为了对变长句子进行批处理,我们允许句子中最多包含 MAX_LENGTH 个 token,并且批次中所有 token 少于 MAX_LENGTH 的句子都在末尾用我们专用的 PAD_token 进行填充。要在 PyTorch RNN 模块中使用填充批次,我们必须将前向传播调用用 torch.nn.utils.rnn.pack_padded_sequence
和 torch.nn.utils.rnn.pad_packed_sequence
数据转换进行包装。请注意,forward
函数还接受一个 input_lengths
列表,其中包含批次中每个句子的长度。这个输入用于 torch.nn.utils.rnn.pack_padded_sequence
函数进行填充时使用。
TorchScript 注意事项:¶
由于编码器的 forward
函数不包含任何数据依赖的控制流,我们将使用追踪 (tracing) 将其转换为 script 模式。在追踪 module 时,我们可以保持 module 定义不变。我们将在本文档的末尾、运行评估之前初始化所有模型。
class EncoderRNN(nn.Module):
def __init__(self, hidden_size, embedding, n_layers=1, dropout=0):
super(EncoderRNN, self).__init__()
self.n_layers = n_layers
self.hidden_size = hidden_size
self.embedding = embedding
# Initialize GRU; the ``input_size`` and ``hidden_size`` parameters are both set to 'hidden_size'
# because our input size is a word embedding with number of features == hidden_size
self.gru = nn.GRU(hidden_size, hidden_size, n_layers,
dropout=(0 if n_layers == 1 else dropout), bidirectional=True)
def forward(self, input_seq, input_lengths, hidden=None):
# type: (Tensor, Tensor, Optional[Tensor]) -> Tuple[Tensor, Tensor]
# Convert word indexes to embeddings
embedded = self.embedding(input_seq)
# Pack padded batch of sequences for RNN module
packed = torch.nn.utils.rnn.pack_padded_sequence(embedded, input_lengths)
# Forward pass through GRU
outputs, hidden = self.gru(packed, hidden)
# Unpack padding
outputs, _ = torch.nn.utils.rnn.pad_packed_sequence(outputs)
# Sum bidirectional GRU outputs
outputs = outputs[:, :, :self.hidden_size] + outputs[:, : ,self.hidden_size:]
# Return output and final hidden state
return outputs, hidden
定义解码器的注意力模块¶
接下来,我们将定义我们的注意力模块 (Attn
)。请注意,此模块将用作我们解码器模型的子模块。Luong 等人考虑了各种“打分函数”,这些函数接受当前解码器 RNN 输出和整个编码器输出,并返回注意力“能量”。这个注意力能量张量的大小与编码器输出相同,两者最终相乘,得到一个加权张量,其中最大值表示在解码的特定时间步查询句子中最重要的部分。
# Luong attention layer
class Attn(nn.Module):
def __init__(self, method, hidden_size):
super(Attn, self).__init__()
self.method = method
if self.method not in ['dot', 'general', 'concat']:
raise ValueError(self.method, "is not an appropriate attention method.")
self.hidden_size = hidden_size
if self.method == 'general':
self.attn = nn.Linear(self.hidden_size, hidden_size)
elif self.method == 'concat':
self.attn = nn.Linear(self.hidden_size * 2, hidden_size)
self.v = nn.Parameter(torch.FloatTensor(hidden_size))
def dot_score(self, hidden, encoder_output):
return torch.sum(hidden * encoder_output, dim=2)
def general_score(self, hidden, encoder_output):
energy = self.attn(encoder_output)
return torch.sum(hidden * energy, dim=2)
def concat_score(self, hidden, encoder_output):
energy = self.attn(torch.cat((hidden.expand(encoder_output.size(0), -1, -1), encoder_output), 2)).tanh()
return torch.sum(self.v * energy, dim=2)
def forward(self, hidden, encoder_outputs):
# Calculate the attention weights (energies) based on the given method
if self.method == 'general':
attn_energies = self.general_score(hidden, encoder_outputs)
elif self.method == 'concat':
attn_energies = self.concat_score(hidden, encoder_outputs)
elif self.method == 'dot':
attn_energies = self.dot_score(hidden, encoder_outputs)
# Transpose max_length and batch_size dimensions
attn_energies = attn_energies.t()
# Return the softmax normalized probability scores (with added dimension)
return F.softmax(attn_energies, dim=1).unsqueeze(1)
定义解码器¶
与 EncoderRNN
类似,我们使用 torch.nn.GRU
模块作为解码器的 RNN。然而,这次我们使用单向 GRU。重要的是要注意,与编码器不同,我们将逐个单词输入解码器 RNN。我们首先获取当前单词的嵌入并应用 dropout。接下来,我们将嵌入和最后一个隐藏状态前向传播到 GRU,并获得当前的 GRU 输出和隐藏状态。然后我们使用 Attn
模块作为一层来获得注意力权重,我们将其乘以编码器的输出,得到我们的注意力加权编码器输出。我们将这个注意力加权编码器输出用作我们的 context
张量,它代表一个加权和,指示应注意编码器输出的哪些部分。从这里,我们使用线性层和 softmax 归一化来选择输出序列中的下一个单词。
# TorchScript Notes:
# ~~~~~~~~~~~~~~~~~~~~~~
#
# Similarly to the ``EncoderRNN``, this module does not contain any
# data-dependent control flow. Therefore, we can once again use
# **tracing** to convert this model to TorchScript after it
# is initialized and its parameters are loaded.
#
class LuongAttnDecoderRNN(nn.Module):
def __init__(self, attn_model, embedding, hidden_size, output_size, n_layers=1, dropout=0.1):
super(LuongAttnDecoderRNN, self).__init__()
# Keep for reference
self.attn_model = attn_model
self.hidden_size = hidden_size
self.output_size = output_size
self.n_layers = n_layers
self.dropout = dropout
# Define layers
self.embedding = embedding
self.embedding_dropout = nn.Dropout(dropout)
self.gru = nn.GRU(hidden_size, hidden_size, n_layers, dropout=(0 if n_layers == 1 else dropout))
self.concat = nn.Linear(hidden_size * 2, hidden_size)
self.out = nn.Linear(hidden_size, output_size)
self.attn = Attn(attn_model, hidden_size)
def forward(self, input_step, last_hidden, encoder_outputs):
# Note: we run this one step (word) at a time
# Get embedding of current input word
embedded = self.embedding(input_step)
embedded = self.embedding_dropout(embedded)
# Forward through unidirectional GRU
rnn_output, hidden = self.gru(embedded, last_hidden)
# Calculate attention weights from the current GRU output
attn_weights = self.attn(rnn_output, encoder_outputs)
# Multiply attention weights to encoder outputs to get new "weighted sum" context vector
context = attn_weights.bmm(encoder_outputs.transpose(0, 1))
# Concatenate weighted context vector and GRU output using Luong eq. 5
rnn_output = rnn_output.squeeze(0)
context = context.squeeze(1)
concat_input = torch.cat((rnn_output, context), 1)
concat_output = torch.tanh(self.concat(concat_input))
# Predict next word using Luong eq. 6
output = self.out(concat_output)
output = F.softmax(output, dim=1)
# Return output and final hidden state
return output, hidden
定义评估¶
贪婪搜索解码器¶
正如在聊天机器人教程中所述,我们使用一个 GreedySearchDecoder
module 来辅助实际的解码过程。此 module 将训练好的编码器和解码器模型作为属性,并驱动编码输入句子(一个单词索引向量)以及迭代地逐个单词(单词索引)解码输出响应序列的过程。
编码输入序列很简单:只需将整个序列张量及其相应的长度向量前向传播到 encoder
。重要的是要注意,此 module 一次只处理一个输入序列,而不是一批序列。因此,当使用常量 1 声明张量大小时,这对应于批大小为 1。要解码给定的解码器输出,我们必须迭代地通过解码器模型运行前向传播,解码器会输出 softmax 分数,这些分数对应于序列中每个单词成为正确下一个单词的概率。我们将 decoder_input
初始化为一个包含 SOS_token 的张量。每次通过 decoder
后,我们都会贪婪地将 softmax 概率最高的单词附加到 decoded_words
列表中。我们还将此单词用作下一次迭代的 decoder_input
。解码过程在 decoded_words
列表达到 MAX_LENGTH 长度或预测的单词是 EOS_token 时终止。
TorchScript 注意事项:¶
此 module 的 forward
方法在逐个单词解码输出序列时,涉及迭代 \([0, max\_length)\) 的范围。因此,我们应该使用脚本化 (scripting) 将此 module 转换为 TorchScript。与可以追踪的编码器和解码器模型不同,我们必须对 GreedySearchDecoder
module 进行一些必要的修改,以便无误地初始化对象。换句话说,我们必须确保我们的 module 遵循 TorchScript 机制的规则,并且不使用 TorchScript 支持的 Python 子集之外的任何语言特性。
为了了解可能需要进行的一些修改,我们将回顾聊天机器人教程中的 GreedySearchDecoder
实现与下面单元格中使用的实现之间的差异。请注意,标红的行是从原始实现中移除的行,标绿的行是新增的行。

变更:¶
将
decoder_n_layers
添加到构造函数参数中此变更源于我们将传递给此 module 的编码器和解码器模型将是
TracedModule
的子类(而不是Module
)。因此,我们无法通过decoder.n_layers
访问解码器的层数。相反,我们对此进行了规划,并在 module 构建期间传入此值。
将新属性存储为常量
在原始实现中,我们可以在
GreedySearchDecoder
的forward
方法中自由使用来自周围(全局)作用域的变量。然而,现在我们使用脚本化,就没有这种自由了,因为脚本化假定我们不一定能保留 Python 对象,尤其是在导出时。一个简单的解决方案是在构造函数中将这些全局作用域的值作为属性存储到 module 中,并将它们添加到名为__constants__
的特殊列表中,以便在forward
方法中构建图时可以将它们用作字面值。这种用法的一个示例在 NEW 行 19,在那里我们没有使用全局值device
和SOS_token
,而是使用了我们的常量属性self._device
和self._SOS_token
。
强制指定
forward
方法参数的类型更改
decoder_input
的初始化方式在原始实现中,我们使用
torch.LongTensor([[SOS_token]])
初始化了我们的decoder_input
张量。在脚本化时,不允许像这样以字面方式初始化张量。相反,我们可以使用torch.ones
等显式的torch函数来初始化张量。在这种情况下,通过将 1 乘以存储在常量self._SOS_token
中的 SOS_token 值,我们可以轻松复制标量decoder_input
张量。
class GreedySearchDecoder(nn.Module):
def __init__(self, encoder, decoder, decoder_n_layers):
super(GreedySearchDecoder, self).__init__()
self.encoder = encoder
self.decoder = decoder
self._device = device
self._SOS_token = SOS_token
self._decoder_n_layers = decoder_n_layers
__constants__ = ['_device', '_SOS_token', '_decoder_n_layers']
def forward(self, input_seq : torch.Tensor, input_length : torch.Tensor, max_length : int):
# Forward input through encoder model
encoder_outputs, encoder_hidden = self.encoder(input_seq, input_length)
# Prepare encoder's final hidden layer to be first hidden input to the decoder
decoder_hidden = encoder_hidden[:self._decoder_n_layers]
# Initialize decoder input with SOS_token
decoder_input = torch.ones(1, 1, device=self._device, dtype=torch.long) * self._SOS_token
# Initialize tensors to append decoded words to
all_tokens = torch.zeros([0], device=self._device, dtype=torch.long)
all_scores = torch.zeros([0], device=self._device)
# Iteratively decode one word token at a time
for _ in range(max_length):
# Forward pass through decoder
decoder_output, decoder_hidden = self.decoder(decoder_input, decoder_hidden, encoder_outputs)
# Obtain most likely word token and its softmax score
decoder_scores, decoder_input = torch.max(decoder_output, dim=1)
# Record token and score
all_tokens = torch.cat((all_tokens, decoder_input), dim=0)
all_scores = torch.cat((all_scores, decoder_scores), dim=0)
# Prepare current token to be next decoder input (add a dimension)
decoder_input = torch.unsqueeze(decoder_input, 0)
# Return collections of word tokens and scores
return all_tokens, all_scores
评估输入¶
接下来,我们定义一些用于评估输入的函数。evaluate
函数接受一个标准化的字符串句子,将其处理成一个对应词索引的张量(批量大小为 1),并将此张量传递给一个名为 searcher
的 GreedySearchDecoder
实例来处理编码/解码过程。searcher 返回输出词索引向量和一个对应的分数张量,该张量包含每个解码词token的 softmax 分数。最后一步是使用 voc.index2word
将每个词索引转换回其字符串表示形式。
我们还定义了两个用于评估输入句子的函数。evaluateInput
函数提示用户输入,并评估它。它将继续要求输入,直到用户输入“q”或“quit”。
evaluateExample
函数仅接受一个字符串输入句子作为参数,对其进行标准化,评估它,并打印响应。
def evaluate(searcher, voc, sentence, max_length=MAX_LENGTH):
### Format input sentence as a batch
# words -> indexes
indexes_batch = [indexesFromSentence(voc, sentence)]
# Create lengths tensor
lengths = torch.tensor([len(indexes) for indexes in indexes_batch])
# Transpose dimensions of batch to match models' expectations
input_batch = torch.LongTensor(indexes_batch).transpose(0, 1)
# Use appropriate device
input_batch = input_batch.to(device)
lengths = lengths.to(device)
# Decode sentence with searcher
tokens, scores = searcher(input_batch, lengths, max_length)
# indexes -> words
decoded_words = [voc.index2word[token.item()] for token in tokens]
return decoded_words
# Evaluate inputs from user input (``stdin``)
def evaluateInput(searcher, voc):
input_sentence = ''
while(1):
try:
# Get input sentence
input_sentence = input('> ')
# Check if it is quit case
if input_sentence == 'q' or input_sentence == 'quit': break
# Normalize sentence
input_sentence = normalizeString(input_sentence)
# Evaluate sentence
output_words = evaluate(searcher, voc, input_sentence)
# Format and print response sentence
output_words[:] = [x for x in output_words if not (x == 'EOS' or x == 'PAD')]
print('Bot:', ' '.join(output_words))
except KeyError:
print("Error: Encountered unknown word.")
# Normalize input sentence and call ``evaluate()``
def evaluateExample(sentence, searcher, voc):
print("> " + sentence)
# Normalize sentence
input_sentence = normalizeString(sentence)
# Evaluate sentence
output_words = evaluate(searcher, voc, input_sentence)
output_words[:] = [x for x in output_words if not (x == 'EOS' or x == 'PAD')]
print('Bot:', ' '.join(output_words))
加载预训练参数¶
好的,让我们加载模型!
使用托管模型¶
加载托管模型
在此处下载模型 here。
将
loadFilename
变量设置为下载的检查点文件的路径。保持
checkpoint = torch.load(loadFilename)
行不注释,因为托管模型是在 CPU 上训练的。
使用您自己的模型¶
加载您自己的预训练模型
将
loadFilename
变量设置为您希望加载的检查点文件的路径。请注意,如果您遵循了聊天机器人教程中保存模型的约定,这可能涉及更改model_name
、encoder_n_layers
、decoder_n_layers
、hidden_size
和checkpoint_iter
(因为这些值用于模型路径中)。如果您在 CPU 上训练了模型,请确保您使用
checkpoint = torch.load(loadFilename)
行打开检查点。如果您在 GPU 上训练了模型并且在 CPU 上运行本教程,请取消注释checkpoint = torch.load(loadFilename, map_location=torch.device('cpu'))
行。
TorchScript 注解:¶
请注意,我们像往常一样初始化并将参数加载到我们的编码器和解码器模型中。如果您对模型的某些部分使用追踪模式(torch.jit.trace
),则**在**追踪模型**之前**,您必须调用 .to(device)
来设置模型的设备选项,并调用 .eval()
来将 dropout 层设置为测试模式。TracedModule 对象不继承 to
或 eval
方法。由于在本教程中,我们仅使用脚本化而不是追踪,我们只需在进行评估之前执行此操作(这与我们在即时模式下通常所做的相同)。
save_dir = os.path.join("data", "save")
corpus_name = "cornell movie-dialogs corpus"
# Configure models
model_name = 'cb_model'
attn_model = 'dot'
#attn_model = 'general'``
#attn_model = 'concat'
hidden_size = 500
encoder_n_layers = 2
decoder_n_layers = 2
dropout = 0.1
batch_size = 64
# If you're loading your own model
# Set checkpoint to load from
checkpoint_iter = 4000
从检查点加载的示例代码
loadFilename = os.path.join(save_dir, model_name, corpus_name,
'{}-{}_{}'.format(encoder_n_layers, decoder_n_layers, hidden_size),
'{}_checkpoint.tar'.format(checkpoint_iter))
# If you're loading the hosted model
loadFilename = 'data/4000_checkpoint.tar'
# Load model
# Force CPU device options (to match tensors in this tutorial)
checkpoint = torch.load(loadFilename, map_location=torch.device('cpu'))
encoder_sd = checkpoint['en']
decoder_sd = checkpoint['de']
encoder_optimizer_sd = checkpoint['en_opt']
decoder_optimizer_sd = checkpoint['de_opt']
embedding_sd = checkpoint['embedding']
voc = Voc(corpus_name)
voc.__dict__ = checkpoint['voc_dict']
print('Building encoder and decoder ...')
# Initialize word embeddings
embedding = nn.Embedding(voc.num_words, hidden_size)
embedding.load_state_dict(embedding_sd)
# Initialize encoder & decoder models
encoder = EncoderRNN(hidden_size, embedding, encoder_n_layers, dropout)
decoder = LuongAttnDecoderRNN(attn_model, embedding, hidden_size, voc.num_words, decoder_n_layers, dropout)
# Load trained model parameters
encoder.load_state_dict(encoder_sd)
decoder.load_state_dict(decoder_sd)
# Use appropriate device
encoder = encoder.to(device)
decoder = decoder.to(device)
# Set dropout layers to ``eval`` mode
encoder.eval()
decoder.eval()
print('Models built and ready to go!')
Building encoder and decoder ...
Models built and ready to go!
将模型转换为 TorchScript¶
编码器¶
如前所述,要将编码器模型转换为 TorchScript,我们使用**脚本化**。编码器模型接受输入序列和相应的长度张量。因此,我们创建了一个示例输入序列张量 test_seq
,其大小合适 (MAX_LENGTH, 1),包含合适范围内的数字 \[0, voc.num\_words)\],并且类型合适 (int64)。我们还创建了一个 test_seq_length
标量,它实际包含与 test_seq
中有多少词相对应的值。下一步是使用 torch.jit.trace
函数追踪模型。请注意,我们传递的第一个参数是要追踪的模块,第二个是用于模块 forward
方法的参数元组。
解码器¶
我们对解码器执行与对编码器相同的追踪过程。请注意,我们在 traced_encoder 的一组随机输入上调用 forward,以获取解码器所需的输出。这不是必需的,因为我们也可以简单地构造一个具有正确形状、类型和值范围的张量。这种方法是可行的,因为在我们的情况下,我们对张量的值没有任何限制,我们没有任何可能因超出范围的输入而导致故障的操作。
GreedySearchDecoder¶
回想一下,由于存在数据依赖的控制流,我们对 searcher 模块进行了脚本化。在脚本化的情况下,我们进行必要的语言更改,以确保实现符合 TorchScript。我们初始化脚本化的 searcher 的方式与我们初始化未脚本化的变体的方式相同。
### Compile the whole greedy search model to TorchScript model
# Create artificial inputs
test_seq = torch.LongTensor(MAX_LENGTH, 1).random_(0, voc.num_words).to(device)
test_seq_length = torch.LongTensor([test_seq.size()[0]]).to(device)
# Trace the model
traced_encoder = torch.jit.trace(encoder, (test_seq, test_seq_length))
### Convert decoder model
# Create and generate artificial inputs
test_encoder_outputs, test_encoder_hidden = traced_encoder(test_seq, test_seq_length)
test_decoder_hidden = test_encoder_hidden[:decoder.n_layers]
test_decoder_input = torch.LongTensor(1, 1).random_(0, voc.num_words)
# Trace the model
traced_decoder = torch.jit.trace(decoder, (test_decoder_input, test_decoder_hidden, test_encoder_outputs))
### Initialize searcher module by wrapping ``torch.jit.script`` call
scripted_searcher = torch.jit.script(GreedySearchDecoder(traced_encoder, traced_decoder, decoder.n_layers))
/var/lib/ci-user/.local/lib/python3.10/site-packages/torch/jit/_trace.py:165: UserWarning:
The .grad attribute of a Tensor that is not a leaf Tensor is being accessed. Its .grad attribute won't be populated during autograd.backward(). If you indeed want the .grad field to be populated for a non-leaf Tensor, use .retain_grad() on the non-leaf Tensor. If you access the non-leaf Tensor by mistake, make sure you access the leaf Tensor instead. See github.com/pytorch/pytorch/pull/30531 for more informations. (Triggered internally at /pytorch/build/aten/src/ATen/core/TensorBody.h:489.)
打印图¶
现在我们的模型已采用 TorchScript 形式,我们可以打印每个模型的图,以确保我们适当地捕获了计算图。由于 TorchScript 允许我们递归地编译整个模型层次结构,并将 encoder
和 decoder
图内联到单个图中,我们只需打印 scripted_searcher 图即可。
print('scripted_searcher graph:\n', scripted_searcher.graph)
scripted_searcher graph:
graph(%self : __torch__.GreedySearchDecoder,
%input_seq.1 : Tensor,
%input_length.1 : Tensor,
%max_length.1 : int):
%53 : bool = prim::Constant[value=0]()
%42 : bool = prim::Constant[value=1]() # /var/lib/workspace/beginner_source/deploy_seq2seq_hybrid_frontend_tutorial.py:563:8
%18 : int = prim::Constant[value=4]() # /var/lib/workspace/beginner_source/deploy_seq2seq_hybrid_frontend_tutorial.py:558:68
%17 : Device = prim::Constant[value="cpu"]() # /var/lib/workspace/beginner_source/deploy_seq2seq_hybrid_frontend_tutorial.py:558:48
%14 : NoneType = prim::Constant()
%12 : int = prim::Constant[value=2]() # /var/lib/workspace/beginner_source/deploy_seq2seq_hybrid_frontend_tutorial.py:556:41
%16 : int = prim::Constant[value=1]() # /var/lib/workspace/beginner_source/deploy_seq2seq_hybrid_frontend_tutorial.py:558:35
%26 : int = prim::Constant[value=0]() # /var/lib/workspace/beginner_source/deploy_seq2seq_hybrid_frontend_tutorial.py:560:34
%encoder : __torch__.EncoderRNN = prim::GetAttr[name="encoder"](%self)
%7 : (Tensor, Tensor) = prim::CallMethod[name="forward"](%encoder, %input_seq.1, %input_length.1) # /var/lib/workspace/beginner_source/deploy_seq2seq_hybrid_frontend_tutorial.py:554:42
%encoder_outputs.1 : Tensor, %encoder_hidden.1 : Tensor = prim::TupleUnpack(%7)
%decoder_hidden.1 : Tensor = aten::slice(%encoder_hidden.1, %26, %14, %12, %16) # /var/lib/workspace/beginner_source/deploy_seq2seq_hybrid_frontend_tutorial.py:556:25
%20 : int[] = prim::ListConstruct(%16, %16)
%23 : Tensor = aten::ones(%20, %18, %14, %17, %14) # /var/lib/workspace/beginner_source/deploy_seq2seq_hybrid_frontend_tutorial.py:558:24
%decoder_input.1 : Tensor = aten::mul(%23, %16) # /var/lib/workspace/beginner_source/deploy_seq2seq_hybrid_frontend_tutorial.py:558:24
%27 : int[] = prim::ListConstruct(%26)
%all_tokens.1 : Tensor = aten::zeros(%27, %18, %14, %17, %14) # /var/lib/workspace/beginner_source/deploy_seq2seq_hybrid_frontend_tutorial.py:560:21
%33 : int[] = prim::ListConstruct(%26)
%all_scores.1 : Tensor = aten::zeros(%33, %14, %14, %17, %14) # /var/lib/workspace/beginner_source/deploy_seq2seq_hybrid_frontend_tutorial.py:561:21
%all_tokens : Tensor, %all_scores : Tensor, %decoder_hidden : Tensor, %decoder_input : Tensor = prim::Loop(%max_length.1, %42, %all_tokens.1, %all_scores.1, %decoder_hidden.1, %decoder_input.1) # /var/lib/workspace/beginner_source/deploy_seq2seq_hybrid_frontend_tutorial.py:563:8
block0(%43 : int, %all_tokens.11 : Tensor, %all_scores.11 : Tensor, %decoder_hidden.9 : Tensor, %decoder_input.17 : Tensor):
%decoder : __torch__.LuongAttnDecoderRNN = prim::GetAttr[name="decoder"](%self)
%48 : (Tensor, Tensor) = prim::CallMethod[name="forward"](%decoder, %decoder_input.17, %decoder_hidden.9, %encoder_outputs.1) # /var/lib/workspace/beginner_source/deploy_seq2seq_hybrid_frontend_tutorial.py:565:45
%decoder_output.1 : Tensor, %decoder_hidden.5 : Tensor = prim::TupleUnpack(%48)
%decoder_scores.1 : Tensor, %decoder_input.5 : Tensor = aten::max(%decoder_output.1, %16, %53) # /var/lib/workspace/beginner_source/deploy_seq2seq_hybrid_frontend_tutorial.py:567:44
%61 : Tensor[] = prim::ListConstruct(%all_tokens.11, %decoder_input.5)
%all_tokens.5 : Tensor = aten::cat(%61, %26) # /var/lib/workspace/beginner_source/deploy_seq2seq_hybrid_frontend_tutorial.py:569:25
%67 : Tensor[] = prim::ListConstruct(%all_scores.11, %decoder_scores.1)
%all_scores.5 : Tensor = aten::cat(%67, %26) # /var/lib/workspace/beginner_source/deploy_seq2seq_hybrid_frontend_tutorial.py:570:25
%decoder_input.13 : Tensor = aten::unsqueeze(%decoder_input.5, %26) # /var/lib/workspace/beginner_source/deploy_seq2seq_hybrid_frontend_tutorial.py:572:28
-> (%42, %all_tokens.5, %all_scores.5, %decoder_hidden.5, %decoder_input.13)
%75 : (Tensor, Tensor) = prim::TupleConstruct(%all_tokens, %all_scores)
return (%75)
运行评估¶
最后,我们将使用 TorchScript 模型运行聊天机器人模型的评估。如果转换正确,模型的行为将与其在即时模式下的表现完全一致。
默认情况下,我们评估一些常见的查询句子。如果您想自己与机器人聊天,请取消注释 evaluateInput
行并试一试。
# Use appropriate device
scripted_searcher.to(device)
# Set dropout layers to ``eval`` mode
scripted_searcher.eval()
# Evaluate examples
sentences = ["hello", "what's up?", "who are you?", "where am I?", "where are you from?"]
for s in sentences:
evaluateExample(s, scripted_searcher, voc)
# Evaluate your input by running
# ``evaluateInput(traced_encoder, traced_decoder, scripted_searcher, voc)``
> hello
Bot: hello .
> what's up?
Bot: i m going to get my car .
> who are you?
Bot: i m the owner .
> where am I?
Bot: in the house .
> where are you from?
Bot: south america .
保存模型¶
现在我们已成功将模型转换为 TorchScript,我们将对其进行序列化,以便在非 Python 部署环境中使用。为此,我们可以简单地保存我们的 scripted_searcher
模块,因为这是用于对聊天机器人模型运行推理的用户界面。保存 Script 模块时,请使用 script_module.save(PATH)
而非 torch.save(model, PATH)
。
scripted_searcher.save("scripted_chatbot.pth")
脚本总运行时间: ( 0 minutes 0.661 seconds)