博客

使用 TorchScript 优化 CUDA 循环神经网络

作者: 2019年5月1日2024年11月16日暂无评论

本周,我们正式发布了 PyTorch 1.1,这是 PyTorch 1.0 的一个大型功能更新。我们新增的功能之一是利用 TorchScript(PyTorch JIT)更好地支持快速、自定义的循环神经网络(fastrnns)(https://pytorch.ac.cn/docs/stable/jit.html)。

RNN 是一类非常流行的模型,在各种不同形态和规模的 NLP 任务中展现出了出色的性能。PyTorch 实现了其中一些最常用的模型,如 Elman RNNGRULSTM,以及多层和双向变体。

然而,许多用户希望根据最新文献的思路来实现自定义的 RNN。将 层归一化 (Layer Normalization) 应用于 LSTM 就是一个典型的用例。由于 PyTorch 的 CUDA LSTM 实现使用了融合内核 (fused kernel),因此很难在其中插入归一化层,甚至很难修改底层的 LSTM 实现。许多用户转向使用标准 PyTorch 算子编写自定义实现,但此类代码开销巨大:大多数 PyTorch 操作会在 GPU 上至少启动一个内核,而 RNN 由于其循环特性通常会运行大量操作。不过,我们可以利用 TorchScript 来融合操作并自动优化代码,从而在 GPU 上启动更少、更优化的内核。

我们的目标是让用户能够在 TorchScript 中编写快速的自定义 RNN,而无需编写专门的 CUDA 内核即可达到相近的性能。在本文中,我们将提供一个教程,介绍如何使用 TorchScript 编写您自己的快速 RNN。为了更好地理解 TorchScript 所应用的优化,我们将研究这些优化在标准 LSTM 实现上是如何工作的,但其中的大多数优化也适用于通用 RNN。

编写自定义 RNN

首先,您可以使用此文件作为模板来编写您自己的自定义 RNN。

我们一直在不断改进基础设施以提高性能。如果您希望获得 TorchScript 目前提供的速度/优化(如算子融合、批量矩阵乘法等),请遵循以下指南。下一节将深入解释这些优化。

  1. 如果自定义操作都是逐元素 (element-wise) 的,那再好不过了,因为您可以自动获得 PyTorch JIT 算子融合带来的好处!
  2. 如果您有更复杂的操作(例如归约操作与逐元素操作混合),请考虑将归约操作和逐元素操作分开归组,以便将逐元素操作融合到一个单一的融合组 (fusion group) 中。
  3. 如果您想了解自定义 RNN 中哪些内容被融合了,可以使用 graph_for 来查看该操作的优化图。以 LSTMCell 为例:
# get inputs and states for LSTMCell

 inputs = get_lstm_inputs()

 # instantiate a ScriptModule

 cell = LSTMCell(input_size, hidden_size)

 # print the optimized graph using graph_for

 out = cell(inputs)
 print(cell.graph_for(inputs))

    这将为提供的特定输入生成优化的 TorchScript 图(又称 PyTorch JIT IR)。

    graph(%x : Float(*, *),
             %hx : Float(*, *),
             %cx : Float(*, *),
             %w_ih : Float(*, *),
             %w_hh : Float(*, *),
             %b_ih : Float(*),
             %b_hh : Float(*)):
         %hy : Float(*, *), %cy : Float(*, *) = prim::DifferentiableGraph_0(%cx, %b_hh, %b_ih, %hx, %w_hh, %x, %w_ih)
         %30 : (Float(*, *), Float(*, *)) = prim::TupleConstruct(%hy, %cy)
         return (%30)
         with prim::DifferentiableGraph_0 = graph(%13 : Float(*, *),
             %29 : Float(*),
             %33 : Float(*),
             %40 : Float(*, *),
             %43 : Float(*, *),
             %45 : Float(*, *),
             %48 : Float(*, *)):
         %49 : Float(*, *) = aten::t(%48)
         %47 : Float(*, *) = aten::mm(%45, %49)
         %44 : Float(*, *) = aten::t(%43)
         %42 : Float(*, *) = aten::mm(%40, %44)
         ...some broadcast sizes operations...
         %hy : Float(*, *), %287 : Float(*, *), %cy : Float(*, *), %outgate.1 : Float(*, *), %cellgate.1 : Float(*, *), %forgetgate.1 : Float(*, *), %ingate.1 : Float(*, *) = prim::FusionGroup_0(%13, %346, %345, %344, %343)
         ...some broadcast sizes operations...
         return (%hy, %cy, %49, %44, %196, %199, %340, %192, %325, %185, %ingate.1, %forgetgate.1, %cellgate.1, %outgate.1, %395, %396, %287)
         with prim::FusionGroup_0 = graph(%13 : Float(*, *),
             %71 : Tensor,
             %76 : Tensor,
             %81 : Tensor,
             %86 : Tensor):
         ...some chunks, constants, and add operations...
         %ingate.1 : Float(*, *) = aten::sigmoid(%38)
         %forgetgate.1 : Float(*, *) = aten::sigmoid(%34)
         %cellgate.1 : Float(*, *) = aten::tanh(%30)
         %outgate.1 : Float(*, *) = aten::sigmoid(%26)
         %14 : Float(*, *) = aten::mul(%forgetgate.1, %13)
         %11 : Float(*, *) = aten::mul(%ingate.1, %cellgate.1)
         %cy : Float(*, *) = aten::add(%14, %11, %69)
         %4 : Float(*, *) = aten::tanh(%cy)
         %hy : Float(*, *) = aten::mul(%outgate.1, %4)
         return (%hy, %4, %cy, %outgate.1, %cellgate.1, %forgetgate.1, %ingate.1)

    从上图中我们可以看到,它有一个 prim::FusionGroup_0 子图,它融合了 LSTMCell 中的所有逐元素操作(转置和矩阵乘法不是逐元素操作)。某些图节点起初可能难以理解,但我们将在优化部分进行解释。为了保持简洁,我们在本文中省略了一些仅用于确保正确性的冗长算子。

    变长序列的最佳实践

    TorchScript 不支持 PackedSequence。通常,在处理变长序列时,最好的做法是将它们填充 (pad) 成单个张量,并将该张量发送到 TorchScript LSTM 中。以下是一个示例:

    sequences = [...] # List[Tensor], each Tensor is T' x C
    padded = torch.utils.rnn.pad_sequence(sequences)
    lengths = [seq.size(0) for seq in sequences]
    padded  # T x N x C, where N is batch size and T is the max of all T'
    
    model = LSTM(...)
    output, hiddens = model(padded)
    output  # T x N x C
    

    当然,output 在填充区域可能会包含一些垃圾数据;请使用 lengths 来跟踪您不需要的部分。

    优化

    现在,我们将解释 PyTorch JIT 为加速自定义 RNN 所执行的优化。我们将使用 TorchScript 中的一个简单自定义 LSTM 模型来说明这些优化,但其中许多优化是通用的,也适用于其他 RNN。

    为了说明我们所做的优化以及我们如何从中受益,我们将运行一个用 TorchScript 编写的简单自定义 LSTM 模型(您可以参考 custom_lstm.py 中的代码或下方的代码片段)并对我们的改进进行计时。

    我们在配备了 2 个 Intel Xeon 处理器和 1 个 Nvidia P100 的机器上设置了环境,安装了 cuDNN v7.3 和 CUDA 9.2。LSTM 模型的基本设置如下:

    input_size = 512
    hidden_size = 512
    mini_batch = 64
    numLayers = 1
    seq_length = 100 
    

    PyTorch JIT 所做的最重要的事情是将 Python 程序编译为 PyTorch JIT IR,这是一种用于对程序图结构进行建模的中间表示。该 IR 可以从全程序优化、硬件加速中受益,并总体上有潜力提供巨大的计算增益。在此示例中,我们仅使用 JIT 提供的编译器优化传递来运行初始的 TorchScript 模型,包括公共子表达式消除、常量合并、常量传播、死代码消除和一些窥孔优化。在预热后,我们运行模型训练 100 次并计算平均训练时间。模型前向传播时间的初始结果约为 27ms,反向传播时间约为 64ms,这与 PyTorch cuDNN LSTM 提供的性能还有一定差距。接下来,我们将解释我们在提高训练或推理性能方面所做的主要优化,从 LSTMCell 和 LSTMLayer 开始,以及一些杂项优化。

    LSTM Cell (前向)

    LSTM 中的几乎所有计算都发生在 LSTMCell 中,因此查看它包含的计算以及如何提高它们的速度非常重要。以下是 TorchScript 中 LSTMCell 的一个示例实现:

    class LSTMCell(jit.ScriptModule):
        def __init__(self, input_size, hidden_size):
            super(LSTMCell, self).__init__()
            self.input_size = input_size
            self.hidden_size = hidden_size
            self.weight_ih = Parameter(torch.randn(4 * hidden_size, input_size))
            self.weight_hh = Parameter(torch.randn(4 * hidden_size, hidden_size))
            self.bias_ih = Parameter(torch.randn(4 * hidden_size))
            self.bias_hh = Parameter(torch.randn(4 * hidden_size))
    
        @jit.script_method
        def forward(self, input, state):
            # type: (Tensor, Tuple[Tensor, Tensor]) -> Tuple[Tensor, Tuple[Tensor, Tensor]]
            hx, cx = state
            gates = (torch.mm(input, self.weight_ih.t()) + self.bias_ih +
                     torch.mm(hx, self.weight_hh.t()) + self.bias_hh)
            ingate, forgetgate, cellgate, outgate = gates.chunk(4, 1)
    
            ingate = torch.sigmoid(ingate)
            forgetgate = torch.sigmoid(forgetgate)
            cellgate = torch.tanh(cellgate)
            outgate = torch.sigmoid(outgate)
    
            cy = (forgetgate * cx) + (ingate * cellgate)
            hy = outgate * torch.tanh(cy)
    
            return hy, (hy, cy)
    

    TorchScript 生成的这种图表示 (IR) 实现了多项优化和可扩展计算。除了我们可以做的典型编译器优化(CSE、常量传播等)外,我们还可以运行其他 IR 转换来让代码运行得更快。

    • 逐元素算子融合。PyTorch JIT 会自动融合逐元素算子,因此当您拥有相邻的逐元素算子时,JIT 会自动将所有这些操作归组为一个 FusionGroup,然后该 FusionGroup 可以通过单个 GPU/CPU 内核启动并在一次传递中执行。这避免了每次操作昂贵的内存读写。
    • 重新排序分块 (chunk) 和点对点 (pointwise) 操作以实现更多融合。LSTM 单元将门相加(一个点对点操作),然后将这些门分块为四部分:ifco 门。然后,它像上面那样对 ifco 门执行点对点操作。这实际上导致了两个融合组:一个用于分块前的逐元素操作,另一个用于分块后的逐元素操作。这里有趣的是,点对点操作与 torch.chunk 是可交换的:我们无需在某些输入张量上执行点对点操作并对输出进行分块,而是可以先对输入张量进行分块,然后在输出张量上执行相同的点对点操作。通过将分块操作移至第一个融合组之前,我们可以将第一个和第二个融合组合并为一个大组。
    • 在 CPU 上创建张量非常昂贵,但目前正在进行改进以加快此过程。目前,LSTMCell 运行三个 CUDA 内核:两个 gemm 内核和一个用于单个点对点组的内核。我们注意到的一点是,第二个 gemm 完成与单个点对点组开始之间存在很大的间隔。这段时间 GPU 处于空闲状态,没有执行任何操作。深入研究后,我们发现问题在于 torch.chunk 会构造新的张量,而张量构造的速度不够快。我们没有构造新的 Tensor 对象,而是教导融合编译器如何操作数据指针和步长,以便在将其发送到融合内核之前执行 torch.chunk,从而缩短了第二个 gemm 和逐元素融合组启动之间的空闲时间。这使我们的 LSTM 前向传播速度提高了约 1.2 倍。

    通过上述技巧,我们能够将几乎所有的 LSTMCell 前向图(除了两个 gemm 内核)融合到一个单一的融合组中,这对应于上述 IR 图中的 prim::FusionGroup_0。然后它将被启动到一个单一的融合内核中执行。通过这些优化,模型性能显著提高,平均前向传播时间从 27ms 缩短到 10ms(约 1.7 倍加速),平均反向传播时间从 64ms 缩短到 27ms(1.37 倍加速)。

    LSTM Layer (前向)

    class LSTMLayer(jit.ScriptModule):
        def __init__(self, cell, *cell_args):
            super(LSTMLayer, self).__init__()
            self.cell = cell(*cell_args)
    
        @jit.script_method
        def forward(self, input, state):
            # type: (Tensor, Tuple[Tensor, Tensor]) -> Tuple[Tensor, Tuple[Tensor, Tensor]]
            inputs = input.unbind(0)
            outputs = torch.jit.annotate(List[Tensor], [])
            for i in range(len(inputs)):
                out, state = self.cell(inputs[i], state)
                outputs += [out]
            return torch.stack(outputs), state
    

    我们对为 TorchScript LSTM 生成的 IR 进行了几项优化以提高性能,部分优化示例如下:

    • 循环展开:我们自动展开代码中的循环(对于大循环,我们展开其中的一小部分),这使我们能够对 for 循环控制流进行进一步优化。例如,融合器可以将循环体迭代过程中的操作融合在一起,这对于像 LSTM 这样控制流密集的模型来说,会带来很好的性能提升。
    • 批量矩阵乘法:对于输入是预乘的 RNN(即模型对相同的 LHS 或 RHS 有大量矩阵乘法),我们可以有效地将这些操作批量处理为单个矩阵乘法,同时对输出进行分块以实现等效语义。

    通过应用这些技术,我们的前向传播时间额外缩短了 1.6ms 至 8.4ms(1.2 倍加速),反向传播时间缩短了 7ms 至约 20ms(1.35 倍加速)。

    LSTM Layer (反向)

    • “树”批量矩阵乘法:通常在 LSTM 反向传播图中,单个权重被多次重复使用,形成一棵树,其中叶子是矩阵乘法,节点是加法。这些节点可以通过在不同维度上连接 LHS 和 RHS 来合并,然后作为单个矩阵乘法计算。等价公式可表示如下:L1*R1 + L2*R2 = torch.cat((L1,L2), dim=1) * torch.cat((R1,R2), dim=0)。
    • 自动求导 (Autograd) 是 PyTorch 成为优雅 ML 框架的关键组件。因此,我们将其引入了 PyTorch JIT,但使用了一种在 IR 级别工作的新型自动微分 (AD) 机制。JIT 自动微分会将前向图切分为符号可微分的子图,并为这些子图生成反向节点。以引用的 IR 为例,我们将具有 AD 公式操作的图节点组合成一个单一的 prim::DifferentiableGraph_0。对于尚未加入 AD 公式的操作,我们将在执行期间回退到 Autograd。
    • 优化反向路径很困难,且隐式广播语义使自动微分的优化变得更加复杂。PyTorch 通过为您自动广播张量,使编写无需考虑形状的张量操作变得非常方便。为了追求性能,反向传播的难点在于我们需要对此类可广播操作进行求和。这导致每个可广播操作的导数后面都跟着一个求和操作。由于我们目前无法融合归约操作,这会导致 FusionGroups 分裂成多个小组,从而导致性能下降。要处理此问题,请参阅 Thomas Viehmann 撰写的这篇精彩文章

    杂项优化

    • 除了上述步骤外,我们还消除了 CUDA 内核启动之间的开销和不必要的张量分配。一个例子是当您执行张量设备查找时。这在初期会导致性能低下,且伴随大量不必要的分配。当我们移除这些内容后,内核启动之间的耗时从毫秒级减少到了纳秒级。
    • 最后,自定义 LSTMCell(如 LayerNorm)中可能应用了归一化。由于 LayerNorm 和其他归一化算子包含归约操作,因此很难将其整体融合。相反,我们将 LayerNorm 自动分解为统计计算(归约操作)+ 逐元素转换,然后将这些逐元素部分融合在一起。截至本文发表时,我们的自动微分和图融合基础设施存在一些限制,目前的版本仅支持推理模式。我们计划在未来的版本中添加反向传播支持。

    通过上述关于算子融合、循环展开、批量矩阵乘法和一些杂项优化,我们可以从下图中清楚地看到自定义 TorchScript LSTM 在前向和反向传播性能上的提升:

    本文未涵盖许多其他优化。除了本文中介绍的这些之外,我们现在看到我们的自定义 LSTM 前向传播已经与 cuDNN 持平。我们也在致力于进一步优化反向传播,并预计在未来的版本中看到改进。除了 TorchScript 提供的速度优势外,我们还引入了更灵活的 API,使您能够手写更多 cuDNN 无法提供的自定义 RNN。