• 文档 >
  • ExecuTorch 中的 LLM 简介
快捷键

ExecuTorch 中的 LLM 简介

欢迎来到 LLM 手册!本手册旨在提供一个实际示例,以利用 ExecuTorch 引导您使用自己的大型语言模型 (LLM)。我们的主要目标是提供一个清晰简洁的指南,说明如何将我们的系统与您自己的 LLM 集成。

请注意,本项目旨在作为演示,而不是作为具有最佳性能的完整功能示例。因此,某些组件(如采样器、分词器等)仅以其最基本版本提供,仅用于演示目的。因此,模型产生的结果可能会有所不同,并且可能并不总是最佳的。

我们鼓励用户将本项目用作起点,并根据其特定需求进行调整,包括创建您自己的分词器、采样器、加速后端和其他组件的版本。我们希望本项目能为您的 LLM 和 ExecuTorch 之旅提供有益的指导。

有关以最佳性能部署 Llama 的信息,请参阅Llama 指南

目录

  1. 先决条件

  2. Hello World 示例

  3. 量化

  4. 使用移动加速

  5. 调试和性能分析

  6. 如何使用自定义内核

  7. 如何构建移动应用程序

先决条件

要遵循本指南,您需要克隆 ExecuTorch 存储库并安装依赖项。ExecuTorch 建议使用 Python 3.10 和 Conda 来管理您的环境。Conda 不是必需的,但请注意,您可能需要根据您的环境将 python/pip 替换为 python3/pip3。

有关安装 miniconda 的说明,请在此处找到

# Create a directory for this example.
mkdir et-nanogpt
cd et-nanogpt

# Clone the ExecuTorch repository and submodules.
mkdir third-party
git clone -b release/0.4 https://github.com/pytorch/executorch.git third-party/executorch
cd third-party/executorch
git submodule update --init

# Create a conda environment and install requirements.
conda create -yn executorch python=3.10.0
conda activate executorch
./install_requirements.sh

cd ../..

有关安装 pyenv-virtualenv 的说明,请在此处找到

重要的是,如果通过 brew 安装 pyenv,则它不会自动在终端中启用 pyenv,从而导致错误。运行以下命令以启用。请参阅上面的 pyenv-virtualenv 安装指南,了解如何将其添加到您的 .bashrc 或 .zshrc 中,以避免需要手动运行这些命令。

eval "$(pyenv init -)"
eval "$(pyenv virtualenv-init -)"
# Create a directory for this example.
mkdir et-nanogpt
cd et-nanogpt

pyenv install -s 3.10
pyenv virtualenv 3.10 executorch
pyenv activate executorch

# Clone the ExecuTorch repository and submodules.
mkdir third-party
git clone -b release/0.4 https://github.com/pytorch/executorch.git third-party/executorch
cd third-party/executorch
git submodule update --init

# Install requirements.
PYTHON_EXECUTABLE=python ./install_requirements.sh

cd ../..

有关更多信息,请参阅设置 ExecuTorch

在本地运行大型语言模型

本示例使用 Karpathy 的 nanoGPT,它是 GPT-2 124M 的最小实现。本指南适用于其他语言模型,因为 ExecuTorch 是模型不变的。

使用 ExecuTorch 运行模型有两个步骤

  1. 导出模型。此步骤将其预处理为适合运行时执行的格式。

  2. 在运行时,加载模型文件并使用 ExecuTorch 运行时运行。


导出步骤提前完成,通常作为应用程序构建的一部分或在模型更改时完成。生成的 .pte 文件随应用程序一起分发。在运行时,应用程序加载 .pte 文件并将其传递给 ExecuTorch 运行时。

步骤 1. 导出到 ExecuTorch

导出采用 PyTorch 模型并将其转换为可以在消费类设备上高效运行的格式。

对于本示例,您将需要 nanoGPT 模型和相应的分词器词汇表。

curl https://raw.githubusercontent.com/karpathy/nanoGPT/master/model.py -O
curl https://huggingface.co/openai-community/gpt2/resolve/main/vocab.json -O
wget https://raw.githubusercontent.com/karpathy/nanoGPT/master/model.py
wget https://huggingface.co/openai-community/gpt2/resolve/main/vocab.json

要将模型转换为针对独立执行优化的格式,需要执行两个步骤。首先,使用 PyTorch export 函数将 PyTorch 模型转换为平台独立的中间表示。然后使用 ExecuTorch to_edgeto_executorch 方法为设备端执行准备模型。这将创建一个 .pte 文件,桌面或移动应用程序可以在运行时加载该文件。

创建一个名为 export_nanogpt.py 的文件,内容如下

# export_nanogpt.py

import torch

from executorch.exir import EdgeCompileConfig, to_edge
from torch.nn.attention import sdpa_kernel, SDPBackend
from torch.export import export, export_for_training

from model import GPT

# Load the model.
model = GPT.from_pretrained('gpt2')

# Create example inputs. This is used in the export process to provide
# hints on the expected shape of the model input.
example_inputs = (torch.randint(0, 100, (1, model.config.block_size), dtype=torch.long), )

# Set up dynamic shape configuration. This allows the sizes of the input tensors
# to differ from the sizes of the tensors in `example_inputs` during runtime, as
# long as they adhere to the rules specified in the dynamic shape configuration.
# Here we set the range of 0th model input's 1st dimension as
# [0, model.config.block_size].
# See https://pytorch.ac.cn/executorch/main/concepts.html#dynamic-shapes
# for details about creating dynamic shapes.
dynamic_shape = (
    {1: torch.export.Dim("token_dim", max=model.config.block_size)},
)

# Trace the model, converting it to a portable intermediate representation.
# The torch.no_grad() call tells PyTorch to exclude training-specific logic.
with torch.nn.attention.sdpa_kernel([SDPBackend.MATH]), torch.no_grad():
    m = export_for_training(model, example_inputs, dynamic_shapes=dynamic_shape).module()
    traced_model = export(m, example_inputs, dynamic_shapes=dynamic_shape)

# Convert the model into a runnable ExecuTorch program.
edge_config = EdgeCompileConfig(_check_ir_validity=False)
edge_manager = to_edge(traced_model,  compile_config=edge_config)
et_program = edge_manager.to_executorch()

# Save the ExecuTorch program to a file.
with open("nanogpt.pte", "wb") as file:
    file.write(et_program.buffer)

要导出,请使用 python export_nanogpt.py (或 python3,根据您的环境而定)运行脚本。它将在当前目录中生成一个 nanogpt.pte 文件。

有关更多信息,请参阅导出到 ExecuTorchtorch.export

步骤 2. 调用运行时

ExecuTorch 提供了一组运行时 API 和类型来加载和运行模型。

创建一个名为 main.cpp 的文件,内容如下

// main.cpp

#include <cstdint>

#include "basic_sampler.h"
#include "basic_tokenizer.h"

#include <executorch/extension/module/module.h>
#include <executorch/extension/tensor/tensor.h>
#include <executorch/runtime/core/evalue.h>
#include <executorch/runtime/core/exec_aten/exec_aten.h>
#include <executorch/runtime/core/result.h>

using executorch::aten::ScalarType;
using executorch::aten::Tensor;
using executorch::extension::from_blob;
using executorch::extension::Module;
using executorch::runtime::EValue;
using executorch::runtime::Result;

模型输入和输出采用张量的形式。张量可以被认为是多维数组。ExecuTorch EValue 类提供了张量和其他 ExecuTorch 数据类型的包装器。

由于 LLM 一次生成一个 token,因此驱动程序代码需要重复调用模型,逐个 token 构建输出。每个生成的 token 都作为下一次运行的输入传递。

// main.cpp

// The value of the gpt2 `<|endoftext|>` token.
#define ENDOFTEXT_TOKEN 50256

std::string generate(
    Module& llm_model,
    std::string& prompt,
    BasicTokenizer& tokenizer,
    BasicSampler& sampler,
    size_t max_input_length,
    size_t max_output_length) {
  // Convert the input text into a list of integers (tokens) that represents it,
  // using the string-to-token mapping that the model was trained on. Each token
  // is an integer that represents a word or part of a word.
  std::vector<int64_t> input_tokens = tokenizer.encode(prompt);
  std::vector<int64_t> output_tokens;

  for (auto i = 0u; i < max_output_length; i++) {
    // Convert the input_tokens from a vector of int64_t to EValue. EValue is a
    // unified data type in the ExecuTorch runtime.
    auto inputs = from_blob(
        input_tokens.data(),
        {1, static_cast<int>(input_tokens.size())},
        ScalarType::Long);

    // Run the model. It will return a tensor of logits (log-probabilities).
    auto logits_evalue = llm_model.forward(inputs);

    // Convert the output logits from EValue to std::vector, which is what the
    // sampler expects.
    Tensor logits_tensor = logits_evalue.get()[0].toTensor();
    std::vector<float> logits(
        logits_tensor.data_ptr<float>(),
        logits_tensor.data_ptr<float>() + logits_tensor.numel());

    // Sample the next token from the logits.
    int64_t next_token = sampler.sample(logits);

    // Break if we reached the end of the text.
    if (next_token == ENDOFTEXT_TOKEN) {
      break;
    }

    // Add the next token to the output.
    output_tokens.push_back(next_token);

    std::cout << tokenizer.decode({next_token});
    std::cout.flush();

    // Update next input.
    input_tokens.push_back(next_token);
    if (input_tokens.size() > max_input_length) {
      input_tokens.erase(input_tokens.begin());
    }
  }

  std::cout << std::endl;

  // Convert the output tokens into a human-readable string.
  std::string output_string = tokenizer.decode(output_tokens);
  return output_string;
}

Module 类处理加载 .pte 文件并准备执行。

分词器负责将人类可读的提示字符串表示形式转换为模型期望的数值形式。为此,分词器将短子字符串与给定的 token ID 相关联。token 可以被认为是表示单词或单词的一部分,但在实践中,它们可能是任意字符序列。

分词器从文件中加载词汇表,其中包含每个 token ID 与其表示的文本之间的映射。调用 tokenizer.encode()tokenizer.decode() 以在字符串和 token 表示形式之间进行转换。

采样器负责根据模型输出的 logits 或对数概率选择下一个 token。LLM 为每个可能的下一个 token 返回一个 logit 值。采样器根据某种策略选择要使用的 token。此处使用的最简单的方法是采用具有最高 logit 值的 token。

采样器可以提供可配置的选项,例如输出选择的可配置随机量、重复 token 的惩罚以及优先或降低特定 token 优先级的偏差。

// main.cpp

int main() {
  // Set up the prompt. This provides the seed text for the model to elaborate.
  std::cout << "Enter model prompt: ";
  std::string prompt;
  std::getline(std::cin, prompt);

  // The tokenizer is used to convert between tokens (used by the model) and
  // human-readable strings.
  BasicTokenizer tokenizer("vocab.json");

  // The sampler is used to sample the next token from the logits.
  BasicSampler sampler = BasicSampler();

  // Load the exported nanoGPT program, which was generated via the previous
  // steps.
  Module model("nanogpt.pte", Module::LoadMode::MmapUseMlockIgnoreErrors);

  const auto max_input_tokens = 1024;
  const auto max_output_tokens = 30;
  std::cout << prompt;
  generate(
      model, prompt, tokenizer, sampler, max_input_tokens, max_output_tokens);
}

最后,将以下文件下载到与 main.cpp 相同的目录中

curl -O https://raw.githubusercontent.com/pytorch/executorch/main/examples/llm_manual/basic_sampler.h
curl -O https://raw.githubusercontent.com/pytorch/executorch/main/examples/llm_manual/basic_tokenizer.h

要了解更多信息,请参阅运行时 API 教程

构建和运行

ExecuTorch 使用 CMake 构建系统。要编译并链接到 ExecuTorch 运行时,请通过 add_directory 包含 ExecuTorch 项目,并链接到 executorch 和其他依赖项。

创建一个名为 CMakeLists.txt 的文件,内容如下

# CMakeLists.txt

cmake_minimum_required(VERSION 3.19)
project(nanogpt_runner)

set(CMAKE_CXX_STANDARD 17)
set(CMAKE_CXX_STANDARD_REQUIRED True)

# Set options for executorch build.
option(EXECUTORCH_ENABLE_LOGGING "" ON)
option(EXECUTORCH_BUILD_EXTENSION_DATA_LOADER "" ON)
option(EXECUTORCH_BUILD_EXTENSION_MODULE "" ON)
option(EXECUTORCH_BUILD_EXTENSION_TENSOR "" ON)
option(EXECUTORCH_BUILD_KERNELS_OPTIMIZED "" ON)

# Include the executorch subdirectory.
add_subdirectory(
  ${CMAKE_CURRENT_SOURCE_DIR}/third-party/executorch
  ${CMAKE_BINARY_DIR}/executorch
)

add_executable(nanogpt_runner main.cpp)
target_link_libraries(
  nanogpt_runner
  PRIVATE executorch
          extension_module_static # Provides the Module class
          extension_tensor # Provides the TensorPtr class
          optimized_native_cpu_ops_lib # Provides baseline cross-platform
                                       # kernels
)

此时,工作目录应包含以下文件

  • CMakeLists.txt

  • main.cpp

  • basic_tokenizer.h

  • basic_sampler.h

  • export_nanogpt.py

  • model.py

  • vocab.json

  • nanogpt.pte

如果所有这些文件都存在,您现在可以构建和运行

./install_requirements.sh --clean
(mkdir cmake-out && cd cmake-out && cmake ..)
cmake --build cmake-out -j10
./cmake-out/nanogpt_runner

您应该看到消息

Enter model prompt:

为模型键入一些种子文本,然后按 Enter 键。这里我们使用“Hello world!”作为示例提示

Enter model prompt: Hello world!
Hello world!

I'm not sure if you've heard of the "Curse of the Dragon" or not, but it's a very popular game in

此时,它很可能运行非常缓慢。这是因为 ExecuTorch 尚未被告知要针对特定硬件进行优化(委托),并且因为它正在以 32 位浮点数(无量化)进行所有计算。

委托

虽然 ExecuTorch 为所有运算符提供了可移植的跨平台实现,但它也为许多不同的目标提供了专门的后端。这些后端包括但不限于:通过 XNNPACK 后端实现的 x86 和 ARM CPU 加速、通过 Core ML 后端和 Metal Performance Shader (MPS) 后端实现的 Apple 加速以及通过 Vulkan 后端实现的 GPU 加速。

由于优化特定于给定的后端,因此每个 pte 文件都特定于导出时面向的后端。为了支持多个设备(例如 Android 的 XNNPACK 加速和 iOS 的 Core ML),请为每个后端导出一个单独的 PTE 文件。

要在导出时委托给后端,ExecuTorch 在 EdgeProgramManager 对象中提供了 to_backend() 函数,该函数接受后端特定的分区器对象。分区器负责查找可以由目标后端加速的计算图部分,并且 to_backend() 函数会将匹配的部分委托给给定的后端以进行加速和优化。未委托的计算图的任何部分将由 ExecuTorch 运算符实现执行。

要将导出的模型委托给特定的后端,我们需要首先从 ExecuTorch 代码库导入其分区器以及边缘编译配置,然后使用 EdgeProgramManager 对象 to_edge 函数创建的分区器实例调用 to_backend

以下是如何将 nanoGPT 委托给 XNNPACK 的示例(如果您要部署到 Android 手机,例如)

# export_nanogpt.py

# Load partitioner for Xnnpack backend
from executorch.backends.xnnpack.partition.xnnpack_partitioner import XnnpackPartitioner

# Model to be delegated to specific backend should use specific edge compile config
from executorch.backends.xnnpack.utils.configs import get_xnnpack_edge_compile_config
from executorch.exir import EdgeCompileConfig, to_edge

import torch
from torch.export import export
from torch.nn.attention import sdpa_kernel, SDPBackend
from torch.export import export_for_training

from model import GPT

# Load the nanoGPT model.
model = GPT.from_pretrained('gpt2')

# Create example inputs. This is used in the export process to provide
# hints on the expected shape of the model input.
example_inputs = (
        torch.randint(0, 100, (1, model.config.block_size - 1), dtype=torch.long),
    )

# Set up dynamic shape configuration. This allows the sizes of the input tensors
# to differ from the sizes of the tensors in `example_inputs` during runtime, as
# long as they adhere to the rules specified in the dynamic shape configuration.
# Here we set the range of 0th model input's 1st dimension as
# [0, model.config.block_size].
# See https://pytorch.ac.cn/executorch/main/concepts.html#dynamic-shapes
# for details about creating dynamic shapes.
dynamic_shape = (
    {1: torch.export.Dim("token_dim", max=model.config.block_size - 1)},
)

# Trace the model, converting it to a portable intermediate representation.
# The torch.no_grad() call tells PyTorch to exclude training-specific logic.
with torch.nn.attention.sdpa_kernel([SDPBackend.MATH]), torch.no_grad():
    m = export_for_training(model, example_inputs, dynamic_shapes=dynamic_shape).module()
    traced_model = export(m, example_inputs, dynamic_shapes=dynamic_shape)

# Convert the model into a runnable ExecuTorch program.
# To be further lowered to Xnnpack backend, `traced_model` needs xnnpack-specific edge compile config
edge_config = get_xnnpack_edge_compile_config()
edge_manager = to_edge(traced_model, compile_config=edge_config)

# Delegate exported model to Xnnpack backend by invoking `to_backend` function with Xnnpack partitioner.
edge_manager = edge_manager.to_backend(XnnpackPartitioner())
et_program = edge_manager.to_executorch()

# Save the Xnnpack-delegated ExecuTorch program to a file.
with open("nanogpt.pte", "wb") as file:
    file.write(et_program.buffer)


此外,更新 CMakeLists.txt 以构建 XNNPACK 后端并将其链接到 ExecuTorch 运行器。

cmake_minimum_required(VERSION 3.19)
project(nanogpt_runner)

set(CMAKE_CXX_STANDARD 17)
set(CMAKE_CXX_STANDARD_REQUIRED True)

# Set options for executorch build.
option(EXECUTORCH_ENABLE_LOGGING "" ON)
option(EXECUTORCH_BUILD_EXTENSION_DATA_LOADER "" ON)
option(EXECUTORCH_BUILD_EXTENSION_MODULE "" ON)
option(EXECUTORCH_BUILD_EXTENSION_TENSOR "" ON)
option(EXECUTORCH_BUILD_KERNELS_OPTIMIZED "" ON)
option(EXECUTORCH_BUILD_XNNPACK "" ON) # Build with Xnnpack backend

# Include the executorch subdirectory.
add_subdirectory(
  ${CMAKE_CURRENT_SOURCE_DIR}/third-party/executorch
  ${CMAKE_BINARY_DIR}/executorch
)

add_executable(nanogpt_runner main.cpp)
target_link_libraries(
  nanogpt_runner
  PRIVATE executorch
          extension_module_static # Provides the Module class
          extension_tensor # Provides the TensorPtr class
          optimized_native_cpu_ops_lib # Provides baseline cross-platform
                                       # kernels
          xnnpack_backend # Provides the XNNPACK CPU acceleration backend
)

保持其余代码不变。有关更多详细信息,请参阅导出到 ExecuTorch调用运行时 以获取更多详细信息

此时,工作目录应包含以下文件

  • CMakeLists.txt

  • main.cpp

  • basic_tokenizer.h

  • basic_sampler.h

  • export_nanogpt.py

  • model.py

  • vocab.json

如果所有这些文件都存在,您现在可以导出 Xnnpack 委托的 pte 模型

python export_nanogpt.py

它将在相同的工作目录下生成 nanogpt.pte

然后我们可以通过以下方式构建和运行模型

(rm -rf cmake-out && mkdir cmake-out && cd cmake-out && cmake ..)
cmake --build cmake-out -j10
./cmake-out/nanogpt_runner

您应该看到消息

Enter model prompt:

为模型键入一些种子文本,然后按 Enter 键。这里我们使用“Hello world!”作为示例提示

Enter model prompt: Hello world!
Hello world!

I'm not sure if you've heard of the "Curse of the Dragon" or not, but it's a very popular game in

与非委托模型相比,委托模型应该明显更快。

有关后端委托的更多信息,请参阅 ExecuTorch 关于 XNNPACK 后端Core ML 后端Qualcomm AI Engine Direct 后端 的指南。

量化

量化是指使用较低精度类型运行计算和存储张量的一组技术。与 32 位浮点数相比,使用 8 位整数可以显着提高速度并减少内存使用量。量化模型的方法有很多种,在所需的预处理量、使用的数据类型以及对模型精度和性能的影响方面各不相同。

由于移动设备上的计算和内存高度受限,因此某种形式的量化对于在消费电子产品上发布大型模型是必要的。特别是,大型语言模型(如 Llama2)可能需要将模型权重量化为 4 位或更少。

利用量化需要在导出前转换模型。PyTorch 为此目的提供了 pt2e (PyTorch 2 Export) API。本示例针对使用 XNNPACK 委托的 CPU 加速。因此,它需要使用特定于 XNNPACK 的量化器。针对不同的后端将需要使用相应的量化器。

要将 8 位整数动态量化与 XNNPACK 委托一起使用,请调用 prepare_pt2e,通过使用代表性输入运行来校准模型,然后调用 convert_pt2e。这将更新计算图以在使用量化运算符(如果可用)的位置使用。

# export_nanogpt.py

from executorch.backends.transforms.duplicate_dynamic_quant_chain import (
    DuplicateDynamicQuantChainPass,
)
from torch.ao.quantization.quantizer.xnnpack_quantizer import (
    get_symmetric_quantization_config,
    XNNPACKQuantizer,
)
from torch.ao.quantization.quantize_pt2e import convert_pt2e, prepare_pt2e
# Use dynamic, per-channel quantization.
xnnpack_quant_config = get_symmetric_quantization_config(
    is_per_channel=True, is_dynamic=True
)
xnnpack_quantizer = XNNPACKQuantizer()
xnnpack_quantizer.set_global(xnnpack_quant_config)

m = export_for_training(model, example_inputs).module()

# Annotate the model for quantization. This prepares the model for calibration.
m = prepare_pt2e(m, xnnpack_quantizer)

# Calibrate the model using representative inputs. This allows the quantization
# logic to determine the expected range of values in each tensor.
m(*example_inputs)

# Perform the actual quantization.
m = convert_pt2e(m, fold_quantize=False)
DuplicateDynamicQuantChainPass()(m)

traced_model = export(m, example_inputs)

此外,添加或更新 to_backend() 调用以使用 XnnpackPartitioner。这指示 ExecuTorch 通过 XNNPACK 后端优化模型以进行 CPU 执行。

from executorch.backends.xnnpack.partition.xnnpack_partitioner import (
    XnnpackPartitioner,
)
edge_manager = to_edge(traced_model, compile_config=edge_config)
edge_manager = edge_manager.to_backend(XnnpackPartitioner()) # Lower to XNNPACK.
et_program = edge_manager.to_executorch()

最后,确保运行器链接到 CMakeLists.txt 中的 xnnpack_backend 目标。

add_executable(nanogpt_runner main.cpp)
target_link_libraries(
    nanogpt_runner
    PRIVATE
    executorch
    extension_module_static # Provides the Module class
    optimized_native_cpu_ops_lib # Provides baseline cross-platform kernels
    xnnpack_backend) # Provides the XNNPACK CPU acceleration backend

有关更多信息,请参阅ExecuTorch 中的量化

性能分析和调试

通过调用 to_backend() 降低模型后,您可能想查看哪些内容被委托了,哪些内容没有被委托。ExecuTorch 提供了实用程序方法来深入了解委托。您可以使用此信息来了解底层计算并诊断潜在的性能问题。模型作者可以使用此信息来构建模型,使其与目标后端兼容。

可视化委托

get_delegation_info() 方法提供了 to_backend() 调用后模型发生情况的摘要

from executorch.devtools.backend_debug import get_delegation_info
from tabulate import tabulate

# ... After call to to_backend(), but before to_executorch()
graph_module = edge_manager.exported_program().graph_module
delegation_info = get_delegation_info(graph_module)
print(delegation_info.get_summary())
df = delegation_info.get_operator_delegation_dataframe()
print(tabulate(df, headers="keys", tablefmt="fancy_grid"))

对于针对 XNNPACK 后端的 nanoGPT,您可能会看到以下内容(请注意,以下数字仅用于说明目的,实际值可能会有所不同)

Total  delegated  subgraphs:  145
Number  of  delegated  nodes:  350
Number  of  non-delegated  nodes:  760

op_type

# in_delegated_graphs

# in_non_delegated_graphs

0

aten__softmax_default

12

0

1

aten_add_tensor

37

0

2

aten_addmm_default

48

0

3

aten_any_dim

0

12

25

aten_view_copy_default

96

122

30

总计

350

760

从表中可以看出,运算符 aten_view_copy_default 在委托图中出现 96 次,在非委托图中出现 122 次。要查看更详细的视图,请使用 format_delegated_graph() 方法获取整个图的格式化打印字符串,或使用 print_delegated_graph() 直接打印

from executorch.exir.backend.utils import format_delegated_graph
graph_module = edge_manager.exported_program().graph_module
print(format_delegated_graph(graph_module))

对于大型模型,这可能会生成大量输出。考虑使用“Control+F”或“Command+F”来定位您感兴趣的运算符(例如“aten_view_copy_default”)。观察哪些实例不在降级图下。

在下面 nanoGPT 输出的片段中,观察到 transformer 模块已委托给 XNNPACK,而 where 运算符未委托。

%aten_where_self_22 : [num_users=1] = call_function[target=executorch.exir.dialects.edge._ops.aten.where.self](args = (%aten_logical_not_default_33, %scalar_tensor_23, %scalar_tensor_22), kwargs = {})
%lowered_module_144 : [num_users=1] = get_attr[target=lowered_module_144]
backend_id: XnnpackBackend
lowered graph():
    %p_transformer_h_0_attn_c_attn_weight : [num_users=1] = placeholder[target=p_transformer_h_0_attn_c_attn_weight]
    %p_transformer_h_0_attn_c_attn_bias : [num_users=1] = placeholder[target=p_transformer_h_0_attn_c_attn_bias]
    %getitem : [num_users=1] = placeholder[target=getitem]
    %sym_size : [num_users=2] = placeholder[target=sym_size]
    %aten_view_copy_default : [num_users=1] = call_function[target=executorch.exir.dialects.edge._ops.aten.view_copy.default](args = (%getitem, [%sym_size, 768]), kwargs = {})
    %aten_permute_copy_default : [num_users=1] = call_function[target=executorch.exir.dialects.edge._ops.aten.permute_copy.default](args = (%p_transformer_h_0_attn_c_attn_weight, [1, 0]), kwargs = {})
    %aten_addmm_default : [num_users=1] = call_function[target=executorch.exir.dialects.edge._ops.aten.addmm.default](args = (%p_transformer_h_0_attn_c_attn_bias, %aten_view_copy_default, %aten_permute_copy_default), kwargs = {})
    %aten_view_copy_default_1 : [num_users=1] = call_function[target=executorch.exir.dialects.edge._ops.aten.view_copy.default](args = (%aten_addmm_default, [1, %sym_size, 2304]), kwargs = {})
    return [aten_view_copy_default_1]

性能分析

通过 ExecuTorch 开发者工具,用户可以分析模型执行情况,从而获得模型中每个运算符的计时信息。

先决条件

ETRecord 生成(可选)

ETRecord 是在导出时生成的工件,其中包含模型图和源代码级元数据,这些元数据将 ExecuTorch 程序链接到原始 PyTorch 模型。您可以在没有 ETRecord 的情况下查看所有性能分析事件,但使用 ETRecord,您还可以将每个事件链接到正在执行的运算符类型、模块层次结构和原始 PyTorch 源代码的堆栈跟踪。有关更多信息,请参阅ETRecord 文档

在您的导出脚本中,在调用 to_edge()to_executorch() 后,使用来自 to_edge()EdgeProgramManager 和来自 to_executorch()ExecuTorchProgramManager 调用 generate_etrecord()。确保复制 EdgeProgramManager,因为对 to_backend() 的调用会就地改变图。

# export_nanogpt.py

import copy
from executorch.devtools import generate_etrecord

# Make the deep copy immediately after to to_edge()
edge_manager_copy = copy.deepcopy(edge_manager)

# ...
# Generate ETRecord right after to_executorch()
etrecord_path = "etrecord.bin"
generate_etrecord(etrecord_path, edge_manager_copy, et_program)

运行导出脚本,ETRecord 将生成为 etrecord.bin

ETDump 生成

ETDump 是在运行时生成的工件,其中包含模型执行的跟踪。有关更多信息,请参阅ETDump 文档

在您的代码中包含 ETDump 标头。

// main.cpp

#include <executorch/devtools/etdump/etdump_flatcc.h>

创建 ETDumpGen 类的实例并将其传递给 Module 构造函数。

std::unique_ptr<ETDumpGen> etdump_gen_ = std::make_unique<ETDumpGen>();
Module model("nanogpt.pte", Module::LoadMode::MmapUseMlockIgnoreErrors, std::move(etdump_gen_));

在调用 generate() 后,将 ETDump 保存到文件。如果需要,您可以在单个跟踪中捕获多个模型运行。

ETDumpGen* etdump_gen = static_cast<ETDumpGen*>(model.event_tracer());

ET_LOG(Info, "ETDump size: %zu blocks", etdump_gen->get_num_blocks());
etdump_result result = etdump_gen->get_etdump_data();
if (result.buf != nullptr && result.size > 0) {
    // On a device with a file system, users can just write it to a file.
    FILE* f = fopen("etdump.etdp", "w+");
    fwrite((uint8_t*)result.buf, 1, result.size, f);
    fclose(f);
    free(result.buf);
}

此外,更新 CMakeLists.txt 以使用开发者工具构建,并启用要跟踪和记录到 ETDump 中的事件

option(EXECUTORCH_ENABLE_EVENT_TRACER "" ON)
option(EXECUTORCH_BUILD_DEVTOOLS "" ON)

# ...

target_link_libraries(
    # ... omit existing ones
    etdump) # Provides event tracing and logging

target_compile_options(executorch PUBLIC -DET_EVENT_TRACER_ENABLED)
target_compile_options(portable_ops_lib PUBLIC -DET_EVENT_TRACER_ENABLED)

构建并运行运行器,您将看到生成名为“etdump.etdp”的文件。(请注意,这次我们在发布模式下构建以绕过 flatccrt 构建限制。)

(rm -rf cmake-out && mkdir cmake-out && cd cmake-out && cmake -DCMAKE_BUILD_TYPE=Release ..)
cmake --build cmake-out -j10
./cmake-out/nanogpt_runner

使用 Inspector API 进行分析

收集调试工件 ETDump(以及可选的 ETRecord)后,您可以使用 Inspector API 查看性能信息。

from executorch.devtools import Inspector

inspector = Inspector(etdump_path="etdump.etdp")
# If you also generated an ETRecord, then pass that in as well: `inspector = Inspector(etdump_path="etdump.etdp", etrecord="etrecord.bin")`

with open("inspector_out.txt", "w") as file:
    inspector.print_data_tabular(file)

这会将性能数据以表格格式打印在“inspector_out.txt”中,每行都是一个性能分析事件。顶行如下所示: 查看完整尺寸

要了解有关 Inspector 及其提供的丰富功能的更多信息,请参阅Inspector API 参考

自定义内核

借助 ExecuTorch 自定义运算符 API,自定义运算符和内核作者可以轻松地将其内核引入 PyTorch/ExecuTorch。

在 ExecuTorch 中使用自定义内核有三个步骤

  1. 使用 ExecuTorch 类型编写自定义内核。

  2. 编译自定义内核并将其链接到 AOT Python 环境以及运行时二进制文件。

  3. 源到源转换,以将运算符与自定义运算符交换。

编写自定义内核

为功能变体(用于 AOT 编译)和 out 变体(用于 ExecuTorch 运行时)定义您的自定义运算符架构。该架构需要遵循 PyTorch ATen 约定(请参阅native_functions.yaml)。

custom_linear(Tensor weight, Tensor input, Tensor(?) bias) -> Tensor

custom_linear.out(Tensor weight, Tensor input, Tensor(?) bias, *, Tensor(a!) out) -> Tensor(a!)

根据上面定义的架构编写您的自定义内核。使用 EXECUTORCH_LIBRARY 宏使内核可用于 ExecuTorch 运行时。

// custom_linear.h / custom_linear.cpp
#include <executorch/runtime/kernel/kernel_includes.h>

Tensor& custom_linear_out(const Tensor& weight, const Tensor& input, optional<Tensor> bias, Tensor& out) {
    // calculation
    return out;
}

// Register as myop::custom_linear.out
EXECUTORCH_LIBRARY(myop, "custom_linear.out", custom_linear_out);

要使此运算符在 PyTorch 中可用,您可以围绕 ExecuTorch 自定义内核定义一个包装器。请注意,ExecuTorch 实现使用 ExecuTorch 张量类型,而 PyTorch 包装器使用 ATen 张量。

// custom_linear_pytorch.cpp

#include "custom_linear.h"
#include <torch/library.h>

at::Tensor custom_linear(const at::Tensor& weight, const at::Tensor& input, std::optional<at::Tensor> bias) {

    // initialize out
    at::Tensor out = at::empty({weight.size(1), input.size(1)});

    // wrap kernel in custom_linear.cpp into ATen kernel
    WRAP_TO_ATEN(custom_linear_out, 3)(weight, input, bias, out);

    return out;
}

// Register the operator with PyTorch.
TORCH_LIBRARY(myop,  m) {
    m.def("custom_linear(Tensor weight, Tensor input, Tensor(?) bias) -> Tensor", custom_linear);
    m.def("custom_linear.out(Tensor weight, Tensor input, Tensor(?) bias, *, Tensor(a!) out) -> Tensor(a!)", WRAP_TO_ATEN(custom_linear_out, 3));
}

在模型中使用自定义运算符

自定义运算符可以显式地在 PyTorch 模型中使用,或者您可以编写一个转换来将核心运算符的实例替换为自定义变体。对于本示例,您可以找到 torch.nn.Linear 的所有实例,并将其替换为 CustomLinear

def  replace_linear_with_custom_linear(module):
    for name, child in module.named_children():
        if isinstance(child, nn.Linear):
            setattr(
                module,
                name,
                CustomLinear(child.in_features,  child.out_features, child.bias),
        )
        else:
            replace_linear_with_custom_linear(child)

剩余步骤与正常流程相同。现在您可以在 eager 模式下运行此模块,也可以将其导出到 ExecuTorch。

如何构建移动应用

请参阅关于在 iOS 和 Android 上使用 ExecuTorch 构建和运行 LLM 的说明。

文档

访问全面的 PyTorch 开发者文档

查看文档

教程

获取面向初学者和高级开发者的深入教程

查看教程

资源

查找开发资源并获得问题解答

查看资源