本周,我们正式发布了 PyTorch 1.1,这是 PyTorch 1.0 的一个大型功能更新。我们添加的新功能之一是更好地支持使用 TorchScript(PyTorch JIT)实现快速、定制的循环神经网络(fastrnns)(https://pytorch.ac.cn/docs/stable/jit.html)。
RNN 是一种流行的模型,在各种不同形态的自然语言处理 (NLP) 任务中表现良好。PyTorch 实现了其中许多流行的模型,如 Elman RNN、GRU 和 LSTM,以及多层和双向变体。
然而,许多用户希望实现自己的定制 RNN,借鉴最新的文献思想。将 层归一化 (Layer Normalization) 应用于 LSTM 就是一个这样的用例。由于 PyTorch 的 CUDA LSTM 实现使用了融合内核,因此很难插入归一化或修改基础 LSTM 实现。许多用户转而使用标准的 PyTorch 运算符编写定制实现,但这类代码存在高开销问题:大多数 PyTorch 操作会在 GPU 上启动至少一个内核,而 RNN 由于其循环特性通常会运行许多操作。然而,我们可以应用 TorchScript 自动融合操作并优化代码,从而在 GPU 上启动更少、更优化的内核。
我们的目标是让用户能够使用 TorchScript 编写快速、定制的 RNN,而无需编写专门的 CUDA 内核即可获得相似的性能。在本文中,我们将提供一个教程,介绍如何使用 TorchScript 编写您自己的快速 RNN。为了更好地理解 TorchScript 应用的优化,我们将以标准 LSTM 实现为例来探讨它们的工作原理,但大多数优化也适用于一般的 RNN。
编写定制 RNN
首先,您可以使用 此文件 作为模板来编写您自己的定制 RNN。
我们正在不断改进我们的基础设施,以提升性能。如果您希望获得 TorchScript 当前提供的速度/优化(例如运算符融合、批量矩阵乘法等),请遵循以下一些准则。下一节将深入解释这些优化。
-
如果定制的操作都是逐元素的 (element-wise),那就太好了,因为您可以自动获得 PyTorch JIT 的运算符融合 (operator fusion) 的好处!
-
如果您有更复杂的操作(例如,规约 (reduce) 操作与逐元素操作混合),请考虑将规约操作和逐元素操作分开分组,以便将逐元素操作融合到一个融合组中。
-
如果您想了解您的定制 RNN 中有哪些部分已被融合,您可以使用
graph_for
来检查操作的优化图。以LSTMCell
为例# get inputs and states for LSTMCell inputs = get_lstm_inputs() # instantiate a ScriptModule cell = LSTMCell(input_size, hidden_size) # print the optimized graph using graph_for out = cell(inputs) print(cell.graph_for(inputs))
这将为您提供的特定输入生成优化的 TorchScript 图(即 PyTorch JIT IR)
graph(%x : Float(*, *), %hx : Float(*, *), %cx : Float(*, *), %w_ih : Float(*, *), %w_hh : Float(*, *), %b_ih : Float(*), %b_hh : Float(*)): %hy : Float(*, *), %cy : Float(*, *) = prim::DifferentiableGraph_0(%cx, %b_hh, %b_ih, %hx, %w_hh, %x, %w_ih) %30 : (Float(*, *), Float(*, *)) = prim::TupleConstruct(%hy, %cy) return (%30) with prim::DifferentiableGraph_0 = graph(%13 : Float(*, *), %29 : Float(*), %33 : Float(*), %40 : Float(*, *), %43 : Float(*, *), %45 : Float(*, *), %48 : Float(*, *)): %49 : Float(*, *) = aten::t(%48) %47 : Float(*, *) = aten::mm(%45, %49) %44 : Float(*, *) = aten::t(%43) %42 : Float(*, *) = aten::mm(%40, %44) ...some broadcast sizes operations... %hy : Float(*, *), %287 : Float(*, *), %cy : Float(*, *), %outgate.1 : Float(*, *), %cellgate.1 : Float(*, *), %forgetgate.1 : Float(*, *), %ingate.1 : Float(*, *) = prim::FusionGroup_0(%13, %346, %345, %344, %343) ...some broadcast sizes operations... return (%hy, %cy, %49, %44, %196, %199, %340, %192, %325, %185, %ingate.1, %forgetgate.1, %cellgate.1, %outgate.1, %395, %396, %287) with prim::FusionGroup_0 = graph(%13 : Float(*, *), %71 : Tensor, %76 : Tensor, %81 : Tensor, %86 : Tensor): ...some chunks, constants, and add operations... %ingate.1 : Float(*, *) = aten::sigmoid(%38) %forgetgate.1 : Float(*, *) = aten::sigmoid(%34) %cellgate.1 : Float(*, *) = aten::tanh(%30) %outgate.1 : Float(*, *) = aten::sigmoid(%26) %14 : Float(*, *) = aten::mul(%forgetgate.1, %13) %11 : Float(*, *) = aten::mul(%ingate.1, %cellgate.1) %cy : Float(*, *) = aten::add(%14, %11, %69) %4 : Float(*, *) = aten::tanh(%cy) %hy : Float(*, *) = aten::mul(%outgate.1, %4) return (%hy, %4, %cy, %outgate.1, %cellgate.1, %forgetgate.1, %ingate.1)
从上面的图中,我们可以看到它有一个 prim::FusionGroup_0
子图,该子图融合了 LSTMCell
中的所有逐元素操作(转置和矩阵乘法不是逐元素操作)。某些图节点起初可能难以理解,但我们会在优化部分解释其中一些,本文中也省略了一些仅为正确性而存在的冗长运算符。
可变长度序列的最佳实践
TorchScript 不支持 PackedSequence。通常,处理可变长度序列时,最好将它们填充到单个张量中,然后将该张量通过 TorchScript LSTM 发送。这里有一个例子
sequences = [...] # List[Tensor], each Tensor is T' x C
padded = torch.utils.rnn.pad_sequence(sequences)
lengths = [seq.size(0) for seq in sequences]
padded # T x N x C, where N is batch size and T is the max of all T'
model = LSTM(...)
output, hiddens = model(padded)
output # T x N x C
当然,output
在填充区域可能包含一些冗余数据;使用 lengths
来跟踪您不需要的部分。
优化
现在我们将解释 PyTorch JIT 为加速定制 RNN 所执行的优化。我们将使用一个简单的 TorchScript 定制 LSTM 模型来演示这些优化,但其中许多是通用的,也适用于其他 RNN。
为了说明我们进行的优化以及如何从中获益,我们将运行一个用 TorchScript 编写的简单定制 LSTM 模型(您可以参考 custom_lstm.py 中的代码或下面的代码片段)并记录我们的更改所花费的时间。
我们在配备 2 个 Intel Xeon 芯片和 1 个 Nvidia P100 的机器上设置了环境,安装了 cuDNN v7.3 和 CUDA 9.2。LSTM 模型的基本设置如下:
input_size = 512
hidden_size = 512
mini_batch = 64
numLayers = 1
seq_length = 100
PyTorch JIT 最重要的一点是它将 Python 程序编译成 PyTorch JIT IR,这是一种用于建模程序图结构的中间表示。这种 IR 可以受益于全程序优化、硬件加速,并且总体上具有提供巨大计算增益的潜力。在这个例子中,我们仅使用 JIT 提供的编译器优化过程来运行初始的 TorchScript 模型,包括公共子表达式消除、常量合并、常量传播、死代码消除以及一些窥孔优化。我们在热身后运行模型训练 100 次,并对训练时间取平均。模型前向传播的初始结果约为 27ms,后向传播时间约为 64ms,这与 PyTorch cuDNN LSTM 提供的性能还有一定差距。接下来,我们将解释我们在训练或推理性能方面所做的主要优化,从 LSTMCell 和 LSTMLayer 开始,以及一些杂项优化。
LSTM Cell(前向传播)
LSTM 中的几乎所有计算都发生在 LSTMCell 中,因此了解它包含的计算以及如何提高其速度非常重要。下面是一个 TorchScript 中的 LSTMCell 示例实现
class LSTMCell(jit.ScriptModule):
def __init__(self, input_size, hidden_size):
super(LSTMCell, self).__init__()
self.input_size = input_size
self.hidden_size = hidden_size
self.weight_ih = Parameter(torch.randn(4 * hidden_size, input_size))
self.weight_hh = Parameter(torch.randn(4 * hidden_size, hidden_size))
self.bias_ih = Parameter(torch.randn(4 * hidden_size))
self.bias_hh = Parameter(torch.randn(4 * hidden_size))
@jit.script_method
def forward(self, input, state):
# type: (Tensor, Tuple[Tensor, Tensor]) -> Tuple[Tensor, Tuple[Tensor, Tensor]]
hx, cx = state
gates = (torch.mm(input, self.weight_ih.t()) + self.bias_ih +
torch.mm(hx, self.weight_hh.t()) + self.bias_hh)
ingate, forgetgate, cellgate, outgate = gates.chunk(4, 1)
ingate = torch.sigmoid(ingate)
forgetgate = torch.sigmoid(forgetgate)
cellgate = torch.tanh(cellgate)
outgate = torch.sigmoid(outgate)
cy = (forgetgate * cx) + (ingate * cellgate)
hy = outgate * torch.tanh(cy)
return hy, (hy, cy)
TorchScript 生成的这种图表示 (IR) 支持多种优化和可扩展计算。除了我们可以进行的典型编译器优化(CSE、常量传播等)之外,我们还可以运行其他 IR 转换来使我们的代码运行得更快。
- 逐元素运算符融合 (Element-wise operator fusion)。PyTorch JIT 会自动融合逐元素操作,因此当您有相邻且都是逐元素操作的运算符时,JIT 会自动将所有这些操作组合到一个 FusionGroup 中,然后这个 FusionGroup 可以通过一个单独的 GPU/CPU 内核启动并在一次处理中完成。这避免了每次操作昂贵的内存读写。
- 重新排序 chunk 和 pointwise 操作以实现更多融合。LSTM Cell 将门控相加(pointwise 操作),然后将门控分成四部分:ifco 门。然后,它像上面一样对 ifco 门执行 pointwise 操作。这在实践中导致了两个融合组:一个用于 chunk 前的逐元素操作,一个用于 chunk 后的逐元素操作。这里有趣的一点是,pointwise 操作与
torch.chunk
是可交换的 (commute):与其对某些输入张量执行 pointwise 操作然后对输出进行 chunk,我们可以先对输入张量进行 chunk,然后在输出张量上执行相同的 pointwise 操作。通过将 chunk 移到第一个融合组之前,我们可以将第一个和第二个融合组合并成一个大组。

- 在 CPU 上创建张量是昂贵的,但目前正在进行工作来使其更快。此时,一个 LSTMCell 运行三个 CUDA 内核:两个
gemm
内核和一个用于单个 pointwise 组的内核。我们注意到的一件事是,第二个gemm
完成与单个 pointwise 组启动之间存在较大间隔。这个间隔是 GPU 空闲无事可做的一段时间。深入研究后,我们发现问题在于torch.chunk
会构造新的张量,而张量构造的速度不够快。我们没有构造新的 Tensor 对象,而是教融合编译器如何在将torch.chunk
发送给融合内核之前操作数据指针和跨度来执行它,从而缩短了第二个 gemm 和逐元素融合组启动之间的空闲时间。这使得 LSTM 前向传播的速度提升了约 1.2 倍。
通过上述技巧,我们能够将几乎所有的 LSTMCell
前向图(除了两个 gemm 内核)融合到一个融合组中,这对应于上述 IR 图中的 prim::FusionGroup_0
。然后它将作为单个融合内核启动执行。通过这些优化,模型性能显著提升,平均前向传播时间从约 17ms 减少到 10ms(提速 1.7 倍),平均后向传播时间从 37ms 减少到 27ms(提速 1.37 倍)。
LSTM Layer(前向传播)
class LSTMLayer(jit.ScriptModule):
def __init__(self, cell, *cell_args):
super(LSTMLayer, self).__init__()
self.cell = cell(*cell_args)
@jit.script_method
def forward(self, input, state):
# type: (Tensor, Tuple[Tensor, Tensor]) -> Tuple[Tensor, Tuple[Tensor, Tensor]]
inputs = input.unbind(0)
outputs = torch.jit.annotate(List[Tensor], [])
for i in range(len(inputs)):
out, state = self.cell(inputs[i], state)
outputs += [out]
return torch.stack(outputs), state
我们对为 TorchScript LSTM 生成的 IR 做了一些技巧来提升性能,以下是我们进行的一些优化示例:
- 循环展开 (Loop Unrolling):我们自动对代码中的循环进行展开(对于大型循环,我们展开其一小部分),这使得我们能够对 for 循环的控制流进行进一步优化。例如,融合器可以将循环体迭代之间的操作融合在一起,这对于像 LSTM 这样控制流密集型的模型来说带来了显著的性能提升。
- 批量矩阵乘法 (Batch Matrix Multiplication):对于输入被预先乘过的 RNN(即模型有很多与相同 LHS 或 RHS 的矩阵乘法),我们可以有效地将这些操作批处理到一个单独的矩阵乘法中,同时对输出进行 chunk 以实现等效的语义。
通过应用这些技术,我们将前向传播的时间额外减少了 1.6ms 至 8.4ms(提速 1.2 倍),并将后向传播的时间减少了 7ms 至约 20ms(提速 1.35 倍)。
LSTM Layer(后向传播)
-
“树形”批量矩阵乘法 (“Tree” Batch Matrix Multiplication):在 LSTM 后向图中有时会发生单个权重被多次重用的情况,形成一个树形结构,其中叶子节点是矩阵乘法,节点是加法。这些节点可以通过在不同维度上拼接 LHS 和 RHS 来组合,然后作为一个单独的矩阵乘法进行计算。等效公式可以表示如下:
$L1 * R1 + L2 * R2 = torch.cat((L1, L2), dim=1) * torch.cat((R1, R2), dim=0)$
-
Autograd 是使 PyTorch 成为如此优雅的 ML 框架的关键组成部分。因此,我们将它引入到 PyTorch JIT 中,但使用的是在 IR 层面工作的新型自动微分 (AD) 机制。JIT 自动微分会将前向图切分成符号可微分的子图,并为这些子图生成后向节点。以上述 IR 为例,我们将具有 AD 公式的操作图节点分组到单个
prim::DifferentiableGraph_0
中。对于尚未添加到 AD 公式中的操作,我们将在执行期间回退到 Autograd。 -
优化后向传播路径很困难,而隐式广播语义使得自动微分的优化更加棘手。PyTorch 通过为您广播张量,方便您编写张量操作而无需担心形状问题。对于性能而言,后向传播中的痛点在于,对于这种可广播操作,我们需要进行求和。这导致每个可广播操作的导数后面都跟着一个求和操作。由于我们目前无法融合规约 (reduce) 操作,这使得 FusionGroups 分裂成多个小分组,导致性能不佳。要解决这个问题,请参考 Thomas Viehmann 撰写的这篇很棒的 文章。
杂项优化
- 除了上面介绍的步骤外,我们还消除了 CUDA 内核启动和不必要的张量分配之间的开销。一个例子是当您进行张量设备查找时。这最初可能由于大量不必要的分配而导致性能不佳。当我们消除这些时,内核启动之间的间隔从毫秒级缩短到纳秒级。
- 最后,定制的 LSTMCell 中可能应用了归一化,如 LayerNorm。由于 LayerNorm 和其他归一化操作包含规约 (reduce) 操作,很难将其完全融合。相反,我们自动将 Layernorm 分解为统计计算(规约操作)+ 逐元素转换,然后将这些逐元素部分融合在一起。截至本文发布时,我们的自动微分和图融合器基础设施存在一些限制,将目前的支持范围限制在推理模式。我们计划在未来的版本中添加后向传播支持。
通过上述对操作融合、循环展开、批量矩阵乘法和一些杂项优化的应用,我们可以从下图中看到定制 TorchScript LSTM 前向和后向传播性能的显著提升

本文中还有许多其他优化未涵盖。除了本文中介绍的之外,我们现在看到定制 LSTM 的前向传播性能与 cuDNN 持平。我们也在努力进一步优化后向传播,并期望在未来版本中看到改进。除了 TorchScript 提供的速度外,我们还引入了更加灵活的 API,使您能够手动编写更多定制 RNN,这是 cuDNN 无法提供的。