ExecuTorch 中的 LLMs 介绍¶
欢迎阅读 LLM 手册!本手册旨在提供一个实用示例,展示如何利用 ExecuTorch 来加载您自己的大型语言模型 (LLMs)。我们的主要目标是提供清晰简洁的指南,说明如何将我们的系统与您自己的 LLMs 集成。
请注意,本项目旨在作为演示,而非具有最佳性能的完整功能示例。因此,诸如采样器、分词器等某些组件仅以最简版本提供,纯粹用于演示目的。因此,模型产生的结果可能会有所不同,并且并非总是最优的。
我们鼓励用户将本项目作为起点,并根据其特定需求进行调整,包括创建您自己的分词器、采样器、加速后端及其他组件版本。我们希望本项目能为您的 LLMs 和 ExecuTorch 之旅提供有益指导。
要以最佳性能部署 Llama,请参阅Llama 指南。
目录¶
前提条件
Hello World 示例
量化
使用移动端加速
调试和性能分析
如何使用自定义内核
如何构建移动应用
前提条件¶
要遵循本指南,您需要克隆 ExecuTorch 仓库并安装依赖项。ExecuTorch 推荐使用 Python 3.10 并使用 Conda 管理您的环境。虽然不强制要求使用 Conda,但请注意,根据您的环境,您可能需要将 python/pip 替换为 python3/pip3。
可以在此处找到安装 miniconda 的说明。
# Create a directory for this example.
mkdir et-nanogpt
cd et-nanogpt
# Clone the ExecuTorch repository.
mkdir third-party
git clone -b release/0.6 https://github.com/pytorch/executorch.git third-party/executorch && cd third-party/executorch
# Create either a Python virtual environment:
python3 -m venv .venv && source .venv/bin/activate && pip install --upgrade pip
# Or a Conda environment:
conda create -yn executorch python=3.10.0 && conda activate executorch
# Install requirements
./install_executorch.sh
cd ../..
可以在此处找到安装 pyenv-virtualenv 的说明。
重要的是,如果通过 brew 安装 pyenv,它不会自动在终端中启用 pyenv,从而导致错误。运行以下命令启用。请参阅上面的 pyenv-virtualenv 安装指南,了解如何将其添加到您的 .bashrc 或 .zshrc 中,以避免手动运行这些命令。
eval "$(pyenv init -)"
eval "$(pyenv virtualenv-init -)"
# Create a directory for this example.
mkdir et-nanogpt
cd et-nanogpt
pyenv install -s 3.10
pyenv virtualenv 3.10 executorch
pyenv activate executorch
# Clone the ExecuTorch repository.
git clone -b release/0.6 https://github.com/pytorch/executorch.git third-party/executorch && cd third-party/executorch
# Install requirements.
PYTHON_EXECUTABLE=python ./install_executorch.sh
cd ../..
更多信息请参阅设置 ExecuTorch。
在本地运行大型语言模型¶
本示例使用 Karpathy 的 nanoGPT,它是 GPT-2 124M 的一个最小实现。本指南适用于其他语言模型,因为 ExecuTorch 不依赖于具体的模型。
使用 ExecuTorch 运行模型有两个步骤
导出模型。此步骤将模型预处理为适合运行时执行的格式。
在运行时,加载模型文件并使用 ExecuTorch 运行时执行。
导出步骤提前进行,通常作为应用构建的一部分或在模型更改时进行。生成的 .pte 文件随应用一起分发。在运行时,应用加载 .pte 文件并将其传递给 ExecuTorch 运行时。
步骤 1. 导出到 ExecuTorch¶
导出过程将 PyTorch 模型转换为可在消费设备上高效运行的格式。
对于此示例,您需要 nanoGPT 模型和相应的分词器词汇表。
curl https://raw.githubusercontent.com/karpathy/nanoGPT/master/model.py -O
curl https://huggingface.co/openai-community/gpt2/resolve/main/vocab.json -O
wget https://raw.githubusercontent.com/karpathy/nanoGPT/master/model.py
wget https://huggingface.co/openai-community/gpt2/resolve/main/vocab.json
要将模型转换为针对独立执行优化的格式,需要执行两个步骤。首先,使用 PyTorch 的 export
函数将 PyTorch 模型转换为中间的、平台无关的中间表示。然后使用 ExecuTorch 的 to_edge
和 to_executorch
方法准备模型以供设备端执行。这将创建一个 .pte 文件,该文件可以在运行时由桌面或移动应用加载。
创建一个名为 export_nanogpt.py 的文件,内容如下
# export_nanogpt.py
import torch
from executorch.exir import EdgeCompileConfig, to_edge
from torch.nn.attention import sdpa_kernel, SDPBackend
from torch.export import export, export_for_training
from model import GPT
# Load the model.
model = GPT.from_pretrained('gpt2')
# Create example inputs. This is used in the export process to provide
# hints on the expected shape of the model input.
example_inputs = (torch.randint(0, 100, (1, model.config.block_size), dtype=torch.long), )
# Set up dynamic shape configuration. This allows the sizes of the input tensors
# to differ from the sizes of the tensors in `example_inputs` during runtime, as
# long as they adhere to the rules specified in the dynamic shape configuration.
# Here we set the range of 0th model input's 1st dimension as
# [0, model.config.block_size].
# See https://pytorch.ac.cn/executorch/0.6/concepts#dynamic-shapes
# for details about creating dynamic shapes.
dynamic_shape = (
{1: torch.export.Dim("token_dim", max=model.config.block_size)},
)
# Trace the model, converting it to a portable intermediate representation.
# The torch.no_grad() call tells PyTorch to exclude training-specific logic.
with torch.nn.attention.sdpa_kernel([SDPBackend.MATH]), torch.no_grad():
m = export_for_training(model, example_inputs, dynamic_shapes=dynamic_shape).module()
traced_model = export(m, example_inputs, dynamic_shapes=dynamic_shape)
# Convert the model into a runnable ExecuTorch program.
edge_config = EdgeCompileConfig(_check_ir_validity=False)
edge_manager = to_edge(traced_model, compile_config=edge_config)
et_program = edge_manager.to_executorch()
# Save the ExecuTorch program to a file.
with open("nanogpt.pte", "wb") as file:
file.write(et_program.buffer)
要进行导出,使用 python export_nanogpt.py
(或 python3,取决于您的环境)运行脚本。这将在当前目录中生成一个 nanogpt.pte
文件。
更多信息请参阅导出到 ExecuTorch 和 torch.export。
步骤 2. 调用运行时¶
ExecuTorch 提供了一套运行时 API 和类型来加载和运行模型。
创建一个名为 main.cpp 的文件,内容如下
// main.cpp
#include <cstdint>
#include "basic_sampler.h"
#include "basic_tokenizer.h"
#include <executorch/extension/module/module.h>
#include <executorch/extension/tensor/tensor.h>
#include <executorch/runtime/core/evalue.h>
#include <executorch/runtime/core/exec_aten/exec_aten.h>
#include <executorch/runtime/core/result.h>
using executorch::aten::ScalarType;
using executorch::aten::Tensor;
using executorch::extension::from_blob;
using executorch::extension::Module;
using executorch::runtime::EValue;
using executorch::runtime::Result;
模型的输入和输出采用张量形式。张量可以被视为一个多维数组。ExecuTorch 的 EValue
类提供了对张量和其他 ExecuTorch 数据类型的包装。
由于 LLM 一次生成一个 Token,驱动代码需要重复调用模型,逐个 Token 构建输出。每个生成的 Token 都作为下一次运行的输入被传递。
// main.cpp
// The value of the gpt2 `<|endoftext|>` token.
#define ENDOFTEXT_TOKEN 50256
std::string generate(
Module& llm_model,
std::string& prompt,
BasicTokenizer& tokenizer,
BasicSampler& sampler,
size_t max_input_length,
size_t max_output_length) {
// Convert the input text into a list of integers (tokens) that represents it,
// using the string-to-token mapping that the model was trained on. Each token
// is an integer that represents a word or part of a word.
std::vector<int64_t> input_tokens = tokenizer.encode(prompt);
std::vector<int64_t> output_tokens;
for (auto i = 0u; i < max_output_length; i++) {
// Convert the input_tokens from a vector of int64_t to EValue. EValue is a
// unified data type in the ExecuTorch runtime.
auto inputs = from_blob(
input_tokens.data(),
{1, static_cast<int>(input_tokens.size())},
ScalarType::Long);
// Run the model. It will return a tensor of logits (log-probabilities).
auto logits_evalue = llm_model.forward(inputs);
// Convert the output logits from EValue to std::vector, which is what the
// sampler expects.
Tensor logits_tensor = logits_evalue.get()[0].toTensor();
std::vector<float> logits(
logits_tensor.data_ptr<float>(),
logits_tensor.data_ptr<float>() + logits_tensor.numel());
// Sample the next token from the logits.
int64_t next_token = sampler.sample(logits);
// Break if we reached the end of the text.
if (next_token == ENDOFTEXT_TOKEN) {
break;
}
// Add the next token to the output.
output_tokens.push_back(next_token);
std::cout << tokenizer.decode({next_token});
std::cout.flush();
// Update next input.
input_tokens.push_back(next_token);
if (input_tokens.size() > max_input_length) {
input_tokens.erase(input_tokens.begin());
}
}
std::cout << std::endl;
// Convert the output tokens into a human-readable string.
std::string output_string = tokenizer.decode(output_tokens);
return output_string;
}
的 Module
类负责加载 .pte 文件并准备执行。
分词器负责将人类可读的提示字符串表示形式转换为模型所需的数值形式。为此,分词器将短子字符串与给定的 Token ID 相关联。Token 可以被认为是表示单词或单词的一部分,尽管实际上,它们可能是任意的字符序列。
分词器从文件中加载词汇表,该文件包含每个 Token ID 及其表示的文本之间的映射。调用 tokenizer.encode()
和 tokenizer.decode()
可在字符串和 Token 表示形式之间进行转换。
采样器负责根据模型输出的 logits(即对数概率)选择下一个 Token。LLM 为每个可能的下一个 Token 返回一个 logit 值。采样器根据某种策略选择要使用的 Token。此处使用的最简单方法是选择具有最高 logit 值的 Token。
采样器可以提供可配置的选项,例如输出选择的可配置随机性、重复 Token 的惩罚以及优先或非优先处理特定 Token 的偏置。
// main.cpp
int main() {
// Set up the prompt. This provides the seed text for the model to elaborate.
std::cout << "Enter model prompt: ";
std::string prompt;
std::getline(std::cin, prompt);
// The tokenizer is used to convert between tokens (used by the model) and
// human-readable strings.
BasicTokenizer tokenizer("vocab.json");
// The sampler is used to sample the next token from the logits.
BasicSampler sampler = BasicSampler();
// Load the exported nanoGPT program, which was generated via the previous
// steps.
Module model("nanogpt.pte", Module::LoadMode::MmapUseMlockIgnoreErrors);
const auto max_input_tokens = 1024;
const auto max_output_tokens = 30;
std::cout << prompt;
generate(
model, prompt, tokenizer, sampler, max_input_tokens, max_output_tokens);
}
最后,将以下文件下载到与 main.cpp 相同的目录中
curl -O https://raw.githubusercontent.com/pytorch/executorch/main/examples/llm_manual/basic_sampler.h
curl -O https://raw.githubusercontent.com/pytorch/executorch/main/examples/llm_manual/basic_tokenizer.h
要了解更多信息,请参阅运行时 API 教程。
构建和运行¶
ExecuTorch 使用 CMake 构建系统。要编译和链接 ExecuTorch 运行时,请通过 add_directory
包含 ExecuTorch 项目,并链接 executorch
和附加依赖项。
创建一个名为 CMakeLists.txt 的文件,内容如下
# CMakeLists.txt
cmake_minimum_required(VERSION 3.19)
project(nanogpt_runner)
set(CMAKE_CXX_STANDARD 17)
set(CMAKE_CXX_STANDARD_REQUIRED True)
# Set options for executorch build.
option(EXECUTORCH_ENABLE_LOGGING "" ON)
option(EXECUTORCH_BUILD_EXTENSION_DATA_LOADER "" ON)
option(EXECUTORCH_BUILD_EXTENSION_MODULE "" ON)
option(EXECUTORCH_BUILD_EXTENSION_TENSOR "" ON)
option(EXECUTORCH_BUILD_KERNELS_OPTIMIZED "" ON)
# Include the executorch subdirectory.
add_subdirectory(
${CMAKE_CURRENT_SOURCE_DIR}/third-party/executorch
${CMAKE_BINARY_DIR}/executorch
)
add_executable(nanogpt_runner main.cpp)
target_link_libraries(
nanogpt_runner
PRIVATE executorch
extension_module_static # Provides the Module class
extension_tensor # Provides the TensorPtr class
optimized_native_cpu_ops_lib # Provides baseline cross-platform
# kernels
)
此时,工作目录应包含以下文件
CMakeLists.txt
main.cpp
basic_tokenizer.h
basic_sampler.h
export_nanogpt.py
model.py
vocab.json
nanogpt.pte
如果这些文件都已存在,您现在可以构建并运行
(mkdir cmake-out && cd cmake-out && cmake ..)
cmake --build cmake-out -j10
./cmake-out/nanogpt_runner
您应该看到消息
Enter model prompt:
为模型输入一些种子文本并按回车键。这里我们使用“Hello world!”作为示例提示。
Enter model prompt: Hello world!
Hello world!
I'm not sure if you've heard of the "Curse of the Dragon" or not, but it's a very popular game in
此时,它可能会运行得非常慢。这是因为 ExecuTorch 尚未被告知针对特定硬件进行优化(Delegation),并且它正在进行所有 32 位浮点计算(没有量化)。
Delegation¶
虽然 ExecuTorch 为所有算子提供了可移植的、跨平台的实现,但它也为许多不同的目标提供了专门的后端。这些后端包括但不限于通过 XNNPACK 后端实现的 x86 和 ARM CPU 加速,通过 Core ML 后端和 Metal Performance Shader (MPS) 后端实现的 Apple 加速,以及通过 Vulkan 后端实现的 GPU 加速。
由于优化是针对特定后端的,因此每个 pte 文件都是针对导出时指定的一个或多个后端。为了支持多种设备,例如针对 Android 的 XNNPACK 加速和针对 iOS 的 Core ML,需要为每个后端导出一个单独的 PTE 文件。
在导出过程中将模型 Delegation 到特定后端时,ExecuTorch 使用 to_edge_transform_and_lower()
函数。此函数接受从 torch.export
导出的程序以及一个后端特定的 partitioner 对象。Partitioner 识别计算图中可以由目标后端优化的部分。在 to_edge_transform_and_lower()
中,导出的程序被转换为 Edge Dialect 程序。然后,Partitioner 将兼容的图部分 Delegation 给后端进行加速和优化。任何未 Delegation 的图部分将由 ExecuTorch 的默认算子实现执行。
要将导出的模型 Delegation 到特定后端,我们需要首先从 ExecuTorch 代码库导入其 partitioner 以及 Edge 编译配置,然后调用 to_edge_transform_and_lower
。
以下是如何将 nanoGPT Delegation 给 XNNPACK 的示例(例如,如果您要部署到 Android 手机)
# export_nanogpt.py
# Load partitioner for Xnnpack backend
from executorch.backends.xnnpack.partition.xnnpack_partitioner import XnnpackPartitioner
# Model to be delegated to specific backend should use specific edge compile config
from executorch.backends.xnnpack.utils.configs import get_xnnpack_edge_compile_config
from executorch.exir import EdgeCompileConfig, to_edge_transform_and_lower
import torch
from torch.export import export
from torch.nn.attention import sdpa_kernel, SDPBackend
from torch.export import export_for_training
from model import GPT
# Load the nanoGPT model.
model = GPT.from_pretrained('gpt2')
# Create example inputs. This is used in the export process to provide
# hints on the expected shape of the model input.
example_inputs = (
torch.randint(0, 100, (1, model.config.block_size - 1), dtype=torch.long),
)
# Set up dynamic shape configuration. This allows the sizes of the input tensors
# to differ from the sizes of the tensors in `example_inputs` during runtime, as
# long as they adhere to the rules specified in the dynamic shape configuration.
# Here we set the range of 0th model input's 1st dimension as
# [0, model.config.block_size].
# See https://pytorch.ac.cn/executorch/0.6/concepts.html#dynamic-shapes
# for details about creating dynamic shapes.
dynamic_shape = (
{1: torch.export.Dim("token_dim", max=model.config.block_size - 1)},
)
# Trace the model, converting it to a portable intermediate representation.
# The torch.no_grad() call tells PyTorch to exclude training-specific logic.
with torch.nn.attention.sdpa_kernel([SDPBackend.MATH]), torch.no_grad():
m = export_for_training(model, example_inputs, dynamic_shapes=dynamic_shape).module()
traced_model = export(m, example_inputs, dynamic_shapes=dynamic_shape)
# Convert the model into a runnable ExecuTorch program.
# To be further lowered to Xnnpack backend, `traced_model` needs xnnpack-specific edge compile config
edge_config = get_xnnpack_edge_compile_config()
# Converted to edge program and then delegate exported model to Xnnpack backend
# by invoking `to` function with Xnnpack partitioner.
edge_manager = to_edge_transform_and_lower(traced_model, partitioner = [XnnpackPartitioner()], compile_config = edge_config)
et_program = edge_manager.to_executorch()
# Save the Xnnpack-delegated ExecuTorch program to a file.
with open("nanogpt.pte", "wb") as file:
file.write(et_program.buffer)
此外,更新 CMakeLists.txt 以构建 XNNPACK 后端并将其链接到 ExecuTorch 运行器。
cmake_minimum_required(VERSION 3.19)
project(nanogpt_runner)
set(CMAKE_CXX_STANDARD 17)
set(CMAKE_CXX_STANDARD_REQUIRED True)
# Set options for executorch build.
option(EXECUTORCH_ENABLE_LOGGING "" ON)
option(EXECUTORCH_BUILD_EXTENSION_DATA_LOADER "" ON)
option(EXECUTORCH_BUILD_EXTENSION_MODULE "" ON)
option(EXECUTORCH_BUILD_EXTENSION_TENSOR "" ON)
option(EXECUTORCH_BUILD_KERNELS_OPTIMIZED "" ON)
option(EXECUTORCH_BUILD_XNNPACK "" ON) # Build with Xnnpack backend
# Include the executorch subdirectory.
add_subdirectory(
${CMAKE_CURRENT_SOURCE_DIR}/third-party/executorch
${CMAKE_BINARY_DIR}/executorch
)
add_executable(nanogpt_runner main.cpp)
target_link_libraries(
nanogpt_runner
PRIVATE executorch
extension_module_static # Provides the Module class
extension_tensor # Provides the TensorPtr class
optimized_native_cpu_ops_lib # Provides baseline cross-platform
# kernels
xnnpack_backend # Provides the XNNPACK CPU acceleration backend
)
保持代码的其余部分不变。更多详细信息请参阅导出到 ExecuTorch 和 调用运行时。
此时,工作目录应包含以下文件
CMakeLists.txt
main.cpp
basic_tokenizer.h
basic_sampler.h
export_nanogpt.py
model.py
vocab.json
如果这些都存在,您现在可以导出经过 Xnnpack Delegation 的 pte 模型
python export_nanogpt.py
它将在相同的工作目录下生成 nanogpt.pte
。
然后我们可以通过以下方式构建并运行模型
(rm -rf cmake-out && mkdir cmake-out && cd cmake-out && cmake ..)
cmake --build cmake-out -j10
./cmake-out/nanogpt_runner
您应该看到消息
Enter model prompt:
为模型输入一些种子文本并按回车键。这里我们使用“Hello world!”作为示例提示。
Enter model prompt: Hello world!
Hello world!
I'm not sure if you've heard of the "Curse of the Dragon" or not, but it's a very popular game in
经过 Delegation 的模型应该比未 Delegation 的模型明显更快。
有关后端 Delegation 的更多信息,请参阅 ExecuTorch 关于XNNPACK 后端、Core ML 后端 和 Qualcomm AI Engine Direct 后端的指南。
量化¶
量化是指使用较低精度类型执行计算和存储张量的一系列技术。与 32 位浮点相比,使用 8 位整数可以显著提高速度并减少内存使用。量化模型有多种方法,所需的预处理量、使用的数据类型以及对模型精度和性能的影响各不相同。
由于移动设备上的计算和内存资源高度受限,因此需要某种形式的量化才能在消费电子产品上部署大型模型。特别是大型语言模型,如 Llama2,可能需要将模型权重量化到 4 位或更低。
利用量化需要在导出之前对模型进行转换。PyTorch 为此目的提供了 pt2e (PyTorch 2 Export) API。本示例针对使用 XNNPACK delegate 的 CPU 加速。因此,它需要使用 XNNPACK 特定的量化器。针对不同的后端将需要使用相应的量化器。
要将 8 位整数动态量化与 XNNPACK delegate 结合使用,请调用 prepare_pt2e
,通过使用代表性输入运行来校准模型,然后调用 convert_pt2e
。这会更新计算图,以在可用时使用量化算子。
# export_nanogpt.py
from executorch.backends.transforms.duplicate_dynamic_quant_chain import (
DuplicateDynamicQuantChainPass,
)
from executorch.backends.xnnpack.quantizer.xnnpack_quantizer import (
get_symmetric_quantization_config,
XNNPACKQuantizer,
)
from torch.ao.quantization.quantize_pt2e import convert_pt2e, prepare_pt2e
# Use dynamic, per-channel quantization.
xnnpack_quant_config = get_symmetric_quantization_config(
is_per_channel=True, is_dynamic=True
)
xnnpack_quantizer = XNNPACKQuantizer()
xnnpack_quantizer.set_global(xnnpack_quant_config)
m = export_for_training(model, example_inputs).module()
# Annotate the model for quantization. This prepares the model for calibration.
m = prepare_pt2e(m, xnnpack_quantizer)
# Calibrate the model using representative inputs. This allows the quantization
# logic to determine the expected range of values in each tensor.
m(*example_inputs)
# Perform the actual quantization.
m = convert_pt2e(m, fold_quantize=False)
DuplicateDynamicQuantChainPass()(m)
traced_model = export(m, example_inputs)
此外,添加或更新 to_edge_transform_and_lower()
调用以使用 XnnpackPartitioner
。这会指示 ExecuTorch 通过 XNNPACK 后端优化模型以进行 CPU 执行。
from executorch.backends.xnnpack.partition.xnnpack_partitioner import (
XnnpackPartitioner,
)
edge_config = get_xnnpack_edge_compile_config()
# Convert to edge dialect and lower to XNNPack.
edge_manager = to_edge_transform_and_lower(traced_model, partitioner = [XnnpackPartitioner()], compile_config = edge_config)
et_program = edge_manager.to_executorch()
with open("nanogpt.pte", "wb") as file:
file.write(et_program.buffer)
然后运行
python export_nanogpt.py
./cmake-out/nanogpt_runner
更多信息请参阅ExecuTorch 中的量化。
性能分析和调试¶
通过调用 to_edge_transform_and_lower()
对模型进行转换后,您可能想看看哪些部分被 Delegation 了,哪些没有。ExecuTorch 提供了实用方法来深入了解 Delegation 情况。您可以使用这些信息来查看底层计算并诊断潜在的性能问题。模型作者可以使用这些信息来构建与目标后端兼容的模型。
可视化 Delegation 情况¶
get_delegation_info()
方法提供了一个摘要,说明在调用 to_edge_transform_and_lower()
后模型发生了什么
from executorch.devtools.backend_debug import get_delegation_info
from tabulate import tabulate
# ... After call to to_edge_transform_and_lower(), but before to_executorch()
graph_module = edge_manager.exported_program().graph_module
delegation_info = get_delegation_info(graph_module)
print(delegation_info.get_summary())
df = delegation_info.get_operator_delegation_dataframe()
print(tabulate(df, headers="keys", tablefmt="fancy_grid"))
对于面向 XNNPACK 后端的 nanoGPT,您可能会看到以下内容(请注意,下面的数字仅用于说明目的,实际值可能有所不同)
Total delegated subgraphs: 145
Number of delegated nodes: 350
Number of non-delegated nodes: 760
算子类型 |
# 在 Delegation 图中 |
# 在非 Delegation 图中 |
|
---|---|---|---|
0 |
aten__softmax_default |
12 |
0 |
1 |
aten_add_tensor |
37 |
0 |
2 |
aten_addmm_default |
48 |
0 |
3 |
aten_any_dim |
0 |
12 |
… |
|||
25 |
aten_view_copy_default |
96 |
122 |
… |
|||
30 |
总计 |
350 |
760 |
从表中可以看出,算子 aten_view_copy_default
在 Delegation 图中出现 96 次,在非 Delegation 图中出现 122 次。要查看更详细的信息,可以使用 format_delegated_graph()
方法获取整个图的格式化字符串输出,或使用 print_delegated_graph()
直接打印。
from executorch.exir.backend.utils import format_delegated_graph
graph_module = edge_manager.exported_program().graph_module
print(format_delegated_graph(graph_module))
对于大型模型,这可能会生成大量输出。考虑使用“Control+F”或“Command+F”来查找您感兴趣的算子(例如“aten_view_copy_default”)。观察哪些实例不在 lowered 图下。
在下面的 nanoGPT 输出片段中,请注意 transformer 模块已被 Delegation 到 XNNPACK,而 where 算子则没有。
%aten_where_self_22 : [num_users=1] = call_function[target=executorch.exir.dialects.edge._ops.aten.where.self](args = (%aten_logical_not_default_33, %scalar_tensor_23, %scalar_tensor_22), kwargs = {})
%lowered_module_144 : [num_users=1] = get_attr[target=lowered_module_144]
backend_id: XnnpackBackend
lowered graph():
%p_transformer_h_0_attn_c_attn_weight : [num_users=1] = placeholder[target=p_transformer_h_0_attn_c_attn_weight]
%p_transformer_h_0_attn_c_attn_bias : [num_users=1] = placeholder[target=p_transformer_h_0_attn_c_attn_bias]
%getitem : [num_users=1] = placeholder[target=getitem]
%sym_size : [num_users=2] = placeholder[target=sym_size]
%aten_view_copy_default : [num_users=1] = call_function[target=executorch.exir.dialects.edge._ops.aten.view_copy.default](args = (%getitem, [%sym_size, 768]), kwargs = {})
%aten_permute_copy_default : [num_users=1] = call_function[target=executorch.exir.dialects.edge._ops.aten.permute_copy.default](args = (%p_transformer_h_0_attn_c_attn_weight, [1, 0]), kwargs = {})
%aten_addmm_default : [num_users=1] = call_function[target=executorch.exir.dialects.edge._ops.aten.addmm.default](args = (%p_transformer_h_0_attn_c_attn_bias, %aten_view_copy_default, %aten_permute_copy_default), kwargs = {})
%aten_view_copy_default_1 : [num_users=1] = call_function[target=executorch.exir.dialects.edge._ops.aten.view_copy.default](args = (%aten_addmm_default, [1, %sym_size, 2304]), kwargs = {})
return [aten_view_copy_default_1]
性能分析¶
通过 ExecuTorch 开发者工具,用户能够对模型执行进行性能分析,获取模型中每个算子的计时信息。
前提条件¶
ETRecord 生成(可选)¶
ETRecord 是在导出时生成的一种工件,包含模型图和将 ExecuTorch 程序链接到原始 PyTorch 模型的源级元数据。您可以在没有 ETRecord 的情况下查看所有性能分析事件,但使用 ETRecord,您还可以将每个事件链接到正在执行的算子类型、模块层次结构以及原始 PyTorch 源代码的堆栈跟踪。更多信息请参阅ETRecord 文档。
在您的导出脚本中,调用 to_edge()
和 to_executorch()
后,使用来自 to_edge()
的 EdgeProgramManager
和来自 to_executorch()
的 ExecuTorchProgramManager
调用 generate_etrecord()
。请务必复制 EdgeProgramManager
,因为对 to_edge_transform_and_lower()
的调用会原地修改图。
# export_nanogpt.py
import copy
from executorch.devtools import generate_etrecord
# Make the deep copy immediately after to to_edge()
edge_manager_copy = copy.deepcopy(edge_manager)
# ...
# Generate ETRecord right after to_executorch()
etrecord_path = "etrecord.bin"
generate_etrecord(etrecord_path, edge_manager_copy, et_program)
运行导出脚本,ETRecord 将生成为 etrecord.bin
。
ETDump 生成¶
ETDump 是在运行时生成的一种工件,包含模型执行的跟踪信息。更多信息请参阅ETDump 文档。
在您的代码中包含 ETDump 头文件和命名空间。
// main.cpp
#include <executorch/devtools/etdump/etdump_flatcc.h>
using executorch::etdump::ETDumpGen;
using torch::executor::etdump_result;
创建 ETDumpGen 类的一个实例,并将其传递给 Module 构造函数。
std::unique_ptr<ETDumpGen> etdump_gen_ = std::make_unique<ETDumpGen>();
Module model("nanogpt.pte", Module::LoadMode::MmapUseMlockIgnoreErrors, std::move(etdump_gen_));
调用 generate()
后,将 ETDump 保存到文件。如果需要,您可以在单个跟踪文件中捕获多次模型运行。
ETDumpGen* etdump_gen = static_cast<ETDumpGen*>(model.event_tracer());
ET_LOG(Info, "ETDump size: %zu blocks", etdump_gen->get_num_blocks());
etdump_result result = etdump_gen->get_etdump_data();
if (result.buf != nullptr && result.size > 0) {
// On a device with a file system, users can just write it to a file.
FILE* f = fopen("etdump.etdp", "w+");
fwrite((uint8_t*)result.buf, 1, result.size, f);
fclose(f);
free(result.buf);
}
此外,更新 CMakeLists.txt 以使用开发者工具进行构建,并启用事件跟踪和记录到 ETDump。
option(EXECUTORCH_ENABLE_EVENT_TRACER "" ON)
option(EXECUTORCH_BUILD_DEVTOOLS "" ON)
# ...
target_link_libraries(
# ... omit existing ones
etdump) # Provides event tracing and logging
target_compile_options(executorch PUBLIC -DET_EVENT_TRACER_ENABLED)
target_compile_options(portable_ops_lib PUBLIC -DET_EVENT_TRACER_ENABLED)
构建并运行运行器,您将看到生成了一个名为“etdump.etdp”的文件。(请注意,这次我们在 Release 模式下构建是为了解决 flatccrt 的构建限制。)
(rm -rf cmake-out && mkdir cmake-out && cd cmake-out && cmake -DCMAKE_BUILD_TYPE=Release ..)
cmake --build cmake-out -j10
./cmake-out/nanogpt_runner
使用 Inspector API 进行分析¶
收集到调试工件 ETDump(以及可选的 ETRecord)后,您可以使用 Inspector API 查看性能信息。
from executorch.devtools import Inspector
inspector = Inspector(etdump_path="etdump.etdp")
# If you also generated an ETRecord, then pass that in as well: `inspector = Inspector(etdump_path="etdump.etdp", etrecord="etrecord.bin")`
with open("inspector_out.txt", "w") as file:
inspector.print_data_tabular(file)
这会将性能数据以表格格式打印到“inspector_out.txt”中,每行是一个性能分析事件。前几行如下所示: 查看全尺寸图像
要了解更多关于 Inspector 及其提供的丰富功能的信息,请参阅Inspector API 参考。
自定义内核¶
借助 ExecuTorch 自定义算子 API,自定义算子和内核的作者可以轻松地将其内核引入 PyTorch/ExecuTorch。
在 ExecuTorch 中使用自定义内核有三个步骤
使用 ExecuTorch 类型编写自定义内核。
将自定义内核编译并链接到 AOT Python 环境和运行时二进制文件。
源到源转换,将一个算子替换为自定义算子。
更多信息请参阅PyTorch 自定义算子 和 ExecuTorch 内核注册。
如何构建移动应用¶
请参阅在 iOS 和 Android 上使用 ExecuTorch 构建和运行 LLMs 的说明。