博客

PyTorch 基于跟踪的选择性构建

作者 2022年10月17日2024年11月15日暂无评论

简介

TL;DR(摘要):在移动设备、SBC(单板计算机)和 IoT 设备上运行 PyTorch 具有一定挑战性。编译后的 PyTorch 库非常庞大,且包含了许多端侧应用场景中并不需要的依赖项。

为了在端侧运行特定的模型集,我们实际上只需要 PyTorch 库中一小部分功能。我们发现,使用通过选择性构建(selective build)生成的 PyTorch 运行时,可以将二进制文件大小减少高达 90%(针对 Linux x86-64 构建的 CPU 和 QuantizedCPU 后端)。在本篇博文中,我们将分享使用选择性构建生成模型专属最小化运行时的经验,并向您展示具体的操作方法。

为什么这对应用开发者很重要?

使用选择性构建生成的 PyTorch 运行时可以将 AI 应用的大小减少 30 MB 以上——对于典型的移动应用来说,这是一个显著的缩减!让移动应用变得更轻量有很多好处:它们可以在更多种类的设备上运行,消耗更少的移动数据,并且在用户设备上下载和更新的速度更快。

开发者体验如何?

此方法可以与任何现有的 PyTorch 移动端部署工作流程无缝衔接。您只需要将通用的 PyTorch 运行时库替换为您为应用中所需特定模型定制的运行时即可。此过程的一般步骤如下:

  1. 仪表化模式(instrumentation mode)下构建 PyTorch 运行时(这被称为 PyTorch 的仪表化构建)。这将记录所使用的算子、内核和功能。
  2. 使用提供的 model_tracer 二进制文件通过此仪表化构建运行您的模型。这将生成一个包含您模型所用所有功能的 YAML 文件。这些功能将被保留在最小化运行时中。
  3. 使用此 YAML 文件作为输入构建 PyTorch。这就是选择性构建技术,它能显著减小最终 PyTorch 二进制文件的大小。
  4. 使用这个经过选择性构建的 PyTorch 库来减小您的移动应用体积!

以特殊的“仪表化”模式构建 PyTorch 运行时(通过传递 TRACING_BASED=1 构建选项)会生成 PyTorch 的仪表化构建运行时,以及一个 model_tracer 二进制文件。使用此构建运行模型,使我们能够追踪模型所使用的 PyTorch 部分。

图 1:PyTorch 的仪表化构建

# Clone the PyTorch repo
git clone https://github.com/pytorch/pytorch.git
cd pytorch

# Build the model_tracer
USE_NUMPY=0 USE_DISTRIBUTED=0 USE_CUDA=0 TRACING_BASED=1 \
  python setup.py develop

现在,该仪表化构建被用于通过代表性输入运行模型推理。model_tracer 二进制文件会观察仪表化构建在推理过程中被激活的部分,并将其转储到 YAML 文件中。

图 2:在仪表化构建上运行模型后生成的 YAML 文件

# Generate YAML file
./build/bin/model_tracer \
  --model_input_path /tmp/path_to_model.ptl \
  --build_yaml_path /tmp/selected_ops.yaml

现在我们再次构建 PyTorch 运行时,但这次使用追踪器生成的 YAML 文件。运行时现在仅包含该模型所需的部分。这在下图中被称为“选择性构建的 PyTorch 运行时”

# Clean out cached configuration
make clean

# Build PyTorch using Selected Operators (from the YAML file)
# using the host toolchain, and use this generated library
BUILD_PYTORCH_MOBILE_WITH_HOST_TOOLCHAIN=1 \
USE_LIGHTWEIGHT_DISPATCH=0 \
BUILD_LITE_INTERPRETER=1 \
SELECTED_OP_LIST=/tmp/selected_ops.yaml \
TRACING_BASED=1 \
  ./scripts/build_mobile.sh

图 3:PyTorch 的选择性构建以及在选择性构建的 PyTorch 运行时上的模型执行

代码示例

我们整理了一个笔记本(notebook),通过一个简单的 PyTorch 模型展示了上述过程的代码实现。

如需在 Android/iOS 上进行部署的更深入实践教程,本教程应该会有所帮助。

技术常见问题解答

为什么 PyTorch 的选择性构建需要追踪(Tracing)?

在 PyTorch 中,CPU 内核可以通过 PyTorch 分发器(Dispatcher)调用其他算子。仅包含模型直接调用的根算子集合是不够的,因为在底层可能还会有许多算子被间接调用。通过代表性输入运行模型并观察实际调用的算子列表(即“追踪”)是确定 PyTorch 使用部分的最高效准确方法。

此外,内核应处理哪些数据类型(dtypes)等因素也是取决于模型实际输入的运行时特性。因此,追踪机制非常适用于此目的。

通过基于追踪的选择性构建,可以选择(包含或剔除)哪些功能?

在基于追踪的选择性构建过程中,可以为 PyTorch 运行时选择以下功能:

  1. CPU/QuantizedCPU PyTorch ATen 算子内核:如果目标为选择性构建运行时的模型不需要某个 PyTorch 算子,则在运行时中会省略该 CPU 内核的注册。这由 Torchgen 代码生成器控制。
  2. 基础算子(Primary Operators):这由名为 TORCH_SELECTIVE_SCHEMA 的宏控制(通过模板选择性构建),它根据生成的头文件中的信息选择或取消选择基础算子。
  3. 处理 CPU 内核中特定数据类型的代码:这是通过在 AT_PRIVATE_CHECK_SELECTIVE_BUILD 宏生成的 switch case 语句中生成异常抛出操作来实现的。
  4. 扩展 PyTorch 的自定义 C++ 类注册:这由 TORCH_SELECTIVE_CLASS 宏控制,该宏可在注册自定义 C++ 类时使用。辅助工具 torch::selective_class<> 应与 TORCH_SELECTIVE_CLASS 宏结合使用。

构建过程中使用的 YAML 文件结构是什么样的?

追踪后生成的 YAML 文件如下例所示。它编码了上述“可选择”构建功能的所有元素。

include_all_non_op_selectives: false
build_features: []
operators:
    aten::add.Tensor:
        is_used_for_training: false
        is_root_operator: true
        include_all_overloads: false
    aten::len.t:
        is_used_for_training: false
        is_root_operator: true
        include_all_overloads: false
kernel_metadata:
    _local_scalar_dense_cpu:
    - Float
    add_stub:
    - Float
    copy_:
    - Bool
    - Byte
    mul_cpu:
    - Float
custom_classes: []

代码究竟是如何从生成的二进制文件中剔除的?

根据具体情况,主要有两种技术用于提示编译器和链接器识别未使用和不可达的代码。这些代码随后会被编译器或链接器作为不可达代码进行清理。

[1] 链接器移除未引用的函数

当被链接的已编译目标文件中存在未从任何可见函数间接引用的函数时,链接器将移除它(如果提供了正确的构建标志)。选择性构建系统在两种场景中利用了这一点。

分发器中的内核注册

如果不需要某个算子的内核,则它不会向分发器注册。未注册的内核意味着该函数不可达,链接器会将其移除。

模板选择性构建

这里的基本思路是使用类模板特化来选择一个类,该类要么捕获对函数的引用,要么不捕获(取决于是否使用),这样链接器就可以清除未引用的函数。

例如,在下面的代码中,没有对函数“fn2”的引用,因此它将被链接器清理,因为它在任何地方都没有被引用。

#include <vector>
#include <cstdio>

template <typename T, bool>
struct FunctionSelector {
    T fn_;
    FunctionSelector(T fn): fn_(fn) {}
    T get() { return this->fn_; }
};

// The "false" specialization of this class does NOT retain the argument passed
// to the class constructor, which means that the function pointer passed in
// is considered to be unreferenced in the program (unless it is referenced
// elsewhere).
template <typename T>
struct FunctionSelector<T, false> {
    FunctionSelector(T) {}
};

template <typename T>
FunctionSelector<T, true> make_function_selector_true(T fn) {
    return FunctionSelector<T, true>(fn);
}

template <typename T>
FunctionSelector<T, false> make_function_selector_false(T fn) {
    return FunctionSelector<T, false>(fn);
}

typedef void(*fn_ptr_type)();

std::vector<fn_ptr_type> fns;

template <typename T>
void add_fn(FunctionSelector<T, true> fs) {
    fns.push_back(fs.get());
}

template <typename T>
void add_fn(FunctionSelector<T, false>) {
    // Do nothing.
}

// fn1 will be kept by the linker since it is added to the vector "fns" at
// runtime.
void fn1() {
    printf("fn1\n");
}

// fn2 will be removed by the linker since it isn't referenced at all.
void fn2() {
    printf("fn2\n");
}

int main() {
    add_fn(make_function_selector_true(fn1));
    add_fn(make_function_selector_false(fn2));
}

[2] 编译器执行死代码消除

C++ 编译器可以通过静态分析代码的控制流来检测死代码(不可达代码)。例如,如果存在一个位于无条件异常抛出之后的代码路径,那么其后的所有代码都将被标记为死代码,编译器不会将其转换为目标代码。通常,编译器需要使用 -fdce 标志来消除死代码。

在下面的例子中,您可以看到左侧(红色框中)的 C++ 代码在右侧没有任何对应的生成目标代码。

图 4:C++ 编译器的死代码消除

这一特性被应用于 PyTorch 内核实现的函数体中,这些实现包含大量重复代码以处理张量的多种数据类型。dtype 是张量存储元素的基础数据类型,可以是 float、double、int64、bool、int8 等……

几乎每个 PyTorch CPU 内核都使用 AT_DISPATCH_ALL_TYPES* 形式的宏,用于替换内核需要处理的每种 dtype 的专用代码。例如:

AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND3(
    kBool, kHalf, kBFloat16, dtype, "copy_kernel", [&] {
  cpu_kernel_vec(
      iter,
      [=](scalar_t a) -> scalar_t { return a; },
      [=](Vectorized<scalar_t> a) -> Vectorized<scalar_t> { return a; });
});

AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND3 在内部有一个 switch-case 语句,看起来就像上面图 4 中的代码。追踪过程记录了内核标签“copy_kernel”触发的数据类型,构建过程处理这些标签并在所有处理非该内核标签所需数据类型的 case 语句中插入 throw 语句。

这就是 PyTorch 基于追踪的选择性构建中实现 dtype 选择性的方式。

结论

基于追踪的选择性构建是一种实用且可扩展的方法,用于仅选择应用中已使用的部分,从而保留静态分析无法检测到的代码。这些代码通常在本质上极度依赖数据或输入。

本文详细介绍了基于追踪的选择性构建如何在底层工作,以及与其实现相关的技术细节。这些技术也可以应用于其他能够从减小二进制体积中获益的应用和场景。