注意
点击 此处 下载完整的示例代码
使用 TorchScript 部署 Seq2Seq 模型¶
本教程将逐步介绍如何使用 TorchScript API 将序列到序列模型转换为 TorchScript。我们将转换的模型是 聊天机器人教程 中的聊天机器人模型。您可以将本教程视为聊天机器人教程的“第 2 部分”,并部署您自己的预训练模型,或者您可以从本文档开始,并使用我们托管的预训练模型。在后一种情况下,您可以参考原始的聊天机器人教程,了解有关数据预处理、模型理论和定义以及模型训练的详细信息。
什么是 TorchScript?¶
在基于深度学习的项目的研发阶段,使用 PyTorch 等急切的命令式接口是有利的。这使用户能够编写熟悉的、惯用的 Python 代码,从而可以使用 Python 数据结构、控制流操作、打印语句和调试实用程序。虽然急切接口对于研究和实验应用是一个有益的工具,但当需要在生产环境中部署模型时,拥有基于图的模型表示非常有益。延迟图表示允许进行诸如乱序执行之类的优化,以及定位高度优化的硬件架构的能力。此外,基于图的表示使框架无关的模型导出成为可能。PyTorch 提供了将急切模式代码逐步转换为 TorchScript 的机制,TorchScript 是 Python 的一个静态可分析和可优化的子集,Torch 使用它来独立于 Python 运行时表示深度学习程序。
将 Eager 模式下的 PyTorch 程序转换为 TorchScript 的 API 位于 torch.jit
模块中。此模块有两种核心方法可以将 Eager 模式模型转换为 TorchScript 图表示:**追踪** 和 **脚本化**。 torch.jit.trace
函数接收一个模块或函数以及一组示例输入。然后,它使用示例输入运行该函数或模块,同时追踪遇到的计算步骤,并输出一个执行追踪操作的基于图的函数。**追踪** 非常适合处理没有数据相关控制流的简单模块和函数,例如标准卷积神经网络。但是,如果追踪包含数据相关 if 语句和循环的函数,则只会记录示例输入执行路径上调用的操作。换句话说,控制流本身不会被捕获。为了转换包含数据相关控制流的模块和函数,提供了一种**脚本化**机制。 torch.jit.script
函数/装饰器接收一个模块或函数,并且不需要示例输入。脚本化会显式地将模块或函数代码转换为 TorchScript,包括所有控制流。使用脚本化需要注意的一点是,它只支持 Python 的一个子集,因此您可能需要重写代码以使其与 TorchScript 语法兼容。
有关支持的功能的所有详细信息,请参阅 TorchScript 语言参考。为了提供最大的灵活性,您还可以将追踪和脚本化模式混合在一起以表示您的整个程序,并且这些技术可以增量地应用。
致谢¶
本教程受到以下资料的启发
Yuan-Kuei Wu 的 pytorch-chatbot 实现: https://github.com/ywk991112/pytorch-chatbot
Sean Robertson 的 practical-pytorch seq2seq-translation 示例: https://github.com/spro/practical-pytorch/tree/master/seq2seq-translation
FloydHub 的 Cornell Movie Corpus 预处理代码: https://github.com/floydhub/textutil-preprocess-cornell-movie-corpus
准备环境¶
首先,我们将导入所需的模块并设置一些常量。如果您计划使用自己的模型,请确保 MAX_LENGTH
常量设置正确。提醒一下,此常量定义了训练期间允许的最大句子长度以及模型能够产生的最大长度输出。
import torch
import torch.nn as nn
import torch.nn.functional as F
import re
import os
import unicodedata
import numpy as np
device = torch.device("cpu")
MAX_LENGTH = 10 # Maximum sentence length
# Default word tokens
PAD_token = 0 # Used for padding short sentences
SOS_token = 1 # Start-of-sentence token
EOS_token = 2 # End-of-sentence token
模型概述¶
如前所述,我们使用的模型是 序列到序列 (seq2seq) 模型。这种类型的模型用于输入是可变长度序列,输出也是可变长度序列,并且不一定与输入一一对应的情况。seq2seq 模型由两个协同工作的循环神经网络 (RNN) 组成:**编码器** 和 **解码器**。
图片来源: https://jeddy92.github.io/JEddy92.github.io/ts_seq2seq_intro/
编码器¶
编码器 RNN 逐个标记(例如,单词)迭代输入句子,在每个时间步输出一个“输出”向量和一个“隐藏状态”向量。然后将隐藏状态向量传递到下一个时间步,同时记录输出向量。编码器将其在序列中每个点看到的上下文转换为高维空间中的一组点,解码器将使用这些点为给定任务生成有意义的输出。
数据处理¶
尽管我们的模型在概念上处理标记序列,但实际上,就像所有机器学习模型一样,它们处理数字。在这种情况下,模型词汇表中的每个单词(在训练前建立)都映射到一个整数索引。我们使用 Voc
对象来包含从单词到索引的映射,以及词汇表中单词的总数。稍后在运行模型之前,我们将加载该对象。
此外,为了能够运行评估,我们必须提供一个工具来处理我们的字符串输入。 normalizeString
函数将字符串中的所有字符转换为小写并删除所有非字母字符。 indexesFromSentence
函数接收一个单词句子并返回相应的单词索引序列。
class Voc:
def __init__(self, name):
self.name = name
self.trimmed = False
self.word2index = {}
self.word2count = {}
self.index2word = {PAD_token: "PAD", SOS_token: "SOS", EOS_token: "EOS"}
self.num_words = 3 # Count SOS, EOS, PAD
def addSentence(self, sentence):
for word in sentence.split(' '):
self.addWord(word)
def addWord(self, word):
if word not in self.word2index:
self.word2index[word] = self.num_words
self.word2count[word] = 1
self.index2word[self.num_words] = word
self.num_words += 1
else:
self.word2count[word] += 1
# Remove words below a certain count threshold
def trim(self, min_count):
if self.trimmed:
return
self.trimmed = True
keep_words = []
for k, v in self.word2count.items():
if v >= min_count:
keep_words.append(k)
print('keep_words {} / {} = {:.4f}'.format(
len(keep_words), len(self.word2index), len(keep_words) / len(self.word2index)
))
# Reinitialize dictionaries
self.word2index = {}
self.word2count = {}
self.index2word = {PAD_token: "PAD", SOS_token: "SOS", EOS_token: "EOS"}
self.num_words = 3 # Count default tokens
for word in keep_words:
self.addWord(word)
# Lowercase and remove non-letter characters
def normalizeString(s):
s = s.lower()
s = re.sub(r"([.!?])", r" \1", s)
s = re.sub(r"[^a-zA-Z.!?]+", r" ", s)
return s
# Takes string sentence, returns sentence of word indexes
def indexesFromSentence(voc, sentence):
return [voc.word2index[word] for word in sentence.split(' ')] + [EOS_token]
定义编码器¶
我们使用 torch.nn.GRU
模块实现编码器的 RNN,我们向其中馈送一批句子(单词嵌入向量),它在内部逐个标记迭代句子,计算隐藏状态。我们将此模块初始化为双向的,这意味着我们有两个独立的 GRU:一个按时间顺序迭代序列,另一个按相反顺序迭代。我们最终返回这两个 GRU 输出的总和。由于我们的模型使用批处理进行训练,因此我们的 EncoderRNN
模型的 forward
函数需要一个填充的输入批次。为了对可变长度句子进行批处理,我们允许句子中最多有MAX_LENGTH个标记,并且批次中所有少于MAX_LENGTH个标记的句子都在末尾用我们专用的PAD_token标记进行填充。要将填充批次与 PyTorch RNN 模块一起使用,我们必须使用 torch.nn.utils.rnn.pack_padded_sequence
和 torch.nn.utils.rnn.pad_packed_sequence
数据转换来包装前向传递调用。请注意, forward
函数还接收一个 input_lengths
列表,其中包含批次中每个句子的长度。此输入在填充时由 torch.nn.utils.rnn.pack_padded_sequence
函数使用。
TorchScript 注释:¶
由于编码器的 forward
函数不包含任何数据相关控制流,因此我们将使用**追踪**将其转换为脚本模式。追踪模块时,我们可以按原样保留模块定义。我们将在本文档的最后初始化所有模型,然后运行评估。
class EncoderRNN(nn.Module):
def __init__(self, hidden_size, embedding, n_layers=1, dropout=0):
super(EncoderRNN, self).__init__()
self.n_layers = n_layers
self.hidden_size = hidden_size
self.embedding = embedding
# Initialize GRU; the ``input_size`` and ``hidden_size`` parameters are both set to 'hidden_size'
# because our input size is a word embedding with number of features == hidden_size
self.gru = nn.GRU(hidden_size, hidden_size, n_layers,
dropout=(0 if n_layers == 1 else dropout), bidirectional=True)
def forward(self, input_seq, input_lengths, hidden=None):
# type: (Tensor, Tensor, Optional[Tensor]) -> Tuple[Tensor, Tensor]
# Convert word indexes to embeddings
embedded = self.embedding(input_seq)
# Pack padded batch of sequences for RNN module
packed = torch.nn.utils.rnn.pack_padded_sequence(embedded, input_lengths)
# Forward pass through GRU
outputs, hidden = self.gru(packed, hidden)
# Unpack padding
outputs, _ = torch.nn.utils.rnn.pad_packed_sequence(outputs)
# Sum bidirectional GRU outputs
outputs = outputs[:, :, :self.hidden_size] + outputs[:, : ,self.hidden_size:]
# Return output and final hidden state
return outputs, hidden
定义解码器的注意力模块¶
接下来,我们将定义我们的注意力模块(Attn
)。请注意,此模块将用作解码器模型中的子模块。Luong 等人考虑了各种“评分函数”,这些函数接收当前解码器 RNN 输出和整个编码器输出,并返回注意力“能量”。此注意力能量张量与编码器输出的大小相同,最终两者相乘,得到一个加权张量,其最大值表示在解码的特定时间步长上查询句子的最重要部分。
# Luong attention layer
class Attn(nn.Module):
def __init__(self, method, hidden_size):
super(Attn, self).__init__()
self.method = method
if self.method not in ['dot', 'general', 'concat']:
raise ValueError(self.method, "is not an appropriate attention method.")
self.hidden_size = hidden_size
if self.method == 'general':
self.attn = nn.Linear(self.hidden_size, hidden_size)
elif self.method == 'concat':
self.attn = nn.Linear(self.hidden_size * 2, hidden_size)
self.v = nn.Parameter(torch.FloatTensor(hidden_size))
def dot_score(self, hidden, encoder_output):
return torch.sum(hidden * encoder_output, dim=2)
def general_score(self, hidden, encoder_output):
energy = self.attn(encoder_output)
return torch.sum(hidden * energy, dim=2)
def concat_score(self, hidden, encoder_output):
energy = self.attn(torch.cat((hidden.expand(encoder_output.size(0), -1, -1), encoder_output), 2)).tanh()
return torch.sum(self.v * energy, dim=2)
def forward(self, hidden, encoder_outputs):
# Calculate the attention weights (energies) based on the given method
if self.method == 'general':
attn_energies = self.general_score(hidden, encoder_outputs)
elif self.method == 'concat':
attn_energies = self.concat_score(hidden, encoder_outputs)
elif self.method == 'dot':
attn_energies = self.dot_score(hidden, encoder_outputs)
# Transpose max_length and batch_size dimensions
attn_energies = attn_energies.t()
# Return the softmax normalized probability scores (with added dimension)
return F.softmax(attn_energies, dim=1).unsqueeze(1)
定义解码器¶
与 EncoderRNN
类似,我们使用 torch.nn.GRU
模块作为解码器的 RNN。但是,这次我们使用单向 GRU。需要注意的是,与编码器不同,我们将逐个单词馈送解码器 RNN。我们首先获取当前单词的嵌入并应用 dropout。接下来,我们将嵌入和最后一个隐藏状态转发到 GRU,并获得当前 GRU 输出和隐藏状态。然后,我们使用 Attn
模块作为层来获取注意力权重,然后将这些权重乘以编码器的输出以获得我们的注意力编码器输出。我们将此注意力编码器输出用作我们的 context
张量,它表示一个加权和,指示编码器输出的哪些部分需要关注。从这里,我们使用线性层和 softmax 归一化来选择输出序列中的下一个单词。
# TorchScript Notes:
# ~~~~~~~~~~~~~~~~~~~~~~
#
# Similarly to the ``EncoderRNN``, this module does not contain any
# data-dependent control flow. Therefore, we can once again use
# **tracing** to convert this model to TorchScript after it
# is initialized and its parameters are loaded.
#
class LuongAttnDecoderRNN(nn.Module):
def __init__(self, attn_model, embedding, hidden_size, output_size, n_layers=1, dropout=0.1):
super(LuongAttnDecoderRNN, self).__init__()
# Keep for reference
self.attn_model = attn_model
self.hidden_size = hidden_size
self.output_size = output_size
self.n_layers = n_layers
self.dropout = dropout
# Define layers
self.embedding = embedding
self.embedding_dropout = nn.Dropout(dropout)
self.gru = nn.GRU(hidden_size, hidden_size, n_layers, dropout=(0 if n_layers == 1 else dropout))
self.concat = nn.Linear(hidden_size * 2, hidden_size)
self.out = nn.Linear(hidden_size, output_size)
self.attn = Attn(attn_model, hidden_size)
def forward(self, input_step, last_hidden, encoder_outputs):
# Note: we run this one step (word) at a time
# Get embedding of current input word
embedded = self.embedding(input_step)
embedded = self.embedding_dropout(embedded)
# Forward through unidirectional GRU
rnn_output, hidden = self.gru(embedded, last_hidden)
# Calculate attention weights from the current GRU output
attn_weights = self.attn(rnn_output, encoder_outputs)
# Multiply attention weights to encoder outputs to get new "weighted sum" context vector
context = attn_weights.bmm(encoder_outputs.transpose(0, 1))
# Concatenate weighted context vector and GRU output using Luong eq. 5
rnn_output = rnn_output.squeeze(0)
context = context.squeeze(1)
concat_input = torch.cat((rnn_output, context), 1)
concat_output = torch.tanh(self.concat(concat_input))
# Predict next word using Luong eq. 6
output = self.out(concat_output)
output = F.softmax(output, dim=1)
# Return output and final hidden state
return output, hidden
定义评估¶
贪婪搜索解码器¶
与聊天机器人教程一样,我们使用 GreedySearchDecoder
模块来促进实际的解码过程。此模块将训练好的编码器和解码器模型作为属性,并驱动编码输入句子(单词索引向量)的过程,并以逐个单词(单词索引)的方式迭代解码输出响应序列。
编码输入序列很简单:只需将整个序列张量及其对应的长度向量转发到 encoder
。需要注意的是,此模块一次只处理一个输入序列,**不是**一批序列。因此,当使用常量1声明张量大小时,这对应于批次大小为 1。要解码给定的解码器输出,我们必须通过解码器模型迭代运行前向传递,该模型输出对应于每个单词是解码序列中正确下一个单词的概率的 softmax 分数。我们将 decoder_input
初始化为包含SOS_token的张量。在通过 decoder
进行每次传递后,我们贪婪地将具有最高 softmax 概率的单词附加到 decoded_words
列表中。我们还将此单词用作下一次迭代的 decoder_input
。解码过程在以下任一情况下终止: decoded_words
列表的长度达到MAX_LENGTH或预测的单词是EOS_token。
TorchScript 注释:¶
本模块的forward
方法在解码输出序列时,每次解码一个词,需要遍历范围\([0, max\_length)\)。因此,我们应该使用**脚本化**将此模块转换为TorchScript。与我们可以跟踪的编码器和解码器模型不同,我们必须对GreedySearchDecoder
模块进行一些必要的更改,以便在不报错的情况下初始化对象。换句话说,我们必须确保我们的模块遵循TorchScript机制的规则,并且不使用TorchScript包含的Python子集之外的任何语言特性。
为了了解可能需要进行的一些操作,我们将回顾聊天机器人教程中的GreedySearchDecoder
实现与我们在下面单元格中使用的实现之间的差异。请注意,红色突出显示的行是从原始实现中删除的行,绿色突出显示的行是新添加的行。
更改:¶
在构造函数参数中添加了
decoder_n_layers
此更改源于我们将传递给此模块的编码器和解码器模型将是
TracedModule
(而不是Module
)的子类。因此,我们无法使用decoder.n_layers
访问解码器的层数。相反,我们对此进行了规划,并在模块构造期间传递此值。
将新属性存储为常量
在原始实现中,我们可以在
GreedySearchDecoder
的forward
方法中自由使用周围(全局)范围内的变量。但是,现在我们使用脚本化,我们没有这种自由,因为脚本化的假设是我们不一定能够保留Python对象,尤其是在导出时。一个简单的解决方案是在构造函数中将这些全局范围的值存储为模块的属性,并将它们添加到一个名为__constants__
的特殊列表中,以便在forward
方法中构造图时可以将它们用作字面量值。此用法的示例位于新第19行,其中我们使用常量属性self._device
和self._SOS_token
,而不是使用device
和SOS_token
全局值。
强制执行
forward
方法参数的类型更改
decoder_input
的初始化在原始实现中,我们使用
torch.LongTensor([[SOS_token]])
初始化了decoder_input
张量。在脚本化中,不允许我们像这样以字面量方式初始化张量。相反,我们可以使用显式的torch函数(如torch.ones
)初始化张量。在本例中,我们可以通过将1乘以存储在常量self._SOS_token
中的SOS_token值来轻松复制标量decoder_input
张量。
class GreedySearchDecoder(nn.Module):
def __init__(self, encoder, decoder, decoder_n_layers):
super(GreedySearchDecoder, self).__init__()
self.encoder = encoder
self.decoder = decoder
self._device = device
self._SOS_token = SOS_token
self._decoder_n_layers = decoder_n_layers
__constants__ = ['_device', '_SOS_token', '_decoder_n_layers']
def forward(self, input_seq : torch.Tensor, input_length : torch.Tensor, max_length : int):
# Forward input through encoder model
encoder_outputs, encoder_hidden = self.encoder(input_seq, input_length)
# Prepare encoder's final hidden layer to be first hidden input to the decoder
decoder_hidden = encoder_hidden[:self._decoder_n_layers]
# Initialize decoder input with SOS_token
decoder_input = torch.ones(1, 1, device=self._device, dtype=torch.long) * self._SOS_token
# Initialize tensors to append decoded words to
all_tokens = torch.zeros([0], device=self._device, dtype=torch.long)
all_scores = torch.zeros([0], device=self._device)
# Iteratively decode one word token at a time
for _ in range(max_length):
# Forward pass through decoder
decoder_output, decoder_hidden = self.decoder(decoder_input, decoder_hidden, encoder_outputs)
# Obtain most likely word token and its softmax score
decoder_scores, decoder_input = torch.max(decoder_output, dim=1)
# Record token and score
all_tokens = torch.cat((all_tokens, decoder_input), dim=0)
all_scores = torch.cat((all_scores, decoder_scores), dim=0)
# Prepare current token to be next decoder input (add a dimension)
decoder_input = torch.unsqueeze(decoder_input, 0)
# Return collections of word tokens and scores
return all_tokens, all_scores
评估输入¶
接下来,我们定义一些用于评估输入的函数。evaluate
函数接收一个标准化的字符串句子,将其处理成其对应单词索引的张量(批大小为1),并将此张量传递给名为searcher
的GreedySearchDecoder
实例以处理编码/解码过程。搜索器返回输出单词索引向量和一个分数张量,该张量对应于每个解码单词标记的softmax分数。最后一步是使用voc.index2word
将每个单词索引转换回其字符串表示形式。
我们还定义了两个用于评估输入句子的函数。evaluateInput
函数提示用户输入,并对其进行评估。它将继续提示用户输入,直到用户输入“q”或“quit”。
evaluateExample
函数仅将字符串输入句子作为参数,对其进行标准化、评估并打印响应。
def evaluate(searcher, voc, sentence, max_length=MAX_LENGTH):
### Format input sentence as a batch
# words -> indexes
indexes_batch = [indexesFromSentence(voc, sentence)]
# Create lengths tensor
lengths = torch.tensor([len(indexes) for indexes in indexes_batch])
# Transpose dimensions of batch to match models' expectations
input_batch = torch.LongTensor(indexes_batch).transpose(0, 1)
# Use appropriate device
input_batch = input_batch.to(device)
lengths = lengths.to(device)
# Decode sentence with searcher
tokens, scores = searcher(input_batch, lengths, max_length)
# indexes -> words
decoded_words = [voc.index2word[token.item()] for token in tokens]
return decoded_words
# Evaluate inputs from user input (``stdin``)
def evaluateInput(searcher, voc):
input_sentence = ''
while(1):
try:
# Get input sentence
input_sentence = input('> ')
# Check if it is quit case
if input_sentence == 'q' or input_sentence == 'quit': break
# Normalize sentence
input_sentence = normalizeString(input_sentence)
# Evaluate sentence
output_words = evaluate(searcher, voc, input_sentence)
# Format and print response sentence
output_words[:] = [x for x in output_words if not (x == 'EOS' or x == 'PAD')]
print('Bot:', ' '.join(output_words))
except KeyError:
print("Error: Encountered unknown word.")
# Normalize input sentence and call ``evaluate()``
def evaluateExample(sentence, searcher, voc):
print("> " + sentence)
# Normalize sentence
input_sentence = normalizeString(sentence)
# Evaluate sentence
output_words = evaluate(searcher, voc, input_sentence)
output_words[:] = [x for x in output_words if not (x == 'EOS' or x == 'PAD')]
print('Bot:', ' '.join(output_words))
加载预训练参数¶
现在,让我们加载我们的模型!
使用托管模型¶
要加载托管模型
从此处下载模型here。
将
loadFilename
变量设置为下载的检查点文件的路径。保持
checkpoint = torch.load(loadFilename)
行未注释,因为托管模型是在CPU上训练的。
使用您自己的模型¶
要加载您自己的预训练模型
将
loadFilename
变量设置为要加载的检查点文件的路径。请注意,如果您遵循聊天机器人教程中保存模型的约定,这可能需要更改model_name
、encoder_n_layers
、decoder_n_layers
、hidden_size
和checkpoint_iter
(因为这些值用于模型路径)。如果您在CPU上训练了模型,请确保使用
checkpoint = torch.load(loadFilename)
行打开检查点。如果您在GPU上训练了模型,并且正在CPU上运行本教程,请取消注释checkpoint = torch.load(loadFilename, map_location=torch.device('cpu'))
行。
TorchScript注释:¶
请注意,我们像往常一样初始化并将参数加载到编码器和解码器模型中。如果您在模型的某些部分使用跟踪模式(torch.jit.trace
),则必须调用.to(device)
来设置模型的设备选项,并调用.eval()
将dropout层设置为测试模式,**然后**才能跟踪模型。TracedModule对象不继承to
或eval
方法。由于在本教程中我们仅使用脚本化而不是跟踪,因此我们只需要在进行评估之前执行此操作(这与我们在渴望模式下通常执行的操作相同)。
save_dir = os.path.join("data", "save")
corpus_name = "cornell movie-dialogs corpus"
# Configure models
model_name = 'cb_model'
attn_model = 'dot'
#attn_model = 'general'``
#attn_model = 'concat'
hidden_size = 500
encoder_n_layers = 2
decoder_n_layers = 2
dropout = 0.1
batch_size = 64
# If you're loading your own model
# Set checkpoint to load from
checkpoint_iter = 4000
从检查点加载的示例代码
loadFilename = os.path.join(save_dir, model_name, corpus_name,
'{}-{}_{}'.format(encoder_n_layers, decoder_n_layers, hidden_size),
'{}_checkpoint.tar'.format(checkpoint_iter))
# If you're loading the hosted model
loadFilename = 'data/4000_checkpoint.tar'
# Load model
# Force CPU device options (to match tensors in this tutorial)
checkpoint = torch.load(loadFilename, map_location=torch.device('cpu'))
encoder_sd = checkpoint['en']
decoder_sd = checkpoint['de']
encoder_optimizer_sd = checkpoint['en_opt']
decoder_optimizer_sd = checkpoint['de_opt']
embedding_sd = checkpoint['embedding']
voc = Voc(corpus_name)
voc.__dict__ = checkpoint['voc_dict']
print('Building encoder and decoder ...')
# Initialize word embeddings
embedding = nn.Embedding(voc.num_words, hidden_size)
embedding.load_state_dict(embedding_sd)
# Initialize encoder & decoder models
encoder = EncoderRNN(hidden_size, embedding, encoder_n_layers, dropout)
decoder = LuongAttnDecoderRNN(attn_model, embedding, hidden_size, voc.num_words, decoder_n_layers, dropout)
# Load trained model parameters
encoder.load_state_dict(encoder_sd)
decoder.load_state_dict(decoder_sd)
# Use appropriate device
encoder = encoder.to(device)
decoder = decoder.to(device)
# Set dropout layers to ``eval`` mode
encoder.eval()
decoder.eval()
print('Models built and ready to go!')
/var/lib/workspace/beginner_source/deploy_seq2seq_hybrid_frontend_tutorial.py:730: FutureWarning:
You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
Building encoder and decoder ...
Models built and ready to go!
将模型转换为TorchScript¶
编码器¶
如前所述,要将编码器模型转换为TorchScript,我们使用**脚本化**。编码器模型接收输入序列和相应的长度张量。因此,我们创建了一个示例输入序列张量test_seq
,它具有适当的大小(MAX_LENGTH,1),包含适当范围\([0, voc.num\_words)\)内的数字,并且是适当的类型(int64)。我们还创建了一个test_seq_length
标量,它实际上包含对应于test_seq
中单词数量的值。下一步是使用torch.jit.trace
函数跟踪模型。请注意,我们传递的第一个参数是我们想要跟踪的模块,第二个参数是模块的forward
方法的参数元组。
解码器¶
我们对解码器的跟踪过程与对编码器的跟踪过程相同。请注意,我们在跟踪的编码器上调用forward以获取解码器所需输出。这并不是必需的,因为我们也可以简单地制造一个具有正确形状、类型和值范围的张量。这种方法是可行的,因为在我们的例子中,我们对张量的值没有任何限制,因为我们没有任何操作可能会因超出范围的输入而导致错误。
GreedySearchDecoder¶
回想一下,由于存在数据相关的控制流,我们对搜索器模块进行了脚本化。在脚本化的情况下,我们进行了必要的语言更改以确保实现符合TorchScript。我们初始化脚本化搜索器的方式与初始化非脚本化变体的方式相同。
### Compile the whole greedy search model to TorchScript model
# Create artificial inputs
test_seq = torch.LongTensor(MAX_LENGTH, 1).random_(0, voc.num_words).to(device)
test_seq_length = torch.LongTensor([test_seq.size()[0]]).to(device)
# Trace the model
traced_encoder = torch.jit.trace(encoder, (test_seq, test_seq_length))
### Convert decoder model
# Create and generate artificial inputs
test_encoder_outputs, test_encoder_hidden = traced_encoder(test_seq, test_seq_length)
test_decoder_hidden = test_encoder_hidden[:decoder.n_layers]
test_decoder_input = torch.LongTensor(1, 1).random_(0, voc.num_words)
# Trace the model
traced_decoder = torch.jit.trace(decoder, (test_decoder_input, test_decoder_hidden, test_encoder_outputs))
### Initialize searcher module by wrapping ``torch.jit.script`` call
scripted_searcher = torch.jit.script(GreedySearchDecoder(traced_encoder, traced_decoder, decoder.n_layers))
/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/jit/_trace.py:166: UserWarning:
The .grad attribute of a Tensor that is not a leaf Tensor is being accessed. Its .grad attribute won't be populated during autograd.backward(). If you indeed want the .grad field to be populated for a non-leaf Tensor, use .retain_grad() on the non-leaf Tensor. If you access the non-leaf Tensor by mistake, make sure you access the leaf Tensor instead. See github.com/pytorch/pytorch/pull/30531 for more informations. (Triggered internally at aten/src/ATen/core/TensorBody.h:489.)
打印图¶
现在我们的模型已采用TorchScript形式,我们可以打印每个模型的图,以确保我们正确捕获了计算图。由于TorchScript允许我们递归地编译整个模型层次结构并将encoder
和decoder
图内联到单个图中,因此我们只需要打印scripted_searcher图。
print('scripted_searcher graph:\n', scripted_searcher.graph)
scripted_searcher graph:
graph(%self : __torch__.GreedySearchDecoder,
%input_seq.1 : Tensor,
%input_length.1 : Tensor,
%max_length.1 : int):
%53 : bool = prim::Constant[value=0]()
%42 : bool = prim::Constant[value=1]() # /var/lib/workspace/beginner_source/deploy_seq2seq_hybrid_frontend_tutorial.py:561:8
%18 : int = prim::Constant[value=4]() # /var/lib/workspace/beginner_source/deploy_seq2seq_hybrid_frontend_tutorial.py:556:68
%17 : Device = prim::Constant[value="cpu"]() # /var/lib/workspace/beginner_source/deploy_seq2seq_hybrid_frontend_tutorial.py:556:48
%14 : NoneType = prim::Constant()
%12 : int = prim::Constant[value=2]() # /var/lib/workspace/beginner_source/deploy_seq2seq_hybrid_frontend_tutorial.py:554:41
%16 : int = prim::Constant[value=1]() # /var/lib/workspace/beginner_source/deploy_seq2seq_hybrid_frontend_tutorial.py:556:35
%26 : int = prim::Constant[value=0]() # /var/lib/workspace/beginner_source/deploy_seq2seq_hybrid_frontend_tutorial.py:558:34
%encoder : __torch__.EncoderRNN = prim::GetAttr[name="encoder"](%self)
%7 : (Tensor, Tensor) = prim::CallMethod[name="forward"](%encoder, %input_seq.1, %input_length.1) # /var/lib/workspace/beginner_source/deploy_seq2seq_hybrid_frontend_tutorial.py:552:42
%encoder_outputs.1 : Tensor, %encoder_hidden.1 : Tensor = prim::TupleUnpack(%7)
%decoder_hidden.1 : Tensor = aten::slice(%encoder_hidden.1, %26, %14, %12, %16) # /var/lib/workspace/beginner_source/deploy_seq2seq_hybrid_frontend_tutorial.py:554:25
%20 : int[] = prim::ListConstruct(%16, %16)
%23 : Tensor = aten::ones(%20, %18, %14, %17, %14) # /var/lib/workspace/beginner_source/deploy_seq2seq_hybrid_frontend_tutorial.py:556:24
%decoder_input.1 : Tensor = aten::mul(%23, %16) # /var/lib/workspace/beginner_source/deploy_seq2seq_hybrid_frontend_tutorial.py:556:24
%27 : int[] = prim::ListConstruct(%26)
%all_tokens.1 : Tensor = aten::zeros(%27, %18, %14, %17, %14) # /var/lib/workspace/beginner_source/deploy_seq2seq_hybrid_frontend_tutorial.py:558:21
%33 : int[] = prim::ListConstruct(%26)
%all_scores.1 : Tensor = aten::zeros(%33, %14, %14, %17, %14) # /var/lib/workspace/beginner_source/deploy_seq2seq_hybrid_frontend_tutorial.py:559:21
%all_tokens : Tensor, %all_scores : Tensor, %decoder_hidden : Tensor, %decoder_input : Tensor = prim::Loop(%max_length.1, %42, %all_tokens.1, %all_scores.1, %decoder_hidden.1, %decoder_input.1) # /var/lib/workspace/beginner_source/deploy_seq2seq_hybrid_frontend_tutorial.py:561:8
block0(%43 : int, %all_tokens.11 : Tensor, %all_scores.11 : Tensor, %decoder_hidden.9 : Tensor, %decoder_input.17 : Tensor):
%decoder : __torch__.LuongAttnDecoderRNN = prim::GetAttr[name="decoder"](%self)
%48 : (Tensor, Tensor) = prim::CallMethod[name="forward"](%decoder, %decoder_input.17, %decoder_hidden.9, %encoder_outputs.1) # /var/lib/workspace/beginner_source/deploy_seq2seq_hybrid_frontend_tutorial.py:563:45
%decoder_output.1 : Tensor, %decoder_hidden.5 : Tensor = prim::TupleUnpack(%48)
%decoder_scores.1 : Tensor, %decoder_input.5 : Tensor = aten::max(%decoder_output.1, %16, %53) # /var/lib/workspace/beginner_source/deploy_seq2seq_hybrid_frontend_tutorial.py:565:44
%61 : Tensor[] = prim::ListConstruct(%all_tokens.11, %decoder_input.5)
%all_tokens.5 : Tensor = aten::cat(%61, %26) # /var/lib/workspace/beginner_source/deploy_seq2seq_hybrid_frontend_tutorial.py:567:25
%67 : Tensor[] = prim::ListConstruct(%all_scores.11, %decoder_scores.1)
%all_scores.5 : Tensor = aten::cat(%67, %26) # /var/lib/workspace/beginner_source/deploy_seq2seq_hybrid_frontend_tutorial.py:568:25
%decoder_input.13 : Tensor = aten::unsqueeze(%decoder_input.5, %26) # /var/lib/workspace/beginner_source/deploy_seq2seq_hybrid_frontend_tutorial.py:570:28
-> (%42, %all_tokens.5, %all_scores.5, %decoder_hidden.5, %decoder_input.13)
%75 : (Tensor, Tensor) = prim::TupleConstruct(%all_tokens, %all_scores)
return (%75)
运行评估¶
最后,我们将使用TorchScript模型运行聊天机器人模型的评估。如果转换正确,模型的行为将与在渴望模式表示中的行为完全相同。
默认情况下,我们评估一些常见的查询句子。如果您想自己与机器人聊天,请取消注释evaluateInput
行并试一试。
# Use appropriate device
scripted_searcher.to(device)
# Set dropout layers to ``eval`` mode
scripted_searcher.eval()
# Evaluate examples
sentences = ["hello", "what's up?", "who are you?", "where am I?", "where are you from?"]
for s in sentences:
evaluateExample(s, scripted_searcher, voc)
# Evaluate your input by running
# ``evaluateInput(traced_encoder, traced_decoder, scripted_searcher, voc)``
> hello
Bot: hello .
> what's up?
Bot: i m going to get my car .
> who are you?
Bot: i m the owner .
> where am I?
Bot: in the house .
> where are you from?
Bot: south america .
保存模型¶
现在我们已成功将模型转换为TorchScript,我们将对其进行序列化以便在非Python部署环境中使用。为此,我们可以简单地保存scripted_searcher
模块,因为这是对聊天机器人模型运行推理的用户界面。保存脚本模块时,请使用script_module.save(PATH)而不是torch.save(model, PATH)。
scripted_searcher.save("scripted_chatbot.pth")
脚本总运行时间:(0 分钟 0.758 秒)