跳转到主要内容
博客

使用 TorchScript 优化 CUDA 循环神经网络

作者: 2019 年 5 月 1 日2024 年 11 月 16 日暂无评论

本周,我们正式发布了 PyTorch 1.1,这是 PyTorch 1.0 的一个大型功能更新。我们添加的新功能之一是更好地支持使用 TorchScript(PyTorch JIT)进行快速、自定义的循环神经网络(fastrnns)(https://pytorch.ac.cn/docs/stable/jit.html)。

RNN 是一种流行的模型,在各种不同形状和大小的 NLP 任务中表现出色。PyTorch 实现了其中一些最流行的模型,如 Elman RNNGRULSTM,以及多层和双向变体。

然而,许多用户希望根据最新文献中的思想实现自己的自定义 RNN。将 层归一化 应用于 LSTM 就是一个这样的用例。由于 PyTorch CUDA LSTM 实现使用了融合内核,因此很难插入归一化或修改基础 LSTM 实现。许多用户已转向使用标准 PyTorch 运算符编写自定义实现,但此类代码开销很高:大多数 PyTorch 操作在 GPU 上至少启动一个内核,并且 RNN 通常由于其循环特性而运行许多操作。然而,我们可以应用 TorchScript 来融合操作并自动优化我们的代码,从而在 GPU 上启动更少、更优化的内核。

我们的目标是让用户能够用 TorchScript 编写快速、自定义的 RNN,而无需编写专门的 CUDA 内核即可实现相似的性能。在这篇文章中,我们将提供一个关于如何使用 TorchScript 编写自己的快速 RNN 的教程。为了更好地理解 TorchScript 应用的优化,我们将研究这些优化在标准 LSTM 实现上的工作方式,但大多数优化可以应用于通用 RNN。

编写自定义 RNN

首先,您可以使用此文件作为模板来编写您自己的自定义 RNN。

我们不断改进基础设施以提高性能。如果您想获得 TorchScript 目前提供的速度/优化(如运算符融合、批量矩阵乘法等),这里有一些遵循的指导原则。下一节将深入解释这些优化。

  1. 如果自定义操作都是逐元素的,那就太好了,因为您可以自动获得 PyTorch JIT 的运算符融合的好处!
  2. 如果您有更复杂的操作(例如,归约操作与逐元素操作混合),请考虑将归约操作和逐元素操作分开分组,以便将逐元素操作融合到一个融合组中。
  3. 如果您想了解您的自定义 RNN 中融合了哪些内容,您可以使用 `graph_for` 检查操作的优化图。以 `LSTMCell` 为例
# get inputs and states for LSTMCell

 inputs = get_lstm_inputs()

 # instantiate a ScriptModule

 cell = LSTMCell(input_size, hidden_size)

 # print the optimized graph using graph_for

 out = cell(inputs)
 print(cell.graph_for(inputs))

    这将为提供的特定输入生成优化的 TorchScript 图(即 PyTorch JIT IR)

    graph(%x : Float(*, *),
             %hx : Float(*, *),
             %cx : Float(*, *),
             %w_ih : Float(*, *),
             %w_hh : Float(*, *),
             %b_ih : Float(*),
             %b_hh : Float(*)):
         %hy : Float(*, *), %cy : Float(*, *) = prim::DifferentiableGraph_0(%cx, %b_hh, %b_ih, %hx, %w_hh, %x, %w_ih)
         %30 : (Float(*, *), Float(*, *)) = prim::TupleConstruct(%hy, %cy)
         return (%30)
         with prim::DifferentiableGraph_0 = graph(%13 : Float(*, *),
             %29 : Float(*),
             %33 : Float(*),
             %40 : Float(*, *),
             %43 : Float(*, *),
             %45 : Float(*, *),
             %48 : Float(*, *)):
         %49 : Float(*, *) = aten::t(%48)
         %47 : Float(*, *) = aten::mm(%45, %49)
         %44 : Float(*, *) = aten::t(%43)
         %42 : Float(*, *) = aten::mm(%40, %44)
         ...some broadcast sizes operations...
         %hy : Float(*, *), %287 : Float(*, *), %cy : Float(*, *), %outgate.1 : Float(*, *), %cellgate.1 : Float(*, *), %forgetgate.1 : Float(*, *), %ingate.1 : Float(*, *) = prim::FusionGroup_0(%13, %346, %345, %344, %343)
         ...some broadcast sizes operations...
         return (%hy, %cy, %49, %44, %196, %199, %340, %192, %325, %185, %ingate.1, %forgetgate.1, %cellgate.1, %outgate.1, %395, %396, %287)
         with prim::FusionGroup_0 = graph(%13 : Float(*, *),
             %71 : Tensor,
             %76 : Tensor,
             %81 : Tensor,
             %86 : Tensor):
         ...some chunks, constants, and add operations...
         %ingate.1 : Float(*, *) = aten::sigmoid(%38)
         %forgetgate.1 : Float(*, *) = aten::sigmoid(%34)
         %cellgate.1 : Float(*, *) = aten::tanh(%30)
         %outgate.1 : Float(*, *) = aten::sigmoid(%26)
         %14 : Float(*, *) = aten::mul(%forgetgate.1, %13)
         %11 : Float(*, *) = aten::mul(%ingate.1, %cellgate.1)
         %cy : Float(*, *) = aten::add(%14, %11, %69)
         %4 : Float(*, *) = aten::tanh(%cy)
         %hy : Float(*, *) = aten::mul(%outgate.1, %4)
         return (%hy, %4, %cy, %outgate.1, %cellgate.1, %forgetgate.1, %ingate.1)

    从上面的图中我们可以看到,它有一个 `prim::FusionGroup_0` 子图,它融合了 LSTMCell 中所有逐元素操作(转置和矩阵乘法不是逐元素操作)。有些图节点一开始可能难以理解,但我们将在优化部分解释其中一些,我们还省略了本文中一些冗长的运算符,这些运算符仅用于正确性。

    变长序列最佳实践

    TorchScript 不支持 PackedSequence。通常,在处理变长序列时,最好将它们填充到一个张量中,然后将该张量通过 TorchScript LSTM 发送。这是一个例子

    sequences = [...] # List[Tensor], each Tensor is T' x C
    padded = torch.utils.rnn.pad_sequence(sequences)
    lengths = [seq.size(0) for seq in sequences]
    padded  # T x N x C, where N is batch size and T is the max of all T'
    
    model = LSTM(...)
    output, hiddens = model(padded)
    output  # T x N x C
    

    当然,`output` 在填充区域可能有一些垃圾数据;使用 `lengths` 来跟踪不需要的部分。

    优化

    我们现在将解释 PyTorch JIT 为加速自定义 RNN 所做的优化。我们将使用 TorchScript 中的一个简单自定义 LSTM 模型来演示这些优化,但其中许多优化是通用的,也适用于其他 RNN。

    为了说明我们所做的优化以及我们如何从这些优化中获益,我们将运行一个用 TorchScript 编写的简单自定义 LSTM 模型(您可以参考 custom_lstm.py 中的代码或下面的代码片段),并计时我们的更改。

    我们在配备 2 个 Intel Xeon 芯片和 1 个 Nvidia P100,并安装了 cuDNN v7.3 和 CUDA 9.2 的机器上设置了环境。LSTM 模型的基本设置如下

    input_size = 512
    hidden_size = 512
    mini_batch = 64
    numLayers = 1
    seq_length = 100 
    

    PyTorch JIT 最重要的事情是它将 Python 程序编译成 PyTorch JIT IR,这是一种用于建模程序图结构的中间表示。然后,该 IR 可以受益于整个程序优化、硬件加速,并总体上具有提供巨大计算增益的潜力。在此示例中,我们运行初始的 TorchScript 模型,仅使用 JIT 提供的编译器优化传递,包括公共子表达式消除、常量池化、常量传播、死代码消除和一些窥孔优化。我们在预热后运行模型训练 100 次,并平均训练时间。模型前向时间的初始结果约为 27 毫秒,后向时间约为 64 毫秒,这与 PyTorch cuDNN LSTM 提供的数据相去甚远。接下来,我们将解释我们在训练或推理方面如何提高性能的主要优化,从 LSTMCell 和 LSTMLayer 开始,以及一些杂项优化。

    LSTM 单元(前向)

    LSTM 中的几乎所有计算都发生在 LSTMCell 中,因此我们检查它包含的计算以及如何提高它们的速度非常重要。下面是 TorchScript 中一个示例 LSTMCell 实现

    class LSTMCell(jit.ScriptModule):
        def __init__(self, input_size, hidden_size):
            super(LSTMCell, self).__init__()
            self.input_size = input_size
            self.hidden_size = hidden_size
            self.weight_ih = Parameter(torch.randn(4 * hidden_size, input_size))
            self.weight_hh = Parameter(torch.randn(4 * hidden_size, hidden_size))
            self.bias_ih = Parameter(torch.randn(4 * hidden_size))
            self.bias_hh = Parameter(torch.randn(4 * hidden_size))
    
        @jit.script_method
        def forward(self, input, state):
            # type: (Tensor, Tuple[Tensor, Tensor]) -> Tuple[Tensor, Tuple[Tensor, Tensor]]
            hx, cx = state
            gates = (torch.mm(input, self.weight_ih.t()) + self.bias_ih +
                     torch.mm(hx, self.weight_hh.t()) + self.bias_hh)
            ingate, forgetgate, cellgate, outgate = gates.chunk(4, 1)
    
            ingate = torch.sigmoid(ingate)
            forgetgate = torch.sigmoid(forgetgate)
            cellgate = torch.tanh(cellgate)
            outgate = torch.sigmoid(outgate)
    
            cy = (forgetgate * cx) + (ingate * cellgate)
            hy = outgate * torch.tanh(cy)
    
            return hy, (hy, cy)
    

    TorchScript 生成的这种图表示(IR)支持多种优化和可伸缩计算。除了我们可以进行的典型编译器优化(CSE、常量传播等)之外,我们还可以运行其他 IR 转换以使我们的代码运行得更快。

    • 逐元素运算符融合。PyTorch JIT 将自动融合逐元素操作,因此当您有相邻的逐元素操作时,JIT 会自动将所有这些操作组合到一个 FusionGroup 中,然后可以使用单个 GPU/CPU 内核启动此 FusionGroup,并一次性执行。这避免了每次操作昂贵的内存读写。
    • 重新排列块和逐点操作以实现更多融合。LSTM 单元将门相加(逐点操作),然后将门分成四部分:ifco 门。然后,它像上面那样对 ifco 门执行逐点操作。这在实践中导致了两个融合组:一个用于分块前的逐元素操作,一个用于分块后的逐元素操作。这里有趣的是,逐点操作与 `torch.chunk` 可交换:与其对某些输入张量执行逐点操作并分块输出,我们可以分块输入张量,然后对输出张量执行相同的逐点操作。通过将分块移动到第一个融合组之前,我们可以将第一个和第二个融合组合并为一个大组。
    • 在 CPU 上创建张量很昂贵,但目前正在努力使其更快。此时,LSTMCell 运行三个 CUDA 内核:两个 `gemm` 内核和一个用于单个逐点组的内核。我们注意到的一件事是,第二个 `gemm` 结束和单个逐点组开始之间存在很大差距。这个差距是 GPU 空闲无所事事的一段时间。进一步调查后,我们发现问题是 `torch.chunk` 构造了新张量,而张量构造的速度不如它应有的快。我们没有构造新的 Tensor 对象,而是教导融合编译器如何操纵数据指针和步幅,以便在将其发送到融合内核之前执行 `torch.chunk`,从而缩短第二个 gemm 和逐元素融合组启动之间的空闲时间。这使 LSTM 前向传播的速度提高了约 1.2 倍。

    通过上述技巧,我们能够将几乎所有 `LSTMCell` 前向图(除了两个 gemm 内核)融合到一个单一的融合组中,这对应于上述 IR 图中的 `prim::FusionGroup_0`。然后,它将被启动到一个单一的融合内核中执行。通过这些优化,模型性能显著提高,平均前向时间减少约 17 毫秒(1.7 倍加速)至 10 毫秒,平均后向时间减少 37 毫秒至 27 毫秒(1.37 倍加速)。

    LSTM 层(前向)

    class LSTMLayer(jit.ScriptModule):
        def __init__(self, cell, *cell_args):
            super(LSTMLayer, self).__init__()
            self.cell = cell(*cell_args)
    
        @jit.script_method
        def forward(self, input, state):
            # type: (Tensor, Tuple[Tensor, Tensor]) -> Tuple[Tensor, Tuple[Tensor, Tensor]]
            inputs = input.unbind(0)
            outputs = torch.jit.annotate(List[Tensor], [])
            for i in range(len(inputs)):
                out, state = self.cell(inputs[i], state)
                outputs += [out]
            return torch.stack(outputs), state
    

    我们对为 TorchScript LSTM 生成的 IR 做了几项技巧以提高性能,一些示例优化如下:

    • 循环展开:我们自动展开代码中的循环(对于大循环,我们展开其一小部分),这使我们能够对 for 循环控制流进行进一步优化。例如,融合器可以将循环体迭代中的操作融合在一起,这对于像 LSTM 这样控制流密集型模型来说,可以显著提高性能。
    • 批量矩阵乘法:对于输入预乘的 RNN(即模型有许多具有相同 LHS 或 RHS 的矩阵乘法),我们可以有效地将这些操作批量处理成单个矩阵乘法,同时对输出进行分块以实现等效语义。

    通过应用这些技术,我们将前向传递时间额外减少了 1.6 毫秒,降至 8.4 毫秒(提速 1.2 倍),后向传递时间减少了 7 毫秒,降至约 20 毫秒(提速 1.35 倍)。

    LSTM 层(反向)

    • “树形”批量矩阵乘法:通常情况下,在 LSTM 反向图中,单个权重会被多次重用,形成一棵树,其中叶子是矩阵乘法,节点是加法。这些节点可以通过在不同维度连接 LHS 和 RHS 来组合,然后计算为单个矩阵乘法。等效公式可以表示为:L1∗R1+L2∗R2=torch.cat((L1,L2),dim=1)∗torch.cat((R1,R2),dim=0)L1∗R1+L2∗R2=torch.cat((L1,L2),dim=1)∗torch.cat((R1,R2),dim=0)
    • 自动求导是 PyTorch 成为如此优雅的机器学习框架的关键组成部分。因此,我们将其延续到了 PyTorch JIT,但使用的是一种在 IR 级别工作的新型**自动微分**(AD)机制。JIT 自动微分会将前向图切片成符号可微分的子图,并为这些子图生成反向节点。以上述 IR 为例,我们将图节点组合成一个 `prim::DifferentiableGraph_0`,用于具有 AD 公式的操作。对于尚未添加到 AD 公式中的操作,我们将在执行期间回退到 Autograd。
    • 优化反向路径很困难,而隐式广播语义使得自动微分的优化更加困难。PyTorch 使得编写张量操作变得很方便,无需担心形状,因为它会为您广播张量。对于性能而言,反向传递中的痛点是我们需要对这种可广播操作进行求和。这导致每个可广播操作的导数后面都有一个求和。由于我们目前无法融合归约操作,这导致 FusionGroup 分裂成多个小组,从而导致性能不佳。要解决这个问题,请参考 Thomas Viehmann 撰写的这篇精彩文章

    杂项优化

    • 除了上述步骤之外,我们还消除了 CUDA 内核启动和不必要的张量分配之间的开销。一个例子是当您进行张量设备查找时。这最初可能由于大量不必要的分配而导致性能不佳。当我们消除这些时,这导致内核启动之间的延迟从毫秒减少到纳秒。
    • 最后,自定义 LSTMCell 中可能应用了归一化,例如 LayerNorm。由于 LayerNorm 和其他归一化操作包含归约操作,因此很难完全融合它们。相反,我们自动将 LayerNorm 分解为统计计算(归约操作)+ 逐元素转换,然后将这些逐元素部分融合在一起。截至本文发布,我们的自动微分和图融合器基础设施存在一些限制,这使得当前仅支持推理模式。我们计划在未来的版本中添加反向支持。

    通过对操作融合、循环展开、批量矩阵乘法和一些杂项优化,我们可以从下图中看到我们的自定义 TorchScript LSTM 前向和后向性能的显著提升

    我们在这篇文章中还有许多没有涵盖的额外优化。除了本文中列出的那些,我们现在看到我们的自定义 LSTM 前向传播与 cuDNN 不相上下。我们还在努力进一步优化反向传播,并期望在未来的版本中看到改进。除了 TorchScript 提供的速度之外,我们还引入了一个更灵活的 API,使您能够手写更多自定义 RNN,这是 cuDNN 无法提供的。