作者:Dhruv Matani, Suraj Subramanian

引言

TL;DR:在移动设备、SBC(单板计算机)和物联网设备上运行 PyTorch 可能具有挑战性。编译后的 PyTorch 库非常庞大,包含设备端使用场景可能不需要的依赖项。

要在设备上运行特定的模型集,我们实际上只需要 PyTorch 库中一小部分功能。我们发现,使用选择性构建生成的 PyTorch 运行时可以将二进制文件大小减少高达 90%(对于 Linux x86-64 构建上的 CPU 和 QuantizedCPU 后端)。在本博客中,我们将分享我们使用选择性构建生成模型特定最小运行时的经验,并向您展示如何做到这一点。

为什么这对应用开发者很重要?

使用选择性构建生成的 PyTorch 运行时可以将 AI 应用的大小减少 30 多 MB - 这对于典型的移动应用来说是显著的减少!使移动应用更轻量化有很多好处 - 它们可以在更广泛的设备上运行,消耗更少的蜂窝数据,并且可以在用户设备上更快地下载和更新。

开发者体验是怎样的?

此方法可以与任何现有的 PyTorch Mobile 部署工作流程无缝协作。您只需要将通用 PyTorch 运行时库替换为您希望在应用中使用的特定模型定制的运行时。此过程的一般步骤是

  1. 检测模式下构建 PyTorch 运行时(这称为 PyTorch 的检测构建)。这将记录使用的运算符、内核和功能。
  2. 使用提供的 model_tracer 二进制文件在此检测构建上运行您的模型。这将生成一个 YAML 文件,其中存储了您的模型使用的所有功能。这些功能将保留在最小运行时中。
  3. 使用此 YAML 文件作为输入构建 PyTorch。这就是选择性构建技术,它可以极大地减小最终 PyTorch 二进制文件的大小。
  4. 使用这个选择性构建的 PyTorch 库来减小您的移动应用的大小!

以特殊的“检测”模式构建 PyTorch 运行时(通过传递 TRACING_BASED=1 构建选项)会生成 PyTorch 的检测构建运行时,以及一个 model_tracer 二进制文件。使用此构建运行模型可以追踪模型使用的 PyTorch 部分。

图 1:PyTorch 的检测构建

# Clone the PyTorch repo
git clone https://github.com/pytorch/pytorch.git
cd pytorch

# Build the model_tracer
USE_NUMPY=0 USE_DISTRIBUTED=0 USE_CUDA=0 TRACING_BASED=1 \
  python setup.py develop

现在,此检测构建用于使用代表性输入运行模型推理。model_tracer 二进制文件会观察推理运行期间激活的检测构建部分,并将其转储到 YAML 文件中。

图 2:在检测构建上运行模型生成的 YAML 文件

# Generate YAML file
./build/bin/model_tracer \
  --model_input_path /tmp/path_to_model.ptl \
  --build_yaml_path /tmp/selected_ops.yaml

现在我们再次构建 PyTorch 运行时,但这次使用跟踪器生成的 YAML 文件。运行时现在仅包含此模型所需的那些部分。这在下图中称为“选择性构建的 PyTorch 运行时”

# Clean out cached configuration
make clean

# Build PyTorch using Selected Operators (from the YAML file)
# using the host toolchain, and use this generated library
BUILD_PYTORCH_MOBILE_WITH_HOST_TOOLCHAIN=1 \
USE_LIGHTWEIGHT_DISPATCH=0 \
BUILD_LITE_INTERPRETER=1 \
SELECTED_OP_LIST=/tmp/selected_ops.yaml \
TRACING_BASED=1 \
  ./scripts/build_mobile.sh

图 3:PyTorch 的选择性构建以及在选择性构建的 PyTorch 运行时上执行模型

给我看代码!

我们整理了一个 notebook 来演示使用一个简单的 PyTorch 模型,上述过程在代码中是什么样子。

对于更实际地在 Android/iOS 上部署此方法的教程,本教程应该会有帮助。

技术常见问题

为什么 PyTorch 的选择性构建需要追踪?

在 PyTorch 中,CPU 内核可以通过 PyTorch Dispatcher 调用其他运算符。仅仅包含模型直接调用的根运算符集是不够的,因为在底层可能会传递性地调用更多运算符。在代表性输入上运行模型并观察实际调用的运算符列表(即“追踪”)是确定 PyTorch 使用哪些部分的最准确方法。

此外,诸如内核应处理哪些 dtypes 等因素也是运行时功能,这些功能取决于提供给模型的实际输入。因此,追踪机制非常适合此目的。

使用基于追踪的选择性构建可以选择(纳入或排除)哪些功能?

在基于追踪的选择性构建过程中,可以为 PyTorch 运行时选择以下功能

  1. CPU/QuantizedCPU 内核,用于 PyTorch 的 ATen 运算符:如果一个 PyTorch 运算符不是选择性构建运行时所针对的模型所需的,那么该 CPU 内核的注册将在运行时中被忽略。这通过 Torchgen 代码生成器控制。
  2. Primary Operators:这由名为 TORCH_SELECTIVE_SCHEMA 的宏控制(通过模板化选择性构建),它根据生成的头文件中的信息选择或取消选择一个 primary operator。
  3. 处理 特定 dtypes 的 CPU 内核代码:这是通过在宏 AT_PRIVATE_CHECK_SELECTIVE_BUILD 生成的 switch case 中特定 case 语句生成异常抛出来实现的。
  4. 注册扩展 PyTorch 的自定义 C++ 类:这由宏 TORCH_SELECTIVE_CLASS 控制,该宏可在注册自定义 C++ 类时使用。torch::selective_class_<> 助手应与宏 TORCH_SELECTIVE_CLASS 结合使用。

构建过程中使用的 YAML 文件结构是什么样的?

追踪后生成的 YAML 文件如下例所示。它编码了上述指定的所有“可选择”构建功能元素。

include_all_non_op_selectives: false
build_features: []
operators:
    aten::add.Tensor:
        is_used_for_training: false
        is_root_operator: true
        include_all_overloads: false
    aten::len.t:
        is_used_for_training: false
        is_root_operator: true
        include_all_overloads: false
kernel_metadata:
    _local_scalar_dense_cpu:
    - Float
    add_stub:
    - Float
    copy_:
    - Bool
    - Byte
    mul_cpu:
    - Float
custom_classes: []

代码究竟是如何从生成的二进制文件中移除的?

根据具体情况,主要有两种技术用于向编译器和链接器提示未使用和不可达的代码。然后,这些代码被编译器或链接器作为不可达代码清除。

[1] 由链接器移除的未引用函数

当编译后的目标文件中存在一个未被任何可见函数传递引用的函数时,链接器会移除它(如果提供了正确的构建标志)。选择性构建系统在两种场景下利用了这一点。

Dispatcher 中的内核注册

如果不需要某个运算符的内核,则不会将其注册到 dispatcher。未注册的内核意味着该函数不可达,它将被链接器移除。

模板化选择性构建

这里的总体思想是,使用类模板特化来选择一个类,该类根据函数是否被使用来捕获对函数的引用或不捕获,然后链接器可以清理掉未引用的函数。

例如,在下面的代码中,没有对函数“fn2”的引用,因此它将被链接器清理,因为它在任何地方都没有被引用。

#include <vector>
#include <cstdio>

template <typename T, bool>
struct FunctionSelector {
    T fn_;
    FunctionSelector(T fn): fn_(fn) {}
    T get() { return this->fn_; }
};

// The "false" specialization of this class does NOT retain the argument passed
// to the class constructor, which means that the function pointer passed in
// is considered to be unreferenced in the program (unless it is referenced
// elsewhere).
template <typename T>
struct FunctionSelector<T, false> {
    FunctionSelector(T) {}
};

template <typename T>
FunctionSelector<T, true> make_function_selector_true(T fn) {
    return FunctionSelector<T, true>(fn);
}

template <typename T>
FunctionSelector<T, false> make_function_selector_false(T fn) {
    return FunctionSelector<T, false>(fn);
}

typedef void(*fn_ptr_type)();

std::vector<fn_ptr_type> fns;

template <typename T>
void add_fn(FunctionSelector<T, true> fs) {
    fns.push_back(fs.get());
}

template <typename T>
void add_fn(FunctionSelector<T, false>) {
    // Do nothing.
}

// fn1 will be kept by the linker since it is added to the vector "fns" at
// runtime.
void fn1() {
    printf("fn1\n");
}

// fn2 will be removed by the linker since it isn't referenced at all.
void fn2() {
    printf("fn2\n");
}

int main() {
    add_fn(make_function_selector_true(fn1));
    add_fn(make_function_selector_false(fn2));
}

[2] 由编译器消除的死代码

C++ 编译器可以通过静态分析代码的控制流来检测死(不可达)代码。例如,如果在无条件异常抛出之后存在代码路径,那么其后的所有代码都将被标记为死代码,并且不会被编译器转换为目标代码。通常,编译器需要使用 -fdce 标志来消除死代码。

在下面的示例中,您可以看到左侧(红色框中)的 C++ 代码在右侧没有相应的生成目标代码。

图 4:C++ 编译器消除死代码

此属性被用于 PyTorch 内核实现的函数体中,这些函数体包含大量重复代码来处理 Tensor 的多种 dtype。一个 dtype 是 Tensor 存储元素的基础数据类型。它可以是 float, double, int64, bool, int8 等等…

几乎所有 PyTorch CPU 内核都使用形式为 AT_DISPATCH_ALL_TYPES* 的宏,该宏用于替换针对内核需要处理的每种 dtype 特化的某些代码。例如

AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND3(
    kBool, kHalf, kBFloat16, dtype, "copy_kernel", [&] {
  cpu_kernel_vec(
      iter,
      [=](scalar_t a) -> scalar_t { return a; },
      [=](Vectorized<scalar_t> a) -> Vectorized<scalar_t> { return a; });
});

AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND3 内部有一个类似于上图 Figure-4 中代码的 switch-case 语句。追踪过程记录了针对内核标签“copy_kernel”触发的 dtypes,并且构建过程会处理这些标签,并在处理此内核标签不需要的 dtype 的每个 case 语句中插入 throw 语句。

这就是 PyTorch 基于追踪的选择性构建中实现 dtype 选择性的方式。

结论

基于追踪的选择性构建是一种实用且可扩展的方法,用于仅选择应用程序中已使用的部分,以保留静态分析无法检测到的代码。这些代码本质上通常高度依赖于数据/输入。

本文提供了关于基于追踪的选择性构建如何在底层工作以及其实现相关技术细节的详细见解。这些技术也可以应用于其他可以从减小二进制文件大小中获益的应用程序和情况。