• 文档 >
  • 通过 ExecuTorch 开始使用 LLM
快捷方式

通过 ExecuTorch 开始使用 LLM

欢迎使用 LLM 手册!本手册旨在提供一个实用示例,以利用 ExecuTorch 将您自己的大型语言模型 (LLM) 集成到系统中。我们的主要目标是提供一个清晰简洁的指南,说明如何将我们的系统与您自己的 LLM 集成。

请注意,此项目旨在作为演示,而不是作为具有最佳性能的完全功能示例。因此,某些组件(如采样器、标记化器等)以其最精简的版本提供,仅用于演示目的。因此,模型产生的结果可能会有所不同,并且可能并不总是最佳的。

我们鼓励用户将此项目用作起点,并根据其特定需求进行调整,其中包括创建自己的标记化器、采样器、加速后端和其他组件版本。我们希望此项目能成为您使用 LLM 和 ExecuTorch 旅程中的有用指南。

目录

  1. 先决条件

  2. Hello World 示例

  3. 量化

  4. 使用移动加速

  5. 调试和分析

  6. 如何使用自定义内核

  7. 如何构建移动应用程序

先决条件

要遵循本指南,您需要克隆 ExecuTorch 存储库并安装依赖项。ExecuTorch 推荐使用 Python 3.10 和 Conda 来管理您的环境。不需要 Conda,但请注意,您可能需要根据您的环境将使用 python/pip 替换为 python3/pip3。

可以在此处找到有关安装 miniconda 的说明。

# Create a directory for this example.
mkdir et-nanogpt
cd et-nanogpt

# Clone the ExecuTorch repository and submodules.
mkdir third-party
git clone -b release/0.2 https://github.com/pytorch/executorch.git third-party/executorch
cd third-party/executorch
git submodule update --init

# Create a conda environment and install requirements.
conda create -yn executorch python=3.10.0
conda activate executorch
pip install cmake zstd
./install_requirements.sh

cd ../..

可以在此处找到有关安装 pyenv-virtualenv 的说明。

重要的是,如果通过 brew 安装 pyenv,它不会自动在终端中启用 pyenv,从而导致错误。运行以下命令以启用。请参阅上述 pyenv-virtualenv 安装指南,了解如何将其添加到您的 .bashrc 或 .zshrc 中,以避免需要手动运行这些命令。

eval "$(pyenv init -)"
eval "$(pyenv virtualenv-init -)"
# Create a directory for this example.
mkdir et-nanogpt
cd et-nanogpt

pyenv install -s 3.10
pyenv virtualenv 3.10 executorch
pyenv activate executorch

# Clone the ExecuTorch repository and submodules.
mkdir third-party
git clone -b release/0.2 https://github.com/pytorch/executorch.git third-party/executorch
cd third-party/executorch
git submodule update --init

# Install requirements.
pip install cmake zstd
PYTHON_EXECUTABLE=python ./install_requirements.sh

cd ../..

有关更多信息,请参阅设置 ExecuTorch

在本地运行大型语言模型

此示例使用 Karpathy 的nanoGPT,它是 GPT-2 124M 的最小实现。本指南适用于其他语言模型,因为 ExecuTorch 与模型无关。

使用 ExecuTorch 运行模型有两个步骤

  1. 导出模型。此步骤将其预处理为适合运行时执行的格式。

  2. 在运行时,加载模型文件并使用 ExecuTorch 运行时运行。


导出步骤在之前进行,通常作为应用程序构建的一部分或在模型更改时进行。生成的 .pte 文件与应用程序一起分发。在运行时,应用程序加载 .pte 文件并将其传递给 ExecuTorch 运行时。

步骤 1. 导出到 ExecuTorch

导出需要一个 PyTorch 模型,并将其转换为可以在消费设备上高效运行的格式。

对于此示例,您需要 nanoGPT 模型和相应的标记器词汇表。

curl https://raw.githubusercontent.com/karpathy/nanoGPT/master/model.py -O
curl https://huggingface.co/openai-community/gpt2/resolve/main/vocab.json -O
wget https://raw.githubusercontent.com/karpathy/nanoGPT/master/model.py
wget https://huggingface.co/openai-community/gpt2/resolve/main/vocab.json

要将模型转换为针对独立执行进行优化的格式,有两个步骤。首先,使用 PyTorch export 函数将 PyTorch 模型转换为中间的、与平台无关的中间表示。然后使用 ExecuTorch to_edgeto_executorch 方法准备模型以便在设备上执行。这会创建一个 .pte 文件,该文件可以在运行时由桌面或移动应用程序加载。

创建一个名为 export_nanogpt.py 的文件,其中包含以下内容

# export_nanogpt.py

import torch

from executorch.exir import EdgeCompileConfig, to_edge
from torch.nn.attention import sdpa_kernel, SDPBackend
from torch._export import capture_pre_autograd_graph
from torch.export import export

from model import GPT

# Load the model.
model = GPT.from_pretrained('gpt2')

# Create example inputs. This is used in the export process to provide
# hints on the expected shape of the model input.
example_inputs = (torch.randint(0, 100, (1, model.config.block_size), dtype=torch.long), )

# Set up dynamic shape configuration. This allows the sizes of the input tensors
# to differ from the sizes of the tensors in `example_inputs` during runtime, as
# long as they adhere to the rules specified in the dynamic shape configuration.
# Here we set the range of 0th model input's 1st dimension as
# [0, model.config.block_size].
# See https://pytorch.ac.cn/executorch/main/concepts.html#dynamic-shapes
# for details about creating dynamic shapes.
dynamic_shape = (
    {1: torch.export.Dim("token_dim", max=model.config.block_size)},
)

# Trace the model, converting it to a portable intermediate representation.
# The torch.no_grad() call tells PyTorch to exclude training-specific logic.
with torch.nn.attention.sdpa_kernel([SDPBackend.MATH]), torch.no_grad():
    m = capture_pre_autograd_graph(model, example_inputs, dynamic_shapes=dynamic_shape)
    traced_model = export(m, example_inputs, dynamic_shapes=dynamic_shape)

# Convert the model into a runnable ExecuTorch program.
edge_config = EdgeCompileConfig(_check_ir_validity=False)
edge_manager = to_edge(traced_model,  compile_config=edge_config)
et_program = edge_manager.to_executorch()

# Save the ExecuTorch program to a file.
with open("nanogpt.pte", "wb") as file:
    file.write(et_program.buffer)

要导出,请使用 python export_nanogpt.py(或 python3,具体取决于您的环境)运行脚本。它将在当前目录中生成一个 nanogpt.pte 文件。

有关更多信息,请参阅 导出到 ExecuTorchtorch.export

步骤 2. 调用运行时

ExecuTorch 提供了一组运行时 API 和类型来加载和运行模型。

创建一个名为 main.cpp 的文件,其中包含以下内容

// main.cpp

#include <cstdint>
#include <functional>
#include <memory>
#include <unordered_map>

#include "basic_tokenizer.h"
#include "basic_sampler.h"
#include "managed_tensor.h"

#include <executorch/extension/module/module.h>
#include <executorch/extension/evalue_util/print_evalue.h>
#include <executorch/runtime/core/exec_aten/exec_aten.h>
#include <executorch/runtime/core/exec_aten/util/dim_order_util.h>
#include <executorch/runtime/core/exec_aten/util/tensor_util.h>

using namespace torch::executor;

using SizesType = exec_aten::SizesType;
using DimOrderType = exec_aten::DimOrderType;
using StridesType = exec_aten::StridesType;

模型输入和输出采用张量的形式。可以将张量视为多维数组。ExecuTorch EValue 类提供了一个包装器,用于包装张量和其他 ExecuTorch 数据类型。

由于 LLM 一次生成一个标记,因此驱动程序代码需要反复调用模型,逐个标记构建输出标记。每个生成的标记作为下一个运行的输入传递。

// main.cpp

// The value of the gpt2 `<|endoftext|>` token.
#define ENDOFTEXT_TOKEN 50256

std::string generate(
    Module& llm_model,
    std::string& prompt,
    BasicTokenizer& tokenizer,
    BasicSampler& sampler,
    size_t max_input_length,
    size_t max_output_length) {

    // Convert the input text into a list of integers (tokens) that represents
    // it, using the string-to-token mapping that the model was trained on.
    // Each token is an integer that represents a word or part of a word.
    std::vector<int64_t> input_tokens = tokenizer.encode(prompt);
    std::vector<int64_t> output_tokens;

    for (auto i = 0u; i < max_output_length; i++) {
        // Convert the input_tokens from a vector of int64_t to EValue.
        // EValue is a unified data type in the ExecuTorch runtime.
        ManagedTensor tensor_tokens(
            input_tokens.data(),
            {1, static_cast<int>(input_tokens.size())},
            ScalarType::Long);
        std::vector<EValue> inputs = {tensor_tokens.get_tensor()};

        // Run the model. It will return a tensor of logits (log-probabilities).
        Result<std::vector<EValue>> logits_evalue = llm_model.forward(inputs);

        // Convert the output logits from EValue to std::vector, which is what
        // the sampler expects.
        Tensor logits_tensor = logits_evalue.get()[0].toTensor();
        std::vector<float> logits(logits_tensor.data_ptr<float>(),
            logits_tensor.data_ptr<float>() + logits_tensor.numel());

        // Sample the next token from the logits.
        int64_t next_token = sampler.sample(logits);

        // Break if we reached the end of the text.
        if (next_token == ENDOFTEXT_TOKEN) {
            break;
        }

        // Add the next token to the output.
        output_tokens.push_back(next_token);

        std::cout << tokenizer.decode({ next_token });
        std::cout.flush();

        // Update next input.
        input_tokens.push_back(next_token);
        if (input_tokens.size() > max_input_length) {
            input_tokens.erase(input_tokens.begin());
        }
    }

    std::cout << std::endl;

    // Convert the output tokens into a human-readable string.
    std::string output_string = tokenizer.decode(output_tokens);
    return output_string;
}

Module 类负责加载 .pte 文件并准备执行。

标记器负责将提示的人类可读字符串表示形式转换为模型期望的数字形式。为此,标记器将短子字符串与给定的标记 ID 关联起来。标记可以被认为表示单词或单词的一部分,尽管在实践中,它们可以是任意字符序列。

标记器从文件中加载词汇表,其中包含每个标记 ID 与其表示的文本之间的映射。调用 tokenizer.encode()tokenizer.decode() 在字符串和标记表示之间进行转换。

采样器负责根据模型输出的 logit 或对数概率选择下一个标记。LLM 为每个可能的下一个标记返回一个 logit 值。采样器根据某些策略选择要使用的标记。这里使用的最简单方法是取 logit 值最高的标记。

采样器可以提供可配置选项,例如可配置的输出选择随机性、重复标记的惩罚以及优先或取消优先特定标记的偏差。

// main.cpp

int main() {
    // Set up the prompt. This provides the seed text for the model to elaborate.
    std::cout << "Enter model prompt: ";
    std::string prompt;
    std::getline(std::cin, prompt);

    // The tokenizer is used to convert between tokens (used by the model) and
    // human-readable strings.
    BasicTokenizer tokenizer("vocab.json");

    // The sampler is used to sample the next token from the logits.
    BasicSampler sampler = BasicSampler();

    // Load the exported nanoGPT program, which was generated via the previous steps.
    Module model("nanogpt.pte", torch::executor::Module::MlockConfig::UseMlockIgnoreErrors);

    const auto max_input_tokens = 1024;
    const auto max_output_tokens = 30;
    std::cout << prompt;
    generate(model, prompt, tokenizer, sampler, max_input_tokens, max_output_tokens);
}

最后,将以下文件下载到与 main.h 相同的目录中

curl -O https://raw.githubusercontent.com/pytorch/executorch/main/examples/llm_manual/basic_sampler.h
curl -O https://raw.githubusercontent.com/pytorch/executorch/main/examples/llm_manual/basic_tokenizer.h
curl -O https://raw.githubusercontent.com/pytorch/executorch/main/examples/llm_manual/managed_tensor.h

要了解更多信息,请参阅 在 C++ 中运行 ExecuTorch 模型ExecuTorch 运行时 API 参考

构建和运行

ExecuTorch 使用 CMake 构建系统。要编译并链接到 ExecuTorch 运行时,请通过 add_directory 包含 ExecuTorch 项目,并链接到 executorch 和其他依赖项。

创建一个名为 CMakeLists.txt 的文件,其内容如下

# CMakeLists.txt

cmake_minimum_required(VERSION 3.19)
project(nanogpt_runner)

set(CMAKE_CXX_STANDARD 17)
set(CMAKE_CXX_STANDARD_REQUIRED True)

# Set options for executorch build.
option(EXECUTORCH_BUILD_EXTENSION_DATA_LOADER "" ON)
option(EXECUTORCH_BUILD_EXTENSION_MODULE "" ON)
option(EXECUTORCH_BUILD_OPTIMIZED "" ON)

# Include the executorch subdirectory.
add_subdirectory(
    ${CMAKE_CURRENT_SOURCE_DIR}/third-party/executorch
    ${CMAKE_BINARY_DIR}/third-party/executorch)

add_executable(nanogpt_runner main.cpp)
target_link_libraries(
    nanogpt_runner
    PRIVATE
    executorch
    extension_module_static # Provides the Module class
    optimized_native_cpu_ops_lib) # Provides baseline cross-platform kernels

此时,工作目录应包含以下文件

  • CMakeLists.txt

  • main.cpp

  • basic_tokenizer.h

  • basic_sampler.h

  • managed_tensor.h

  • export_nanogpt.py

  • model.py

  • vocab.json

  • nanogpt.pte

如果所有这些都存在,你现在可以构建并运行

(rm -rf cmake-out && mkdir cmake-out && cd cmake-out && cmake ..)
cmake --build cmake-out -j10
./cmake-out/nanogpt_runner

你应该看到消息

Enter model prompt:

为模型输入一些种子文本并按回车键。这里我们使用“Hello world!”作为示例提示

Enter model prompt: Hello world!
Hello world!

I'm not sure if you've heard of the "Curse of the Dragon" or not, but it's a very popular game in

此时,它可能运行得非常慢。这是因为 ExecuTorch 尚未被告知针对特定硬件进行优化(委托),并且因为它正在以 32 位浮点数(无量化)进行所有计算。

委托

虽然 ExecuTorch 为所有运算符提供可移植的跨平台实现,但它还为许多不同的目标提供专门的后端。其中包括但不限于,通过 XNNPACK 后端进行 x86 和 ARM CPU 加速,通过 Core ML 后端和 Metal Performance Shader (MPS) 后端进行 Apple 加速,以及通过 Vulkan 后端进行 GPU 加速。

由于优化特定于给定的后端,因此每个 pte 文件特定于目标导出后端。要支持多个设备(例如,适用于 Android 的 XNNPACK 加速和适用于 iOS 的 Core ML),请为每个后端导出一个单独的 PTE 文件。

要在导出时委托给后端,ExecuTorch 在 EdgeProgramManager 对象中提供了 to_backend() 函数,该函数采用特定于后端的分配器对象。分配器负责查找可以由目标后端加速的计算图部分,而 to_backend() 函数将委托匹配的部分给定的后端以进行加速和优化。计算图的任何未委托部分都将由 ExecuTorch 运算符实现执行。

要将导出的模型委托给特定后端,我们需要首先从 ExecuTorch 代码库导入其分配器和边缘编译配置,然后在 EdgeProgramManager 对象 to_edge 函数上使用分配器实例调用 to_backend

以下是如何将 nanoGPT 委托给 XNNPACK 的示例(例如,如果你要部署到 Android 手机)

# export_nanogpt.py

# Load partitioner for Xnnpack backend
from executorch.backends.xnnpack.partition.xnnpack_partitioner import XnnpackPartitioner

# Model to be delegated to specific backend should use specific edge compile config
from executorch.backends.xnnpack.utils.configs import get_xnnpack_edge_compile_config
from executorch.exir import EdgeCompileConfig, to_edge

import torch
from torch.export import export
from torch.nn.attention import sdpa_kernel, SDPBackend
from torch._export import capture_pre_autograd_graph

from model import GPT

# Load the nanoGPT model.
model = GPT.from_pretrained('gpt2')

# Create example inputs. This is used in the export process to provide
# hints on the expected shape of the model input.
example_inputs = (
        torch.randint(0, 100, (1, model.config.block_size - 1), dtype=torch.long),
    )

# Set up dynamic shape configuration. This allows the sizes of the input tensors
# to differ from the sizes of the tensors in `example_inputs` during runtime, as
# long as they adhere to the rules specified in the dynamic shape configuration.
# Here we set the range of 0th model input's 1st dimension as
# [0, model.config.block_size].
# See https://pytorch.ac.cn/executorch/main/concepts.html#dynamic-shapes
# for details about creating dynamic shapes.
dynamic_shape = (
    {1: torch.export.Dim("token_dim", max=model.config.block_size - 1)},
)

# Trace the model, converting it to a portable intermediate representation.
# The torch.no_grad() call tells PyTorch to exclude training-specific logic.
with torch.nn.attention.sdpa_kernel([SDPBackend.MATH]), torch.no_grad():
    m = capture_pre_autograd_graph(model, example_inputs, dynamic_shapes=dynamic_shape)
    traced_model = export(m, example_inputs, dynamic_shapes=dynamic_shape)

# Convert the model into a runnable ExecuTorch program.
# To be further lowered to Xnnpack backend, `traced_model` needs xnnpack-specific edge compile config
edge_config = get_xnnpack_edge_compile_config()
edge_manager = to_edge(traced_model, compile_config=edge_config)

# Delegate exported model to Xnnpack backend by invoking `to_backend` function with Xnnpack partitioner.
edge_manager = edge_manager.to_backend(XnnpackPartitioner())
et_program = edge_manager.to_executorch()

# Save the Xnnpack-delegated ExecuTorch program to a file.
with open("nanogpt.pte", "wb") as file:
    file.write(et_program.buffer)


此外,更新 CMakeLists.txt 以构建和链接 XNNPACK 后端到 ExecuTorch 运行器。

cmake_minimum_required(VERSION 3.19)
project(nanogpt_runner)

set(CMAKE_CXX_STANDARD 17)
set(CMAKE_CXX_STANDARD_REQUIRED True)

# Set options for executorch build.
option(EXECUTORCH_BUILD_EXTENSION_DATA_LOADER "" ON)
option(EXECUTORCH_BUILD_EXTENSION_MODULE "" ON)
option(EXECUTORCH_BUILD_OPTIMIZED "" ON)
option(EXECUTORCH_BUILD_XNNPACK "" ON) # Build with Xnnpack backend

# Include the executorch subdirectory.
add_subdirectory(
    ${CMAKE_CURRENT_SOURCE_DIR}/third-party/executorch
    ${CMAKE_BINARY_DIR}/executorch)

# include_directories(${CMAKE_CURRENT_SOURCE_DIR}/src)

add_executable(nanogpt_runner main.cpp)
target_link_libraries(
    nanogpt_runner
    PRIVATE
    executorch
    extension_module_static # Provides the Module class
    optimized_native_cpu_ops_lib # Provides baseline cross-platform kernels
    xnnpack_backend) # Provides the XNNPACK CPU acceleration backend

保持其余代码不变。有关更多详细信息,请参阅 导出到 ExecuTorch调用运行时 以获取更多详细信息

此时,工作目录应包含以下文件

  • CMakeLists.txt

  • main.cpp

  • basic_tokenizer.h

  • basic_sampler.h

  • managed_tensor.h

  • export_nanogpt.py

  • model.py

  • vocab.json

如果所有这些都存在,你现在可以导出 Xnnpack 委托的 pte 模型

python export_nanogpt.py

它将在同一工作目录下生成 nanogpt.pte

然后,我们可以通过以下方式构建和运行模型

(rm -rf cmake-out && mkdir cmake-out && cd cmake-out && cmake ..)
cmake --build cmake-out -j10
./cmake-out/nanogpt_runner

你应该看到消息

Enter model prompt:

为模型输入一些种子文本并按回车键。这里我们使用“Hello world!”作为示例提示

Enter model prompt: Hello world!
Hello world!

I'm not sure if you've heard of the "Curse of the Dragon" or not, but it's a very popular game in

与非委托模型相比,委托模型应该明显更快。

有关后端委托的更多信息,请参阅 ExecuTorch 指南,了解 XNNPACK 后端Core ML 后端

量化

量化是指使用较低精度类型进行计算和存储张量的技术集合。与 32 位浮点数相比,使用 8 位整数可以显着提高速度并减少内存使用量。有许多量化模型的方法,它们在所需的预处理量、使用的数据类型以及对模型准确性和性能的影响方面各不相同。

由于移动设备上的计算和内存受到高度限制,因此在消费电子产品上发布大型模型需要某种形式的量化。特别是,大型语言模型(例如 Llama2)可能需要将模型权重量化为 4 位或更少。

利用量化需要在导出之前转换模型。PyTorch 提供了 pt2e(PyTorch 2 Export)API 以用于此目的。此示例使用 XNNPACK 委托针对 CPU 加速。因此,它需要使用特定于 XNNPACK 的量化器。针对不同的后端需要使用相应的量化器。

要使用 XNNPACK 委托进行 8 位整数动态量化,请调用 prepare_pt2e,通过使用代表性输入运行来校准模型,然后调用 convert_pt2e。这会更新计算图以在可用时使用量化算子。

# export_nanogpt.py

from executorch.backends.transforms.duplicate_dynamic_quant_chain import (
    DuplicateDynamicQuantChainPass,
)
from torch.ao.quantization.quantizer.xnnpack_quantizer import (
    get_symmetric_quantization_config,
    XNNPACKQuantizer,
)
from torch.ao.quantization.quantize_pt2e import convert_pt2e, prepare_pt2e
# Use dynamic, per-channel quantization.
xnnpack_quant_config = get_symmetric_quantization_config(
    is_per_channel=True, is_dynamic=True
)
xnnpack_quantizer = XNNPACKQuantizer()
xnnpack_quantizer.set_global(xnnpack_quant_config)

m = capture_pre_autograd_graph(model, example_inputs)

# Annotate the model for quantization. This prepares the model for calibration.
m = prepare_pt2e(m, xnnpack_quantizer)

# Calibrate the model using representative inputs. This allows the quantization
# logic to determine the expected range of values in each tensor.
m(*example_inputs)

# Perform the actual quantization.
m = convert_pt2e(m, fold_quantize=False)
DuplicateDynamicQuantChainPass()(m)

traced_model = export(m, example_inputs)

此外,添加或更新 to_backend() 调用以使用 XnnpackPartitioner。这指示 ExecuTorch 通过 XNNPACK 后端优化模型以进行 CPU 执行。

from executorch.backends.xnnpack.partition.xnnpack_partitioner import (
    XnnpackPartitioner,
)
edge_manager = to_edge(traced_model, compile_config=edge_config)
edge_manager = edge_manager.to_backend(XnnpackPartitioner()) # Lower to XNNPACK.
et_program = edge_manager.to_executorch()

最后,确保运行器在 CMakeLists.txt 中链接到 xnnpack_backend 目标。

add_executable(nanogpt_runner main.cpp)
target_link_libraries(
    nanogpt_runner
    PRIVATE
    executorch
    extension_module_static # Provides the Module class
    optimized_native_cpu_ops_lib # Provides baseline cross-platform kernels
    xnnpack_backend) # Provides the XNNPACK CPU acceleration backend

有关更多信息,请参阅 ExecuTorch 中的量化

分析和调试

通过调用 to_backend() 降低模型后,你可能希望了解哪些内容已委托,哪些内容未委托。ExecuTorch 提供实用方法来深入了解委托。你可以使用此信息来深入了解底层计算并诊断潜在的性能问题。模型作者可以使用此信息以与目标后端兼容的方式构建模型。

可视化委托

get_delegation_info() 方法提供 to_backend() 调用后模型发生情况的摘要

from executorch.exir.backend.utils import get_delegation_info
from tabulate import tabulate

# ... After call to to_backend(), but before to_executorch()
graph_module = edge_manager.exported_program().graph_module
delegation_info = get_delegation_info(graph_module)
print(delegation_info.get_summary())
df = delegation_info.get_operator_delegation_dataframe()
print(tabulate(df, headers="keys", tablefmt="fancy_grid"))

对于针对 XNNPACK 后端的 nanoGPT,你可能会看到以下内容

Total  delegated  subgraphs:  86
Number  of  delegated  nodes:  473
Number  of  non-delegated  nodes:  430

op_type

# in_delegated_graphs

# in_non_delegated_graphs

0

aten__softmax_default

12

0

1

aten_add_tensor

37

0

2

aten_addmm_default

48

0

3

aten_arange_start_step

0

25

23

aten_view_copy_default

170

48

26

Total

473

430

从表格中,运算符 aten_view_copy_default 在委托图中出现 170 次,在非委托图中出现 48 次。要查看更详细的视图,请使用 print_delegated_graph() 方法显示整个图的打印输出。

from executorch.exir.backend.utils import print_delegated_graph
graph_module = edge_manager.exported_program().graph_module
print(print_delegated_graph(graph_module))

对于大型模型,这可能会生成大量的输出。考虑使用“Control+F”或“Command+F”来找到你感兴趣的运算符(例如“aten_view_copy_default”)。观察哪些实例不在降低的图中。

在下面 nanoGPT 的输出片段中,观察到嵌入和加法运算符已委托给 XNNPACK,而减法运算符则没有。

%aten_unsqueeze_copy_default_22 : [num_users=1] = call_function[target=executorch.exir.dialects.edge._ops.aten.unsqueeze_copy.default](args = (%aten_arange_start_step_23, -2), kwargs = {})
  %aten_unsqueeze_copy_default_23 : [num_users=1] = call_function[target=executorch.exir.dialects.edge._ops.aten.unsqueeze_copy.default](args = (%aten_arange_start_step_24, -1), kwargs = {})
  %lowered_module_0 : [num_users=1] = get_attr[target=lowered_module_0]
    backend_id: XnnpackBackend
    lowered graph():
      %aten_embedding_default : [num_users=1] = placeholder[target=aten_embedding_default]
      %aten_embedding_default_1 : [num_users=1] = placeholder[target=aten_embedding_default_1]
      %aten_add_tensor : [num_users=1] = call_function[target=executorch.exir.dialects.edge._ops.aten.add.Tensor](args = (%aten_embedding_default, %aten_embedding_default_1), kwargs = {})
      return (aten_add_tensor,)
  %executorch_call_delegate : [num_users=1] = call_function[target=torch.ops.higher_order.executorch_call_delegate](args = (%lowered_module_0, %aten_embedding_default, %aten_embedding_default_1), kwargs = {})
  %aten_sub_tensor : [num_users=1] = call_function[target=executorch.exir.dialects.edge._ops.aten.sub.Tensor](args = (%aten_unsqueeze_copy_default, %aten_unsqueeze_copy_default_1), kwargs = {})

性能分析

通过 ExecuTorch SDK,用户能够分析模型执行,并提供模型中每个运算符的时间信息。

先决条件

ETRecord 生成(可选)

ETRecord 是在导出时生成的工件,其中包含模型图和源级元数据,将 ExecuTorch 程序链接到原始 PyTorch 模型。即使没有 ETRecord,您也可以查看所有分析事件,但有了 ETRecord,您还可以将每个事件链接到正在执行的操作员类型、模块层次结构和原始 PyTorch 源代码的堆栈跟踪。有关更多信息,请参阅ETRecord 文档

在您的导出脚本中,在调用to_edge()to_executorch()后,使用to_edge()中的EdgeProgramManagerto_executorch()中的ExecuTorchProgramManager调用generate_etrecord()。确保复制EdgeProgramManager,因为对to_backend()的调用会就地改变图。

import copy
from executorch.sdk import generate_etrecord

# Make the deep copy immediately after to to_edge()
edge_manager_copy = copy.deepcopy(edge_manager)

# ...
# Generate ETRecord right after to_executorch()
etrecord_path = "etrecord.bin"
generate_etrecord(etrecord_path, edge_manager_copy, et_program)

运行导出脚本,ETRecord 将作为etrecord.bin生成。

ETDump 生成

ETDump 是在运行时生成的工件,其中包含模型执行的跟踪。有关更多信息,请参阅ETDump 文档

在您的代码中包含 ETDump 头文件。

// main.cpp

#include <executorch/sdk/etdump/etdump_flatcc.h>

创建 ETDumpGen 类的实例并将其传递给 Module 构造函数。

std::unique_ptr<torch::executor::ETDumpGen> etdump_gen_ = std::make_unique<torch::executor::ETDumpGen>();
Module model("nanogpt.pte", torch::executor::Module::MlockConfig::UseMlockIgnoreErrors, std::move(etdump_gen_));

在调用generate()后,将 ETDump 保存到文件中。如果需要,您可以在单个跟踪中捕获多个模型运行。

torch::executor::ETDumpGen* etdump_gen =
    static_cast<torch::executor::ETDumpGen*>(model.event_tracer());

ET_LOG(Info, "ETDump size: %zu blocks", etdump_gen->get_num_blocks());
etdump_result result = etdump_gen->get_etdump_data();
if (result.buf != nullptr && result.size > 0) {
    // On a device with a file system, users can just write it to a file.
    FILE* f = fopen("etdump.etdp", "w+");
    fwrite((uint8_t*)result.buf, 1, result.size, f);
    fclose(f);
    free(result.buf);
}

此外,更新 CMakeLists.txt 以使用 SDK 构建并启用将事件跟踪并记录到 ETDump 中

option(EXECUTORCH_BUILD_SDK "" ON)

# ...

target_link_libraries(
    nanogpt_runner
    PRIVATE
    executorch
    extension_module_static # Provides the Module class
    optimized_native_cpu_ops_lib # Provides baseline cross-platform kernels
    xnnpack_backend # Provides the XNNPACK CPU acceleration backend
    etdump) # Provides event tracing and logging

target_compile_options(executorch PUBLIC -DET_EVENT_TRACER_ENABLED)
target_compile_options(portable_ops_lib PUBLIC -DET_EVENT_TRACER_ENABLED)

运行运行程序,您将看到生成的“etdump.etdp”。

使用 Inspector API 分析

收集调试工件 ETDump(以及可选的 ETRecord)后,您可以使用 Inspector API 查看性能信息。

from executorch.sdk import Inspector

inspector = Inspector(etdump_path="etdump.etdp")
# If you also generated an ETRecord, then pass that in as well: `inspector = Inspector(etdump_path="etdump.etdp", etrecord="etrecord.bin")`

with open("inspector_out.txt", "w") as file:
    inspector.print_data_tabular(file)

这会以表格格式在“inspector_out.txt”中打印性能数据,其中每一行都是一个分析事件。前几行如下所示: 以全尺寸查看

如需详细了解 Inspector 及其提供的丰富功能,请参阅 Inspector API 参考

自定义内核

借助 ExecuTorch 自定义算子 API,自定义算子和内核的作者可以轻松地将自己的内核引入 PyTorch/ExecuTorch。

在 ExecuTorch 中使用自定义内核有三个步骤

  1. 使用 ExecuTorch 类型编写自定义内核。

  2. 将自定义内核编译并链接到 AOT Python 环境和运行时二进制文件。

  3. 源到源转换,用自定义 op 替换算子。

编写自定义内核

为函数变体(用于 AOT 编译)和 out 变体(用于 ExecuTorch 运行时)定义自定义算子架构。该架构需要遵循 PyTorch ATen 约定(请参阅 native_functions.yaml)。

custom_linear(Tensor weight, Tensor input, Tensor(?) bias) -> Tensor

custom_linear.out(Tensor weight, Tensor input, Tensor(?) bias, *, Tensor(a!) out) -> Tensor(a!)

根据上面定义的架构编写自定义内核。使用 EXECUTORCH_LIBRARY 宏使内核可用于 ExecuTorch 运行时。

// custom_linear.h / custom_linear.cpp
#include <executorch/runtime/kernel/kernel_includes.h>

Tensor& custom_linear_out(const Tensor& weight, const Tensor& input, optional<Tensor> bias, Tensor& out) {
    // calculation
    return out;
}

// Register as myop::custom_linear.out
EXECUTORCH_LIBRARY(myop, "custom_linear.out", custom_linear_out);

要使此算子在 PyTorch 中可用,您可以在 ExecuTorch 自定义内核周围定义一个包装器。请注意,ExecuTorch 实现使用 ExecuTorch 张量类型,而 PyTorch 包装器使用 ATen 张量。

// custom_linear_pytorch.cpp

#include "custom_linear.h"
#include <torch/library.h>

at::Tensor custom_linear(const at::Tensor& weight, const at::Tensor& input, std::optional<at::Tensor> bias) {

    // initialize out
    at::Tensor out = at::empty({weight.size(1), input.size(1)});

    // wrap kernel in custom_linear.cpp into ATen kernel
    WRAP_TO_ATEN(custom_linear_out, 3)(weight, input, bias, out);

    return out;
}

// Register the operator with PyTorch.
TORCH_LIBRARY(myop,  m) {
    m.def("custom_linear(Tensor weight, Tensor input, Tensor(?) bias) -> Tensor", custom_linear);
    m.def("custom_linear.out(Tensor weight, Tensor input, Tensor(?) bias, *, Tensor(a!) out) -> Tensor(a!)", WRAP_TO_ATEN(custom_linear_out, 3));
}

在模型中使用自定义算子

可以在 PyTorch 模型中显式使用自定义算子,或者编写转换以使用自定义变体替换核心算子的实例。对于此示例,您可以查找 torch.nn.Linear 的所有实例,并将它们替换为 CustomLinear

def  replace_linear_with_custom_linear(module):
    for name, child in module.named_children():
        if isinstance(child, nn.Linear):
            setattr(
                module,
                name,
                CustomLinear(child.in_features,  child.out_features, child.bias),
        )
        else:
            replace_linear_with_custom_linear(child)

其余步骤与正常流程相同。现在,您可以在急切模式下运行此模块,也可以导出到 ExecuTorch。

如何构建移动应用程序

请参阅有关在 iOS 和 Android 上使用 ExecuTorch 构建和运行 LLM 的说明。

文档

访问 PyTorch 的综合开发者文档

查看文档

教程

获取针对初学者和高级开发者的深入教程

查看教程

资源

查找开发资源并获得问题的解答

查看资源