• 文档 >
  • ExecuTorch 中大型语言模型简介
快捷方式

ExecuTorch 中大型语言模型简介

欢迎使用 LLM 手册!本手册旨在提供一个实际示例,用于利用 ExecuTorch 集成您自己的大型语言模型 (LLM)。我们的主要目标是提供关于如何将我们的系统与您自己的 LLM 集成的清晰简洁的指南。

请注意,此项目旨在作为演示,而不是具有最佳性能的完整功能示例。因此,某些组件(例如采样器、标记器等)以其最基本版本提供,仅用于演示目的。因此,模型产生的结果可能会有所不同,并且可能并不总是最佳的。

我们鼓励用户将此项目作为起点并根据其特定需求进行调整,这包括创建您自己的标记器、采样器、加速后端和其他组件版本。我们希望此项目能成为您使用 LLM 和 ExecuTorch 之旅的有用指南。

有关以最佳性能部署 Llama,请参阅Llama 指南

目录

  1. 先决条件

  2. Hello World 示例

  3. 量化

  4. 使用移动加速

  5. 调试和分析

  6. 如何使用自定义内核

  7. 如何构建移动应用程序

先决条件

要遵循本指南,您需要克隆 ExecuTorch 存储库并安装依赖项。ExecuTorch 建议使用 Python 3.10 并使用 Conda 管理您的环境。虽然 Conda 不是必需的,但请注意,根据您的环境,您可能需要将 python/pip 的使用替换为 python3/pip3。

有关安装 miniconda 的说明可以在此处找到

# Create a directory for this example.
mkdir et-nanogpt
cd et-nanogpt

# Clone the ExecuTorch repository and submodules.
mkdir third-party
git clone -b release/0.4 https://github.com/pytorch/executorch.git third-party/executorch
cd third-party/executorch
git submodule update --init

# Create a conda environment and install requirements.
conda create -yn executorch python=3.10.0
conda activate executorch
./install_requirements.sh

cd ../..

有关安装 pyenv-virtualenv 的说明可以在此处找到

重要的是,如果通过 brew 安装 pyenv,它不会自动在终端中启用 pyenv,从而导致错误。运行以下命令以启用。请参阅上面 pyenv-virtualenv 安装指南,了解如何将其添加到您的 .bashrc 或 .zshrc 文件中,以避免手动运行这些命令。

eval "$(pyenv init -)"
eval "$(pyenv virtualenv-init -)"
# Create a directory for this example.
mkdir et-nanogpt
cd et-nanogpt

pyenv install -s 3.10
pyenv virtualenv 3.10 executorch
pyenv activate executorch

# Clone the ExecuTorch repository and submodules.
mkdir third-party
git clone -b release/0.4 https://github.com/pytorch/executorch.git third-party/executorch
cd third-party/executorch
git submodule update --init

# Install requirements.
PYTHON_EXECUTABLE=python ./install_requirements.sh

cd ../..

有关更多信息,请参阅设置 ExecuTorch

在本地运行大型语言模型

此示例使用 Karpathy 的nanoGPT,它是 GPT-2 124M 的最小实现。本指南适用于其他语言模型,因为 ExecuTorch 与模型无关。

使用 ExecuTorch 运行模型有两个步骤

  1. 导出模型。此步骤将其预处理为适合运行时执行的格式。

  2. 在运行时,加载模型文件并使用 ExecuTorch 运行时运行。


导出步骤提前发生,通常作为应用程序构建的一部分或在模型更改时发生。生成的 .pte 文件与应用程序一起分发。在运行时,应用程序加载 .pte 文件并将其传递给 ExecuTorch 运行时。

步骤 1. 导出到 ExecuTorch

导出采用 PyTorch 模型并将其转换为可以在消费类设备上高效运行的格式。

对于此示例,您将需要 nanoGPT 模型和相应的标记器词汇表。

curl https://raw.githubusercontent.com/karpathy/nanoGPT/master/model.py -O
curl https://huggingface.co/openai-community/gpt2/resolve/main/vocab.json -O
wget https://raw.githubusercontent.com/karpathy/nanoGPT/master/model.py
wget https://huggingface.co/openai-community/gpt2/resolve/main/vocab.json

要将模型转换为针对独立执行优化的格式,需要两个步骤。首先,使用 PyTorch export 函数将 PyTorch 模型转换为中间的、与平台无关的中间表示。然后使用 ExecuTorch to_edgeto_executorch 方法准备模型以进行设备端执行。这将创建一个 .pte 文件,该文件可以在运行时由桌面或移动应用程序加载。

创建一个名为 export_nanogpt.py 的文件,其中包含以下内容

# export_nanogpt.py

import torch

from executorch.exir import EdgeCompileConfig, to_edge
from torch.nn.attention import sdpa_kernel, SDPBackend
from torch._export import capture_pre_autograd_graph
from torch.export import export

from model import GPT

# Load the model.
model = GPT.from_pretrained('gpt2')

# Create example inputs. This is used in the export process to provide
# hints on the expected shape of the model input.
example_inputs = (torch.randint(0, 100, (1, model.config.block_size), dtype=torch.long), )

# Set up dynamic shape configuration. This allows the sizes of the input tensors
# to differ from the sizes of the tensors in `example_inputs` during runtime, as
# long as they adhere to the rules specified in the dynamic shape configuration.
# Here we set the range of 0th model input's 1st dimension as
# [0, model.config.block_size].
# See https://pytorch.ac.cn/executorch/main/concepts.html#dynamic-shapes
# for details about creating dynamic shapes.
dynamic_shape = (
    {1: torch.export.Dim("token_dim", max=model.config.block_size)},
)

# Trace the model, converting it to a portable intermediate representation.
# The torch.no_grad() call tells PyTorch to exclude training-specific logic.
with torch.nn.attention.sdpa_kernel([SDPBackend.MATH]), torch.no_grad():
    m = capture_pre_autograd_graph(model, example_inputs, dynamic_shapes=dynamic_shape)
    traced_model = export(m, example_inputs, dynamic_shapes=dynamic_shape)

# Convert the model into a runnable ExecuTorch program.
edge_config = EdgeCompileConfig(_check_ir_validity=False)
edge_manager = to_edge(traced_model,  compile_config=edge_config)
et_program = edge_manager.to_executorch()

# Save the ExecuTorch program to a file.
with open("nanogpt.pte", "wb") as file:
    file.write(et_program.buffer)

要导出,请使用 python export_nanogpt.py(或 python3,根据您的环境而定)运行脚本。它将在当前目录中生成一个 nanogpt.pte 文件。

有关更多信息,请参阅导出到 ExecuTorchtorch.export

步骤 2. 调用运行时

ExecuTorch 提供了一组运行时 API 和类型来加载和运行模型。

创建一个名为 main.cpp 的文件,其中包含以下内容

// main.cpp

#include <cstdint>

#include "basic_sampler.h"
#include "basic_tokenizer.h"

#include <executorch/extension/module/module.h>
#include <executorch/extension/tensor/tensor.h>
#include <executorch/runtime/core/evalue.h>
#include <executorch/runtime/core/exec_aten/exec_aten.h>
#include <executorch/runtime/core/result.h>

using executorch::aten::ScalarType;
using executorch::aten::Tensor;
using executorch::extension::from_blob;
using executorch::extension::Module;
using executorch::runtime::EValue;
using executorch::runtime::Result;

模型的输入和输出采用张量的形式。张量可以被认为是多维数组。ExecuTorch 的 EValue 类提供了围绕张量和其他 ExecuTorch 数据类型的包装器。

由于 LLM 每次生成一个 token,因此驱动程序代码需要重复调用模型,逐个 token 地构建输出。每个生成的 token 都会作为下一次运行的输入。

// main.cpp

// The value of the gpt2 `<|endoftext|>` token.
#define ENDOFTEXT_TOKEN 50256

std::string generate(
    Module& llm_model,
    std::string& prompt,
    BasicTokenizer& tokenizer,
    BasicSampler& sampler,
    size_t max_input_length,
    size_t max_output_length) {
  // Convert the input text into a list of integers (tokens) that represents it,
  // using the string-to-token mapping that the model was trained on. Each token
  // is an integer that represents a word or part of a word.
  std::vector<int64_t> input_tokens = tokenizer.encode(prompt);
  std::vector<int64_t> output_tokens;

  for (auto i = 0u; i < max_output_length; i++) {
    // Convert the input_tokens from a vector of int64_t to EValue. EValue is a
    // unified data type in the ExecuTorch runtime.
    auto inputs = from_blob(
        input_tokens.data(),
        {1, static_cast<int>(input_tokens.size())},
        ScalarType::Long);

    // Run the model. It will return a tensor of logits (log-probabilities).
    auto logits_evalue = llm_model.forward(inputs);

    // Convert the output logits from EValue to std::vector, which is what the
    // sampler expects.
    Tensor logits_tensor = logits_evalue.get()[0].toTensor();
    std::vector<float> logits(
        logits_tensor.data_ptr<float>(),
        logits_tensor.data_ptr<float>() + logits_tensor.numel());

    // Sample the next token from the logits.
    int64_t next_token = sampler.sample(logits);

    // Break if we reached the end of the text.
    if (next_token == ENDOFTEXT_TOKEN) {
      break;
    }

    // Add the next token to the output.
    output_tokens.push_back(next_token);

    std::cout << tokenizer.decode({next_token});
    std::cout.flush();

    // Update next input.
    input_tokens.push_back(next_token);
    if (input_tokens.size() > max_input_length) {
      input_tokens.erase(input_tokens.begin());
    }
  }

  std::cout << std::endl;

  // Convert the output tokens into a human-readable string.
  std::string output_string = tokenizer.decode(output_tokens);
  return output_string;
}

The Module 类负责加载 .pte 文件并准备执行。

分词器负责将提示的易于人类阅读的字符串表示形式转换为模型期望的数字形式。为此,分词器将短子字符串与给定的 token ID 关联。可以将 token 视为表示单词或单词的一部分,尽管在实践中,它们可能是字符的任意序列。

分词器从文件中加载词汇表,该文件包含每个 token ID与其表示的文本之间的映射。调用 tokenizer.encode()tokenizer.decode() 在字符串和 token 表示之间进行转换。

采样器负责根据模型输出的对数概率(logit)选择下一个 token。LLM 为每个可能的下一个 token 返回一个 logit 值。采样器根据某种策略选择要使用的 token。这里使用的最简单的方法是取具有最高 logit 值的 token。

采样器可以提供可配置的选项,例如对输出选择进行可配置的随机量、对重复 token 的惩罚以及优先或降低特定 token 优先级的偏差。

// main.cpp

int main() {
  // Set up the prompt. This provides the seed text for the model to elaborate.
  std::cout << "Enter model prompt: ";
  std::string prompt;
  std::getline(std::cin, prompt);

  // The tokenizer is used to convert between tokens (used by the model) and
  // human-readable strings.
  BasicTokenizer tokenizer("vocab.json");

  // The sampler is used to sample the next token from the logits.
  BasicSampler sampler = BasicSampler();

  // Load the exported nanoGPT program, which was generated via the previous
  // steps.
  Module model("nanogpt.pte", Module::LoadMode::MmapUseMlockIgnoreErrors);

  const auto max_input_tokens = 1024;
  const auto max_output_tokens = 30;
  std::cout << prompt;
  generate(
      model, prompt, tokenizer, sampler, max_input_tokens, max_output_tokens);
}

最后,将以下文件下载到与 main.cpp 相同的目录中

curl -O https://raw.githubusercontent.com/pytorch/executorch/main/examples/llm_manual/basic_sampler.h
curl -O https://raw.githubusercontent.com/pytorch/executorch/main/examples/llm_manual/basic_tokenizer.h

要了解更多信息,请参阅 运行时 API 教程

构建和运行

ExecuTorch 使用 CMake 构建系统。要编译并链接到 ExecuTorch 运行时,请通过 add_directory 包含 ExecuTorch 项目,并链接到 executorch 和其他依赖项。

创建一个名为 CMakeLists.txt 的文件,内容如下

# CMakeLists.txt

cmake_minimum_required(VERSION 3.19)
project(nanogpt_runner)

set(CMAKE_CXX_STANDARD 17)
set(CMAKE_CXX_STANDARD_REQUIRED True)

# Set options for executorch build.
option(EXECUTORCH_ENABLE_LOGGING "" ON)
option(EXECUTORCH_BUILD_EXTENSION_DATA_LOADER "" ON)
option(EXECUTORCH_BUILD_EXTENSION_MODULE "" ON)
option(EXECUTORCH_BUILD_EXTENSION_TENSOR "" ON)
option(EXECUTORCH_BUILD_KERNELS_OPTIMIZED "" ON)

# Include the executorch subdirectory.
add_subdirectory(
  ${CMAKE_CURRENT_SOURCE_DIR}/third-party/executorch
  ${CMAKE_BINARY_DIR}/executorch
)

add_executable(nanogpt_runner main.cpp)
target_link_libraries(
  nanogpt_runner
  PRIVATE executorch
          extension_module_static # Provides the Module class
          extension_tensor # Provides the TensorPtr class
          optimized_native_cpu_ops_lib # Provides baseline cross-platform
                                       # kernels
)

此时,工作目录应包含以下文件

  • CMakeLists.txt

  • main.cpp

  • basic_tokenizer.h

  • basic_sampler.h

  • export_nanogpt.py

  • model.py

  • vocab.json

  • nanogpt.pte

如果所有这些都存在,您现在可以构建和运行

(rm -rf cmake-out && mkdir cmake-out && cd cmake-out && cmake ..)
cmake --build cmake-out -j10
./cmake-out/nanogpt_runner

您应该会看到以下消息

Enter model prompt:

为模型输入一些种子文本并按 Enter 键。这里我们使用“Hello world!”作为示例提示

Enter model prompt: Hello world!
Hello world!

I'm not sure if you've heard of the "Curse of the Dragon" or not, but it's a very popular game in

此时,它可能会运行非常缓慢。这是因为 ExecuTorch 尚未被告知要针对特定硬件进行优化(委托),并且因为它正在以 32 位浮点数执行所有计算(没有量化)。

委托

虽然 ExecuTorch 为所有运算符提供了可移植的跨平台实现,但它也为许多不同的目标提供了专门的后端。这些包括但不限于:通过 XNNPACK 后端进行 x86 和 ARM CPU 加速,通过 Core ML 后端和 Metal Performance Shader (MPS) 后端进行 Apple 加速,以及通过 Vulkan 后端进行 GPU 加速。

由于优化特定于给定的后端,因此每个 pte 文件都特定于导出目标的后端。为了支持多个设备,例如 Android 的 XNNPACK 加速和 iOS 的 Core ML,请为每个后端导出单独的 PTE 文件。

为了在导出时委托到后端,ExecuTorch 在 EdgeProgramManager 对象中提供了 to_backend() 函数,该函数接受一个特定于后端的划分器对象。划分器负责查找可以由目标后端加速的计算图的部分,并且 to_backend() 函数会将匹配的部分委托给给定的后端进行加速和优化。计算图中未委托的任何部分都将由 ExecuTorch 运算符实现执行。

要将导出的模型委托给特定后端,我们需要首先导入其划分器以及来自 ExecuTorch 代码库的边缘编译配置,然后在 EdgeProgramManager 对象的 to_edge 函数创建的实例上使用划分器的实例调用 to_backend

以下是如何将 nanoGPT 委托给 XNNPACK 的示例(例如,如果您要部署到 Android 手机上)

# export_nanogpt.py

# Load partitioner for Xnnpack backend
from executorch.backends.xnnpack.partition.xnnpack_partitioner import XnnpackPartitioner

# Model to be delegated to specific backend should use specific edge compile config
from executorch.backends.xnnpack.utils.configs import get_xnnpack_edge_compile_config
from executorch.exir import EdgeCompileConfig, to_edge

import torch
from torch.export import export
from torch.nn.attention import sdpa_kernel, SDPBackend
from torch._export import capture_pre_autograd_graph

from model import GPT

# Load the nanoGPT model.
model = GPT.from_pretrained('gpt2')

# Create example inputs. This is used in the export process to provide
# hints on the expected shape of the model input.
example_inputs = (
        torch.randint(0, 100, (1, model.config.block_size - 1), dtype=torch.long),
    )

# Set up dynamic shape configuration. This allows the sizes of the input tensors
# to differ from the sizes of the tensors in `example_inputs` during runtime, as
# long as they adhere to the rules specified in the dynamic shape configuration.
# Here we set the range of 0th model input's 1st dimension as
# [0, model.config.block_size].
# See https://pytorch.ac.cn/executorch/main/concepts.html#dynamic-shapes
# for details about creating dynamic shapes.
dynamic_shape = (
    {1: torch.export.Dim("token_dim", max=model.config.block_size - 1)},
)

# Trace the model, converting it to a portable intermediate representation.
# The torch.no_grad() call tells PyTorch to exclude training-specific logic.
with torch.nn.attention.sdpa_kernel([SDPBackend.MATH]), torch.no_grad():
    m = capture_pre_autograd_graph(model, example_inputs, dynamic_shapes=dynamic_shape)
    traced_model = export(m, example_inputs, dynamic_shapes=dynamic_shape)

# Convert the model into a runnable ExecuTorch program.
# To be further lowered to Xnnpack backend, `traced_model` needs xnnpack-specific edge compile config
edge_config = get_xnnpack_edge_compile_config()
edge_manager = to_edge(traced_model, compile_config=edge_config)

# Delegate exported model to Xnnpack backend by invoking `to_backend` function with Xnnpack partitioner.
edge_manager = edge_manager.to_backend(XnnpackPartitioner())
et_program = edge_manager.to_executorch()

# Save the Xnnpack-delegated ExecuTorch program to a file.
with open("nanogpt.pte", "wb") as file:
    file.write(et_program.buffer)


此外,更新 CMakeLists.txt 以构建并将 XNNPACK 后端链接到 ExecuTorch 运行程序。

cmake_minimum_required(VERSION 3.19)
project(nanogpt_runner)

set(CMAKE_CXX_STANDARD 17)
set(CMAKE_CXX_STANDARD_REQUIRED True)

# Set options for executorch build.
option(EXECUTORCH_ENABLE_LOGGING "" ON)
option(EXECUTORCH_BUILD_EXTENSION_DATA_LOADER "" ON)
option(EXECUTORCH_BUILD_EXTENSION_MODULE "" ON)
option(EXECUTORCH_BUILD_EXTENSION_TENSOR "" ON)
option(EXECUTORCH_BUILD_KERNELS_OPTIMIZED "" ON)
option(EXECUTORCH_BUILD_XNNPACK "" ON) # Build with Xnnpack backend

# Include the executorch subdirectory.
add_subdirectory(
  ${CMAKE_CURRENT_SOURCE_DIR}/third-party/executorch
  ${CMAKE_BINARY_DIR}/executorch
)

add_executable(nanogpt_runner main.cpp)
target_link_libraries(
  nanogpt_runner
  PRIVATE executorch
          extension_module_static # Provides the Module class
          extension_tensor # Provides the TensorPtr class
          optimized_native_cpu_ops_lib # Provides baseline cross-platform
                                       # kernels
          xnnpack_backend # Provides the XNNPACK CPU acceleration backend
)

保持代码的其余部分不变。有关更多详细信息,请参阅 导出到 ExecuTorch调用运行时 以获取更多详细信息

此时,工作目录应包含以下文件

  • CMakeLists.txt

  • main.cpp

  • basic_tokenizer.h

  • basic_sampler.h

  • export_nanogpt.py

  • model.py

  • vocab.json

如果所有这些都存在,您现在可以导出 Xnnpack 委托的 pte 模型

python export_nanogpt.py

它将在相同的工作目录下生成 nanogpt.pte

然后我们可以通过以下方式构建和运行模型:

(rm -rf cmake-out && mkdir cmake-out && cd cmake-out && cmake ..)
cmake --build cmake-out -j10
./cmake-out/nanogpt_runner

您应该会看到以下消息

Enter model prompt:

为模型输入一些种子文本并按 Enter 键。这里我们使用“Hello world!”作为示例提示

Enter model prompt: Hello world!
Hello world!

I'm not sure if you've heard of the "Curse of the Dragon" or not, but it's a very popular game in

与非委托模型相比,委托模型的速度应该明显更快。

有关后端委托的更多信息,请参阅 ExecuTorch 指南,了解 XNNPACK 后端Core ML 后端高通 AI 引擎直接后端

量化

量化指的是一组使用较低精度类型运行计算和存储张量的技术。与 32 位浮点数相比,使用 8 位整数可以显着提高速度并减少内存使用量。量化模型的方法有很多,在所需的预处理量、使用的数据类型以及对模型准确性和性能的影响方面各不相同。

由于移动设备上的计算和内存都非常有限,因此需要某种形式的量化才能在消费电子产品上部署大型模型。特别是,大型语言模型(如 Llama2)可能需要将模型权重量化为 4 位或更低。

利用量化需要在导出之前转换模型。PyTorch 为此目的提供了 pt2e(PyTorch 2 Export)API。此示例针对使用 XNNPACK 委托进行 CPU 加速。因此,它需要使用特定于 XNNPACK 的量化器。针对不同的后端将需要使用相应的量化器。

要使用 XNNPACK 委托的 8 位整数动态量化,请调用 prepare_pt2e,通过使用代表性输入运行来校准模型,然后调用 convert_pt2e。这会更新计算图以在可用时使用量化运算符。

# export_nanogpt.py

from executorch.backends.transforms.duplicate_dynamic_quant_chain import (
    DuplicateDynamicQuantChainPass,
)
from torch.ao.quantization.quantizer.xnnpack_quantizer import (
    get_symmetric_quantization_config,
    XNNPACKQuantizer,
)
from torch.ao.quantization.quantize_pt2e import convert_pt2e, prepare_pt2e
# Use dynamic, per-channel quantization.
xnnpack_quant_config = get_symmetric_quantization_config(
    is_per_channel=True, is_dynamic=True
)
xnnpack_quantizer = XNNPACKQuantizer()
xnnpack_quantizer.set_global(xnnpack_quant_config)

m = capture_pre_autograd_graph(model, example_inputs)

# Annotate the model for quantization. This prepares the model for calibration.
m = prepare_pt2e(m, xnnpack_quantizer)

# Calibrate the model using representative inputs. This allows the quantization
# logic to determine the expected range of values in each tensor.
m(*example_inputs)

# Perform the actual quantization.
m = convert_pt2e(m, fold_quantize=False)
DuplicateDynamicQuantChainPass()(m)

traced_model = export(m, example_inputs)

此外,添加或更新 to_backend() 调用以使用 XnnpackPartitioner。这指示 ExecuTorch 通过 XNNPACK 后端优化模型以进行 CPU 执行。

from executorch.backends.xnnpack.partition.xnnpack_partitioner import (
    XnnpackPartitioner,
)
edge_manager = to_edge(traced_model, compile_config=edge_config)
edge_manager = edge_manager.to_backend(XnnpackPartitioner()) # Lower to XNNPACK.
et_program = edge_manager.to_executorch()

最后,确保运行程序在 CMakeLists.txt 中链接到 xnnpack_backend 目标。

add_executable(nanogpt_runner main.cpp)
target_link_libraries(
    nanogpt_runner
    PRIVATE
    executorch
    extension_module_static # Provides the Module class
    optimized_native_cpu_ops_lib # Provides baseline cross-platform kernels
    xnnpack_backend) # Provides the XNNPACK CPU acceleration backend

有关更多信息,请参阅 ExecuTorch 中的量化

性能分析和调试

通过调用 to_backend() 降低模型后,您可能希望查看委托了什么内容以及未委托什么内容。ExecuTorch 提供实用程序方法来深入了解委托情况。您可以使用此信息来了解底层计算并诊断潜在的性能问题。模型作者可以使用此信息以与目标后端兼容的方式构建模型。

可视化委托

get_delegation_info() 方法提供了在 to_backend() 调用后模型发生情况的摘要

from executorch.devtools.backend_debug import get_delegation_info
from tabulate import tabulate

# ... After call to to_backend(), but before to_executorch()
graph_module = edge_manager.exported_program().graph_module
delegation_info = get_delegation_info(graph_module)
print(delegation_info.get_summary())
df = delegation_info.get_operator_delegation_dataframe()
print(tabulate(df, headers="keys", tablefmt="fancy_grid"))

对于针对 XNNPACK 后端的 nanoGPT,您可能会看到以下内容(请注意,以下数字仅用于说明目的,实际值可能会有所不同)

Total  delegated  subgraphs:  145
Number  of  delegated  nodes:  350
Number  of  non-delegated  nodes:  760

op_type

# in_delegated_graphs

# in_non_delegated_graphs

0

aten__softmax_default

12

0

1

aten_add_tensor

37

0

2

aten_addmm_default

48

0

3

aten_any_dim

0

12

25

aten_view_copy_default

96

122

30

Total

350

760

从表中可以看出,运算符 aten_view_copy_default 在委托图中出现了 96 次,在非委托图中出现了 122 次。要查看更详细的视图,请使用 format_delegated_graph() 方法获取整个图的格式化字符串打印输出,或使用 print_delegated_graph() 直接打印

from executorch.exir.backend.utils import format_delegated_graph
graph_module = edge_manager.exported_program().graph_module
print(format_delegated_graph(graph_module))

对于大型模型,这可能会生成大量输出。考虑使用“Control+F”或“Command+F”查找您感兴趣的运算符(例如,“aten_view_copy_default”)。观察哪些实例不在降低的图下。

在下面 nanoGPT 输出片段中,观察到 transformer 模块已委托给 XNNPACK,而 where 运算符则没有。

%aten_where_self_22 : [num_users=1] = call_function[target=executorch.exir.dialects.edge._ops.aten.where.self](args = (%aten_logical_not_default_33, %scalar_tensor_23, %scalar_tensor_22), kwargs = {})
%lowered_module_144 : [num_users=1] = get_attr[target=lowered_module_144]
backend_id: XnnpackBackend
lowered graph():
    %p_transformer_h_0_attn_c_attn_weight : [num_users=1] = placeholder[target=p_transformer_h_0_attn_c_attn_weight]
    %p_transformer_h_0_attn_c_attn_bias : [num_users=1] = placeholder[target=p_transformer_h_0_attn_c_attn_bias]
    %getitem : [num_users=1] = placeholder[target=getitem]
    %sym_size : [num_users=2] = placeholder[target=sym_size]
    %aten_view_copy_default : [num_users=1] = call_function[target=executorch.exir.dialects.edge._ops.aten.view_copy.default](args = (%getitem, [%sym_size, 768]), kwargs = {})
    %aten_permute_copy_default : [num_users=1] = call_function[target=executorch.exir.dialects.edge._ops.aten.permute_copy.default](args = (%p_transformer_h_0_attn_c_attn_weight, [1, 0]), kwargs = {})
    %aten_addmm_default : [num_users=1] = call_function[target=executorch.exir.dialects.edge._ops.aten.addmm.default](args = (%p_transformer_h_0_attn_c_attn_bias, %aten_view_copy_default, %aten_permute_copy_default), kwargs = {})
    %aten_view_copy_default_1 : [num_users=1] = call_function[target=executorch.exir.dialects.edge._ops.aten.view_copy.default](args = (%aten_addmm_default, [1, %sym_size, 2304]), kwargs = {})
    return [aten_view_copy_default_1]

性能分析

通过 ExecuTorch 开发者工具,用户能够分析模型执行情况,为模型中的每个运算符提供时间信息。

先决条件

ETRecord 生成(可选)

ETRecord 是在导出时生成的工件,其中包含模型图和源级元数据,将 ExecuTorch 程序链接到原始 PyTorch 模型。您可以查看没有 ETRecord 的所有性能分析事件,但使用 ETRecord 时,您还可以将每个事件链接到正在执行的运算符类型、模块层次结构以及原始 PyTorch 源代码的堆栈跟踪。有关更多信息,请参阅 ETRecord 文档

在导出脚本中,在调用 to_edge()to_executorch() 之后,使用来自 to_edge()EdgeProgramManager 和来自 to_executorch()ExecuTorchProgramManager 调用 generate_etrecord()。请确保复制 EdgeProgramManager,因为对 to_backend() 的调用会就地修改图。

# export_nanogpt.py

import copy
from executorch.devtools import generate_etrecord

# Make the deep copy immediately after to to_edge()
edge_manager_copy = copy.deepcopy(edge_manager)

# ...
# Generate ETRecord right after to_executorch()
etrecord_path = "etrecord.bin"
generate_etrecord(etrecord_path, edge_manager_copy, et_program)

运行导出脚本,ETRecord 将生成,名为 etrecord.bin

ETDump 生成

ETDump 是在运行时生成的工件,其中包含模型执行的跟踪。有关更多信息,请参阅 ETDump 文档

在您的代码中包含 ETDump 头文件。

// main.cpp

#include <executorch/devtools/etdump/etdump_flatcc.h>

创建 ETDumpGen 类的实例并将其传递给 Module 构造函数。

std::unique_ptr<ETDumpGen> etdump_gen_ = std::make_unique<ETDumpGen>();
Module model("nanogpt.pte", Module::LoadMode::MmapUseMlockIgnoreErrors, std::move(etdump_gen_));

调用 generate() 后,将 ETDump 保存到文件中。如果需要,您可以在单个跟踪中捕获多个模型运行。

ETDumpGen* etdump_gen = static_cast<ETDumpGen*>(model.event_tracer());

ET_LOG(Info, "ETDump size: %zu blocks", etdump_gen->get_num_blocks());
etdump_result result = etdump_gen->get_etdump_data();
if (result.buf != nullptr && result.size > 0) {
    // On a device with a file system, users can just write it to a file.
    FILE* f = fopen("etdump.etdp", "w+");
    fwrite((uint8_t*)result.buf, 1, result.size, f);
    fclose(f);
    free(result.buf);
}

此外,更新 CMakeLists.txt 以使用开发者工具进行构建,并启用事件跟踪并记录到 ETDump 中

option(EXECUTORCH_ENABLE_EVENT_TRACER "" ON)
option(EXECUTORCH_BUILD_DEVTOOLS "" ON)

# ...

target_link_libraries(
    # ... omit existing ones
    etdump) # Provides event tracing and logging

target_compile_options(executorch PUBLIC -DET_EVENT_TRACER_ENABLED)
target_compile_options(portable_ops_lib PUBLIC -DET_EVENT_TRACER_ENABLED)

构建并运行 runner,您将看到生成一个名为“etdump.etdp”的文件。(请注意,这次我们使用发布模式构建以解决 flatccrt 构建限制。)

(rm -rf cmake-out && mkdir cmake-out && cd cmake-out && cmake -DCMAKE_BUILD_TYPE=Release ..)
cmake --build cmake-out -j10
./cmake-out/nanogpt_runner

使用 Inspector API 进行分析

收集完调试工件 ETDump(以及可选的 ETRecord)后,您可以使用 Inspector API 查看性能信息。

from executorch.devtools import Inspector

inspector = Inspector(etdump_path="etdump.etdp")
# If you also generated an ETRecord, then pass that in as well: `inspector = Inspector(etdump_path="etdump.etdp", etrecord="etrecord.bin")`

with open("inspector_out.txt", "w") as file:
    inspector.print_data_tabular(file)

这会以表格格式将性能数据打印到“inspector_out.txt”中,每一行表示一个性能分析事件。前几行如下所示: 查看完整尺寸

要详细了解 Inspector 及其提供的丰富功能,请参阅 Inspector API 参考

自定义内核

使用 ExecuTorch 自定义算子 API,自定义算子和内核作者可以轻松地将他们的内核引入 PyTorch/ExecuTorch。

在 ExecuTorch 中使用自定义内核有三个步骤

  1. 使用 ExecuTorch 类型编写自定义内核。

  2. 将自定义内核编译并链接到 AOT Python 环境以及运行时二进制文件。

  3. 源到源转换以用自定义操作替换操作符。

编写自定义内核

为函数变体(用于 AOT 编译)和输出变体(用于 ExecuTorch 运行时)定义自定义操作符模式。模式需要遵循 PyTorch ATen 约定(请参阅 native_functions.yaml)。

custom_linear(Tensor weight, Tensor input, Tensor(?) bias) -> Tensor

custom_linear.out(Tensor weight, Tensor input, Tensor(?) bias, *, Tensor(a!) out) -> Tensor(a!)

根据上面定义的模式编写自定义内核。使用 EXECUTORCH_LIBRARY 宏使内核可用于 ExecuTorch 运行时。

// custom_linear.h / custom_linear.cpp
#include <executorch/runtime/kernel/kernel_includes.h>

Tensor& custom_linear_out(const Tensor& weight, const Tensor& input, optional<Tensor> bias, Tensor& out) {
    // calculation
    return out;
}

// Register as myop::custom_linear.out
EXECUTORCH_LIBRARY(myop, "custom_linear.out", custom_linear_out);

要使此操作符在 PyTorch 中可用,您可以在 ExecuTorch 自定义内核周围定义一个包装器。请注意,ExecuTorch 实现使用 ExecuTorch 张量类型,而 PyTorch 包装器使用 ATen 张量。

// custom_linear_pytorch.cpp

#include "custom_linear.h"
#include <torch/library.h>

at::Tensor custom_linear(const at::Tensor& weight, const at::Tensor& input, std::optional<at::Tensor> bias) {

    // initialize out
    at::Tensor out = at::empty({weight.size(1), input.size(1)});

    // wrap kernel in custom_linear.cpp into ATen kernel
    WRAP_TO_ATEN(custom_linear_out, 3)(weight, input, bias, out);

    return out;
}

// Register the operator with PyTorch.
TORCH_LIBRARY(myop,  m) {
    m.def("custom_linear(Tensor weight, Tensor input, Tensor(?) bias) -> Tensor", custom_linear);
    m.def("custom_linear.out(Tensor weight, Tensor input, Tensor(?) bias, *, Tensor(a!) out) -> Tensor(a!)", WRAP_TO_ATEN(custom_linear_out, 3));
}

在模型中使用自定义操作符

可以在 PyTorch 模型中显式使用自定义操作符,或者您可以编写转换以用自定义变体替换核心操作符的实例。对于此示例,您可以查找 torch.nn.Linear 的所有实例,并将其替换为 CustomLinear

def  replace_linear_with_custom_linear(module):
    for name, child in module.named_children():
        if isinstance(child, nn.Linear):
            setattr(
                module,
                name,
                CustomLinear(child.in_features,  child.out_features, child.bias),
        )
        else:
            replace_linear_with_custom_linear(child)

其余步骤与正常流程相同。现在,您可以在渴望模式下运行此模块,并将其导出到 ExecuTorch。

如何构建移动应用程序

请参阅有关在 iOS 和 Android 上使用 ExecuTorch 构建和运行 LLMs 的说明。

文档

访问 PyTorch 的全面开发者文档

查看文档

教程

获取初学者和高级开发人员的深入教程

查看教程

资源

查找开发资源并获得问题的解答

查看资源