跳转到主要内容
博客

通往 1.0 之路:生产就绪的 PyTorch

作者: 2018 年 5 月 2 日2024 年 11 月 16 日暂无评论

我们想向您预告 PyTorch 下一个版本 PyTorch 1.0 的路线图。在过去的一年里,0.2、0.3 和 0.4 版本将 PyTorch 从 [Torch+Chainer] 类似接口转变为更简洁的接口,增加了双反向传播、类似 NumPy 的函数、高级索引,并移除了 Variable 样板代码。目前,我们确信 API 已处于合理且稳定的状态,可以自信地发布 1.0 版本。

然而,1.0 不仅仅是关于接口的稳定性。

PyTorch 最大的优势之一是其一流的 Python 集成、命令式风格、简单的 API 和选项。这些是使 PyTorch 适用于研究和可编程性的方面。

它最大的缺点之一是生产支持。我们所说的生产支持是指为了在大规模下高效运行模型,人们必须对模型做无数的事情:

  • 导出到仅 C++ 运行时,用于大型项目
  • 优化 iPhone、Android、Qualcomm 和其他系统上的移动系统
  • 使用更高效的数据布局并执行内核融合以实现更快的推理(在规模上节省 10% 的速度或内存是巨大的胜利)
  • 量化推理(例如 8 位推理)

初创公司、大公司以及任何希望围绕 PyTorch 构建产品的人都要求提供生产支持。在 Facebook(PyTorch 最大的利益相关者),我们拥有 Caffe2,它一直是生产就绪平台,在我们的数据中心运行并出货到超过 10 亿部手机,涵盖八代 iPhone 和六代 Android CPU 架构。它在 Intel / ARM 上提供服务器优化的推理、TensorRT 支持以及生产所需的所有必要组件。考虑到所有这些价值都锁定在一个 PyTorch 团队密切合作的平台中,**我们决定将 PyTorch 和 Caffe2 结合起来,从而为 PyTorch 带来生产级就绪性**。

在不给我们的研究人员和最终用户增加可用性问题的情况下支持生产功能需要创造性的解决方案。

生产 ≠ 研究人员的痛苦

增加生产能力涉及增加 API 的复杂性和模型可配置选项的数量。人们需要配置内存布局(NCHW vs NHWC vs N,C/32,H,W,32,每种都提供不同的性能特性)、量化(8 位?3 位?)、低级内核的融合(您使用了 Conv + BatchNorm + ReLU,我们将它们融合到一个内核中)、单独的后端选项(某些层使用 MKLDNN 后端,其他层使用 NNPACK 后端)等。

PyTorch 的核心目标是为研究和可编程性提供一个出色的平台。因此,在添加所有这些优化时,我们一直坚持一个严格的设计约束,即绝不以可用性为代价。

为了实现这一点,我们引入了 torch.jit,这是一个即时 (JIT) 编译器,它在运行时接收您的 PyTorch 模型并将其重写以实现生产效率。JIT 编译器还可以将您的模型导出为基于 Caffe2 组件的纯 C++ 运行时。

在 1.0 版本中,您的代码将继续按原样工作,我们不会对现有 API 进行任何重大更改。

使您的模型达到生产就绪状态是一个可选的注解,它使用 torch.jit 编译器将您的模型导出到无 Python 环境中,并提高其性能。让我们详细介绍一下 JIT 编译器。

torch.jit:模型的 JIT 编译器

我们坚信,直接以地道的 Python 代码指定模型所带来的生产力是难以匹敌的。这使得 PyTorch 如此灵活,但这也意味着 PyTorch 几乎永远不知道您下一步将运行什么操作。然而,这对于导出/生产化和重量级自动性能优化来说是一个巨大的障碍,因为它们需要对计算在执行前将如何进行有全面的预先了解。

我们提供了两种可选的方式从您的代码中恢复此信息,一种基于跟踪原生 python 代码,另一种基于将 python 语言的子集编译为无 python 的中间表示。经过彻底讨论,我们得出结论,它们都将在不同的上下文中发挥作用,因此您可以自由地混合和匹配它们。

追踪模式

PyTorch 追踪器,torch.jit.trace,是一个函数,它记录代码区域中执行的所有原生 PyTorch 操作,以及它们之间的数据依赖关系。事实上,PyTorch 从 0.3 版本开始就有了追踪器,并一直用于通过 ONNX 导出模型。现在,不同之处在于您不再需要将追踪结果拿到其他地方运行——PyTorch 可以使用精心设计的高性能 C++ 运行时为您重新执行它。随着我们开发 PyTorch 1.0,这个运行时将集成 Caffe2 提供的所有优化和硬件集成。

这种方法最大的好处是它不关心您的 Python 代码结构——您可以追踪生成器或协程、模块或纯函数。由于我们只记录原生 PyTorch 运算符,这些细节对记录的追踪没有影响。然而,这种行为是一把双刃剑。例如,如果您的模型中有一个循环,它将在追踪中展开,为循环运行的次数插入循环体的副本。这为零成本抽象提供了机会(例如,您可以遍历模块,实际的追踪将是无循环开销的!),但另一方面,这也会影响数据依赖循环(例如,处理不同长度的序列),有效地将单个长度硬编码到追踪中。

对于不包含循环和 if 语句的网络,追踪是非侵入性的,并且足够健壮,可以处理各种编码风格。此代码示例说明了追踪的样子:

# This will run your nn.Module or regular Python function with the example
# input that you provided. The returned callable can be used to re-execute
# all operations that happened during the example run, but it will no longer
# use the Python interpreter.
from torch.jit import trace
traced_model = trace(model, example_input=input)
traced_fn = trace(fn, example_input=input)

# The training loop doesn't change. Traced model behaves exactly like an
# nn.Module, except that you can't edit what it does or change its attributes.
# Think of it as a "frozen module".
for input, target in data_loader:
    loss = loss_fn(traced_model(input), target)

脚本模式

追踪模式是最大限度减少对代码影响的好方法,但我们也对那些根本性地利用控制流的模型(如 RNN)感到非常兴奋。我们的解决方案是脚本模式。

在这种情况下,您编写一个普通的 Python 函数,只是您不能再使用某些更复杂的语言特性。一旦您分离出所需的功能,您可以通过使用 @script 装饰器对其进行装饰来告诉我们您希望该函数被编译。此注解将您的 python 函数直接转换为我们的高性能 C++ 运行时。这使我们能够恢复所有 PyTorch 操作以及循环和条件。它们将被嵌入到该函数的内部表示中,并且每次运行该函数时都会被考虑在内。

from torch.jit import script

@script
def rnn_loop(x):
    hidden = None
    for x_t in x.split(1):
        x, hidden = model(x, hidden)
    return x

优化与导出

无论您使用追踪还是 @script,结果都是模型的无 Python 表示,可用于优化模型或将模型从 Python 导出以在生产环境中使用。

将模型的更大片段提取到中间表示中,可以进行复杂的全程序优化,并将计算卸载到专门的 AI 加速器,这些加速器在计算图上运行。我们已经开始开发这些优化,包括融合 GPU 操作的通道,以提高小型 RNN 模型的性能。

它还允许我们使用 Caffe2 中现有​​的高性能后端来高效运行模型。此外,@script 函数(和模块!)可以完全导出到 ONNX,并保留其动态特性,这样您就可以使用 Caffe2 的模型执行器或将模型传输到任何其他支持 ONNX 的框架,轻松地在无 Python 环境中运行它们。

可用性

我们非常注重保持当前的可用性水平,我们知道代码不直接在 Python 中执行会导致更难调试,但这是我们经常考虑的问题,我们正在确保您不会被锁定在一种完全不同的编程语言中。

首先,我们遵循“按需付费”原则——如果您不需要优化或导出模型,则无需使用这些新功能,也不会看到任何缺点。此外,可以逐步使用追踪或 @script 模块/函数。例如,允许所有这些行为:您可以追踪模型的一部分,并在更大的未追踪模型中使用追踪。您可以追踪模型 90% 的部分,并对实际包含某些控制流的子模块使用 @script。您可以使用 @script 编写函数,并让它调用原生 Python 函数。如果 @script 函数中出现不正确的情况,您可以删除注解,代码将在原生 Python 中执行,在那里可以使用您喜欢的工具和方法轻松调试。将追踪和 @script 视为使用 MyPy 或 TypeScript 的类型注解——每个附加注解都可以逐步测试,并且在您需要优化或生产化之前,都不需要它们。

最重要的是,这些模式将内置于 PyTorch 的核心,以便它们与您现有代码的混合和匹配可以无缝发生。

注意:这些组件的 JIT 名称有点用词不当,是出于历史原因。PyTorch 中的追踪/函数执行最初是一个生成融合 CUDA 内核的优化 JIT 编译器,但后来发展到包含优化、@script 和导出。当它准备发布时,我们可能会将此功能重命名为混合前端,但我们希望在此处将其按代码中的名称呈现,以便您在开发过程中能够跟进。

其他变更和改进

生产支持是 1.0 的重要特性,但我们将继续按照标准发布流程优化和修复 PyTorch 的其他部分。

在后端方面,PyTorch 将会有一些变化,这可能会影响用户编写的 C 和 C++ 扩展。我们正在替换(或重构)后端 ATen 库,以整合 Caffe2 的功能和优化。

结语

我们计划在夏季的某个时候发布 1.0 版本。您可以在拉取请求页面上关注我们的进展。

您可以从 Caffe2 项目的角度阅读此内容:https://caffe2.ai/blog/2018/05/02/Caffe2_PyTorch_1_0.html