我们希望在此预览一下 PyTorch 1.0 的路线图,这是 PyTorch 的下一个版本。在过去的一年中,通过 0.2、0.3 和 0.4 版本的迭代,PyTorch 已从类似 [Torch+Chainer] 的界面演变得更加简洁,增加了双向反向传播(double-backwards)、类 numpy 函数、高级索引,并移除了 Variable 的样板代码。目前,我们确信 API 已处于一个合理且稳定的状态,可以充满信心地发布 1.0 版本。
然而,1.0 不仅仅关乎接口的稳定性。
PyTorch 最大的优势之一是其一流的 Python 集成、命令式风格、简单的 API 以及丰富的选项。这些特点使 PyTorch 非常适合研究和快速迭代。
而它最大的缺点之一是生产支持不足。我们所说的生产支持,是指为了在大规模环境下高效运行模型而必须进行的无数工作。
- 导出到纯 C++ 运行时以用于大型项目
- 针对 iPhone、Android、高通及其他系统进行移动端优化
- 使用更高效的数据布局并执行算子融合以实现更快的推理(在大规模场景下,节省 10% 的速度或内存是一项巨大的胜利)
- 量化推理(例如 8-bit 推理)
初创公司、大型企业以及任何希望基于 PyTorch 构建产品的人都提出了对生产支持的需求。在 Facebook(PyTorch 最大的利益相关者),我们拥有 Caffe2,它一直是生产就绪的平台,运行在我们的数据中心,并部署在超过 10 亿台手机上,涵盖了八代 iPhone 和六代 Android CPU 架构。它在 Intel / ARM 上具备服务器优化的推理能力、TensorRT 支持以及生产所需的一切必要组件。考虑到所有这些价值都锁定在一个与 PyTorch 团队密切合作的平台上,我们决定将 PyTorch 与 Caffe2 结合,这为 PyTorch 带来了生产级别的就绪能力。
在不给研究人员和最终用户增加可用性问题的前提下支持生产功能,需要创造性的解决方案。
生产 != 给研究人员带来痛苦
增加生产能力意味着增加 API 的复杂性和模型的可配置选项。用户需要配置内存布局(NCHW 与 NHWC 与 N,C/32,H,W,32,每种都有不同的性能特征)、量化(8-bit?3-bit?)、底层算子融合(如果你使用了 Conv + BatchNorm + ReLU,我们将它们融合成一个单一的算子)、分离的后端选项(某些层使用 MKLDNN 后端,其他层使用 NNPACK 后端)等等。
PyTorch 的核心目标是提供一个出色的研究和快速迭代平台。因此,在添加所有这些优化的同时,我们一直遵循一个严格的设计准则:决不以牺牲可用性为代价。
为了实现这一目标,我们引入了 torch.jit,这是一个即时(JIT)编译器,它可以在运行时获取你的 PyTorch 模型并重写它们,以实现生产级的效率。JIT 编译器还可以将你的模型导出,以便在基于 Caffe2 组件的纯 C++ 运行时中运行。
在 1.0 版本中,你的代码将继续按原样工作,我们不会对现有 API 进行任何重大更改。
使模型变为生产就绪是一个可选的注解过程,它利用 torch.jit 编译器将你的模型导出到无 Python 的环境中,并提升其性能。让我们详细了解一下 JIT 编译器。
torch.jit:模型的 JIT 编译器
我们坚信,直接将模型描述为地道的 Python 代码所带来的生产力是很难被超越的。这正是 PyTorch 如此灵活的原因,但也意味着 PyTorch 几乎永远不知道你下一步将运行什么操作。然而,这对导出/生产化和重量级的自动性能优化构成了巨大的障碍,因为它们需要预先全面了解计算过程,甚至在执行之前就需要知道计算的样子。
我们提供了两种可选的方式来从你的代码中获取这些信息,一种基于追踪原生 Python 代码,另一种基于编译一个被注解的 Python 子集,将其转换为一种无 Python 的中间表示。经过深入讨论,我们认为它们在不同的场景下都很有用,因此你将能够自由地混合使用它们。
追踪模式(Tracing Mode)
PyTorch 追踪器 torch.jit.trace 是一个记录代码区域中执行的所有原生 PyTorch 操作及其数据依赖关系的函数。事实上,PyTorch 从 0.3 版本开始就有了追踪器,它一直用于通过 ONNX 导出模型。现在的变化是,你不再一定需要将追踪结果拿到别处运行——PyTorch 可以使用精心设计的高性能 C++ 运行时为你重新执行它。随着我们开发 PyTorch 1.0,该运行时将整合 Caffe2 提供的所有优化和硬件集成。
这种方法最大的好处是它并不关心你的 Python 代码是如何组织的——你可以追踪生成器或协程、模块或纯函数。由于我们只记录原生的 PyTorch 算子,这些细节对记录的追踪结果没有影响。然而,这种行为是一把双刃剑。例如,如果你的模型中有一个循环,它会在追踪中被展开,循环体会被复制,次数等同于循环运行的次数。这为零成本抽象提供了机会(例如,你可以循环遍历模块,而实际的追踪过程将没有循环开销!),但另一方面,这也会影响数据依赖循环(例如处理长度可变的序列),有效地将单一长度硬编码到追踪结果中。
对于不包含循环和 if 语句的网络,追踪是非侵入性的,并且足够稳健,可以处理各种编码风格。这个代码示例展示了追踪的样子。
# This will run your nn.Module or regular Python function with the example
# input that you provided. The returned callable can be used to re-execute
# all operations that happened during the example run, but it will no longer
# use the Python interpreter.
from torch.jit import trace
traced_model = trace(model, example_input=input)
traced_fn = trace(fn, example_input=input)
# The training loop doesn't change. Traced model behaves exactly like an
# nn.Module, except that you can't edit what it does or change its attributes.
# Think of it as a "frozen module".
for input, target in data_loader:
loss = loss_fn(traced_model(input), target)
脚本模式(Script Mode)
追踪模式是最大限度减少对代码影响的好方法,但我们也对那些从根本上利用控制流(例如 RNN)的模型感到非常兴奋。我们对此的解决方案是脚本模式。
在这种模式下,你编写常规的 Python 函数,只是不能使用某些更复杂的语言特性。一旦你隔离了所需的功能,只需通过使用 @script 装饰器进行装饰,告诉我们你希望该函数被编译即可。这个注解会将你的 Python 函数直接转换为我们的高性能 C++ 运行时。这使我们能够恢复所有的 PyTorch 操作以及循环和条件判断。它们将被嵌入到该函数的内部表示中,并在每次运行该函数时被记录。
from torch.jit import script
@script
def rnn_loop(x):
hidden = None
for x_t in x.split(1):
x, hidden = model(x, hidden)
return x
优化与导出
无论你使用追踪还是 @script,结果都是你的模型的无 Python 表示,可用于优化模型或将模型从 Python 环境中导出以用于生产环境。
将模型的更大片段提取为中间表示,使得进行复杂的全程序优化以及将计算卸载到专门的 AI 加速器(这些加速器在计算图上运行)成为可能。我们已经开始开发这些优化,包括融合 GPU 操作以提高小型 RNN 模型性能的阶段。
它还允许我们利用 Caffe2 目前提供的高性能后端来高效运行模型。此外,@script 函数(和模块!)可以完全导出为 ONNX,并保留其动态特性,这样你就可以使用 Caffe2 的模型执行器,或通过将模型传输到任何其他支持 ONNX 的框架,在无 Python 的环境中轻松运行它们。
可用性
我们非常关心维持当前水平的可用性,我们也知道代码不在 Python 中直接执行会导致调试变得困难,但这是我们深思熟虑的问题,我们正在确保你不会被锁定在一种完全不同的编程语言中。
首先,我们遵循“按需付费”原则——如果你不需要优化或导出模型,你就不必使用这些新功能,也不会遇到任何负面影响。此外,追踪或 @script 模块/函数的使用可以循序渐进。例如,所有这些行为都是允许的:你可以追踪模型的一部分,并在更大的非追踪模型中使用该追踪。你可以对 90% 的模型使用追踪,并对其中包含控制流的子模块使用 @script。你可以使用 @script 编写一个函数,并让它调用一个原生 Python 函数。如果 @script 函数中出现异常,你可以移除注解,代码将在原生 Python 中执行,从而可以使用你最喜欢的工具和方法轻松调试。将追踪和 @script 视为类似 MyPy 或 TypeScript 的类型注解——每个额外的注解都可以逐步测试,在你想进行优化或生产化之前,它们都不是必须的。
最重要的是,这些模式将被构建到 PyTorch 的核心中,以便它们可以与你现有的代码无缝地混合使用。
注意:这些组件使用“JIT”这个名字其实不太准确,这是历史原因造成的。PyTorch 中的追踪/函数执行最初是一个优化型 JIT 编译器,用于生成融合的 CUDA 内核,但后来发展到涵盖了优化、@script 和导出。当它准备好发布时,我们很可能会将此功能重命名为“混合前端”(hybrid frontend),但我们在这里使用它在代码中的原始命名,以便在你跟进我们的开发时能够对应上。
其他更改与改进
生产支持是 1.0 版本的核心功能,但作为标准发布流程的一部分,我们将继续优化和修复 PyTorch 的其他部分。
在后端方面,PyTorch 将会发生一些变化,这可能会影响用户编写的 C 和 C++ 扩展。我们正在替换(或重构)后端 ATen 库,以整合来自 Caffe2 的功能和优化。
最后的话
我们的目标是在今年夏季的某个时候发布 1.0 版本。你可以通过 Pull Requests 页面关注我们的进展。
你可以从 Caffe2 项目的角度阅读相关内容:https://caffe2.ai/blog/2018/05/02/Caffe2_PyTorch_1_0.html