• 文档 >
  • 捆绑程序 – ExecuTorch 模型验证工具
快捷键

捆绑程序 – ExecuTorch 模型验证工具

简介

BundledProgram 是核心 ExecuTorch 程序的包装器,旨在帮助用户包装带有他们部署模型的测试用例。BundledProgram 不一定是程序的核心部分,也不是其执行所必需的,但对于各种其他用例尤其重要,例如模型正确性评估,包括模型启动过程中的端到端测试。

总的来说,该过程可以分为两个阶段,在每个阶段我们都支持

  • Emit 阶段:将测试 I/O 用例与 ExecuTorch 程序捆绑在一起,序列化为 flatbuffer。

  • 运行时阶段:在运行时访问、执行和验证捆绑的测试用例。

Emit 阶段

此阶段主要关注 BundledProgram 的创建并将其作为 flatbuffer 文件转储到磁盘。主要步骤如下

  1. 创建一个模型并发出其 ExecuTorch 程序。

  2. 构建一个 List[MethodTestSuite] 以记录所有需要捆绑的测试用例。

  3. 通过使用发出的模型和 List[MethodTestSuite] 生成 BundledProgram

  4. 序列化 BundledProgram 并将其转储到磁盘。

步骤 1:创建一个模型并发出其 ExecuTorch 程序。

ExecuTorch 程序可以通过使用 ExecuTorch API 从用户的模型中发出。请遵循生成示例 ExecuTorch 程序导出到 ExecuTorch 教程

步骤 2:构建 List[MethodTestSuite] 以保存测试信息

BundledProgram 中,我们创建了两个新类 MethodTestCaseMethodTestSuite,以保存 ExecuTorch 程序验证的必要信息。

MethodTestCase 表示单个测试用例。每个 MethodTestCase 都包含单次执行的输入和预期输出。

MethodTestCase
executorch.devtools.bundled_program.config.MethodTestCase.__init__(self, inputs, expected_outputs=None)

用于验证特定方法的单个测试用例

参数
  • inputs

    eager_model 与特定推理方法进行一次执行所需的所有输入。

    值得一提的是,虽然捆绑程序和 ET 运行时 API 都支持设置 torch.tensor 类型以外的输入,但只有 torch.tensor 类型的输入才会在方法中实际更新,其余输入将仅进行健全性检查,以查看它们是否与方法中的默认值匹配。

  • expected_outputs – 给定输入的预期输出,用于验证。如果用户只想使用测试用例进行性能分析,则可以为 None。

返回

self

MethodTestSuite 包含单个方法的所有测试信息,包括表示方法名称的字符串和 List[MethodTestCase],用于所有测试用例

MethodTestSuite
executorch.devtools.bundled_program.config.MethodTestSuite(method_name, test_cases)[源代码]

与验证方法相关的所有测试信息

executorch.devtools.bundled_program.config.method_name

要验证的方法的名称。

executorch.devtools.bundled_program.config.test_cases

用于验证该方法的所有测试用例。

由于每个模型可能有多个推理方法,我们需要生成 List[MethodTestSuite] 以保存所有必要的信息。

步骤 3:生成 BundledProgram

我们在 executorch/devtools/bundled_program/core.py 下提供了 BundledProgram 类,用于捆绑类似 ExecutorchProgram 的变量,包括 ExecutorchProgramMultiMethodExecutorchProgramExecutorchProgramManager,以及 List[MethodTestSuite]

BundledProgram
executorch.devtools.bundled_program.core.BundledProgram.__init__(self, executorch_program, method_test_suites, pte_file_path=None)

通过将给定程序和 method_test_suites 捆绑在一起来创建 BundledProgram。

参数
  • executorch_program – 要捆绑的程序。

  • method_test_suites – 要捆绑的某些方法的测试用例。

  • pte_file_path – 如果未提供 executorch_program,则为 pte 文件的路径,用于反序列化程序。

BundledProgram 的构造函数将在内部进行健全性检查,以查看给定的 List[MethodTestSuite] 是否与给定程序的要求匹配。具体来说

  1. List[MethodTestSuite] 中每个 MethodTestSuite 的 method_names 也应该在程序中。请注意,无需为程序中的每种方法都设置测试用例。

  2. 每个测试用例的元数据应满足相应推理方法输入的要求。

步骤 4:将 BundledProgram 序列化为 Flatbuffer。

为了序列化 BundledProgram 以便运行时 API 可以使用它,我们提供了两个 API,都在 executorch/devtools/bundled_program/serialize/__init__.py 下。

序列化和反序列化
executorch.devtools.bundled_program.serialize.serialize_from_bundled_program_to_flatbuffer(bundled_program)[源代码]

将 BundledProgram 序列化为 FlatBuffer 二进制格式。

参数

bundled_program (BundledProgram) – 要序列化的 BundledProgram 变量。

返回

字节形式的序列化 FlatBuffer 二进制数据。

executorch.devtools.bundled_program.serialize.deserialize_from_flatbuffer_to_bundled_program(flatbuffer)[源代码]

将 FlatBuffer 二进制格式反序列化为 BundledProgram。

参数

flatbuffer (bytes) – 字节形式的 FlatBuffer 二进制数据。

返回

一个 BundledProgram 实例。

Emit 示例

这是一个流程,重点介绍如何在给定 PyTorch 模型和我们想要测试它的代表性输入的情况下生成 BundledProgram

import torch

from executorch.exir import to_edge
from executorch.devtools import BundledProgram

from executorch.devtools.bundled_program.config import MethodTestCase, MethodTestSuite
from executorch.devtools.bundled_program.serialize import (
    serialize_from_bundled_program_to_flatbuffer,
)
from torch.export import export, export_for_training


# Step 1: ExecuTorch Program Export
class SampleModel(torch.nn.Module):
    """An example model with multi-methods. Each method has multiple input and single output"""

    def __init__(self) -> None:
        super().__init__()
        self.a: torch.Tensor = 3 * torch.ones(2, 2, dtype=torch.int32)
        self.b: torch.Tensor = 2 * torch.ones(2, 2, dtype=torch.int32)

    def forward(self, x: torch.Tensor, q: torch.Tensor) -> torch.Tensor:
        z = x.clone()
        torch.mul(self.a, x, out=z)
        y = x.clone()
        torch.add(z, self.b, out=y)
        torch.add(y, q, out=y)
        return y


# Inference method name of SampleModel we want to bundle testcases to.
# Notices that we do not need to bundle testcases for every inference methods.
method_name = "forward"
model = SampleModel()

# Inputs for graph capture.
capture_input = (
    (torch.rand(2, 2) - 0.5).to(dtype=torch.int32),
    (torch.rand(2, 2) - 0.5).to(dtype=torch.int32),
)

# Export method's FX Graph.
method_graph = export(
    export_for_training(model, capture_input).module(),
    capture_input,
)


# Emit the traced method into ET Program.
et_program = to_edge(method_graph).to_executorch()

# Step 2: Construct MethodTestSuite for Each Method

# Prepare the Test Inputs.

# Number of input sets to be verified
n_input = 10

# Input sets to be verified.
inputs = [
    # Each list below is a individual input set.
    # The number of inputs, dtype and size of each input follow Program's spec.
    [
        (torch.rand(2, 2) - 0.5).to(dtype=torch.int32),
        (torch.rand(2, 2) - 0.5).to(dtype=torch.int32),
    ]
    for _ in range(n_input)
]

# Generate Test Suites
method_test_suites = [
    MethodTestSuite(
        method_name=method_name,
        test_cases=[
            MethodTestCase(
                inputs=input,
                expected_outputs=(getattr(model, method_name)(*input), ),
            )
            for input in inputs
        ],
    ),
]

# Step 3: Generate BundledProgram
bundled_program = BundledProgram(et_program, method_test_suites)

# Step 4: Serialize BundledProgram to flatbuffer.
serialized_bundled_program = serialize_from_bundled_program_to_flatbuffer(
    bundled_program
)
save_path = "bundled_program.bpte"
with open(save_path, "wb") as f:
    f.write(serialized_bundled_program)

如果需要,我们还可以从 flatbuffer 文件重新生成 BundledProgram

from executorch.devtools.bundled_program.serialize import deserialize_from_flatbuffer_to_bundled_program
save_path = "bundled_program.bpte"
with open(save_path, "rb") as f:
    serialized_bundled_program = f.read()

regenerate_bundled_program = deserialize_from_flatbuffer_to_bundled_program(serialized_bundled_program)

运行时阶段

此阶段主要关注使用捆绑输入执行模型,并将模型的输出与捆绑的预期输出进行比较。我们提供了多个 API 来处理它的关键部分。

BundledProgram 缓冲区获取 ExecuTorch 程序指针

我们需要指向 ExecuTorch 程序的指针才能执行。为了统一加载和执行 BundledProgram 和 Program flatbuffer 的过程,我们创建了一个 API

get_program_data

警告

doxygenfunction:在项目 “ExecuTorch” 的 doxygen xml 输出中找不到函数 “::executorch::bundled_program::get_program_data”,目录为:../build/xml/

以下是如何使用 get_program_data API 的示例

// Assume that the user has read the contents of the file into file_data using
// whatever method works best for their application. The file could contain
// either BundledProgram data or Program data.
void* file_data = ...;
size_t file_data_len = ...;

// If file_data contains a BundledProgram, get_program_data() will return a
// pointer to the Program data embedded inside it. Otherwise it will return
// file_data, which already pointed to Program data.
const void* program_ptr;
size_t program_len;
status = executorch::bundled_program::get_program_data(
    file_data, file_data_len, &program_ptr, &program_len);
ET_CHECK_MSG(
    status == Error::Ok,
    "get_program_data() failed with status 0x%" PRIx32,
    status);

将捆绑输入加载到方法

为了在捆绑输入上执行程序,我们需要将捆绑输入加载到方法中。这里我们提供了一个名为 executorch::bundled_program::load_bundled_input 的 API

load_bundled_input

警告

doxygenfunction:在项目 “ExecuTorch” 的 doxygen xml 输出中找不到函数 “::executorch::bundled_program::load_bundled_input”,目录为:../build/xml/

验证方法的输出。

我们调用 executorch::bundled_program::verify_method_outputs 来验证方法的输出是否与捆绑的预期输出一致。以下是此 API 的详细信息

verify_method_outputs

警告

doxygenfunction:在项目 “ExecuTorch” 的 doxygen xml 输出中找不到函数 “::executorch::bundled_program::verify_method_outputs”,目录为:../build/xml/

运行时示例

这里我们提供一个关于如何逐步运行捆绑程序的示例。大多数代码借用自 executor_runner,如果您需要更多信息和上下文,请查看该文件

// method_name is the name for the method we want to test
// memory_manager is the executor::MemoryManager variable for executor memory allocation.
// program is the ExecuTorch program.
Result<Method> method = program->load_method(method_name, &memory_manager);

ET_CHECK_MSG(
    method.ok(),
    "load_method() failed with status 0x%" PRIx32,
    method.error());

// Load testset_idx-th input in the buffer to plan
status = executorch::bundled_program::load_bundled_input(
        *method,
        program_data.bundled_program_data(),
        FLAGS_testset_idx);
ET_CHECK_MSG(
    status == Error::Ok,
    "load_bundled_input failed with status 0x%" PRIx32,
    status);

// Execute the plan
status = method->execute();
ET_CHECK_MSG(
    status == Error::Ok,
    "method->execute() failed with status 0x%" PRIx32,
    status);

// Verify the result.
status = executorch::bundled_program::verify_method_outputs(
        *method,
        program_data.bundled_program_data(),
        FLAGS_testset_idx,
        FLAGS_rtol,
        FLAGS_atol);
ET_CHECK_MSG(
    status == Error::Ok,
    "Bundle verification failed with status 0x%" PRIx32,
    status);

常见错误

如果 List[MethodTestSuites]Program 不匹配,则会引发错误。以下是两种常见情况

测试输入与模型的要求不匹配。

PyTorch 模型的每种推理方法都对其输入有自己的要求,例如输入数量、每个输入的数据类型等。BundledProgram 如果测试输入不满足要求,则会引发错误。

以下是测试输入的数据类型不满足模型要求的示例

import torch

from executorch.exir import to_edge
from executorch.devtools import BundledProgram

from executorch.devtools.bundled_program.config import MethodTestCase, MethodTestSuite
from torch.export import export, export_for_training


class Module(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.a = 3 * torch.ones(2, 2, dtype=torch.float)
        self.b = 2 * torch.ones(2, 2, dtype=torch.float)

    def forward(self, x):
        out_1 = torch.ones(2, 2, dtype=torch.float)
        out_2 = torch.ones(2, 2, dtype=torch.float)
        torch.mul(self.a, x, out=out_1)
        torch.add(out_1, self.b, out=out_2)
        return out_2


model = Module()
method_names = ["forward"]

inputs = (torch.ones(2, 2, dtype=torch.float), )

# Find each method of model needs to be traced my its name, export its FX Graph.
method_graph = export(
    export_for_training(model, inputs).module(),
    inputs,
)

# Emit the traced methods into ET Program.
et_program = to_edge(method_graph).to_executorch()

# number of input sets to be verified
n_input = 10

# Input sets to be verified for each inference methods.
# To simplify, here we create same inputs for all methods.
inputs = {
    # Inference method name corresponding to its test cases.
    m_name: [
        # NOTE: executorch program needs torch.float, but here is torch.int
        [
            torch.randint(-5, 5, (2, 2), dtype=torch.int),
        ]
        for _ in range(n_input)
    ]
    for m_name in method_names
}

# Generate Test Suites
method_test_suites = [
    MethodTestSuite(
        method_name=m_name,
        test_cases=[
            MethodTestCase(
                inputs=input,
                expected_outputs=(getattr(model, m_name)(*input),),
            )
            for input in inputs[m_name]
        ],
    )
    for m_name in method_names
]

# Generate BundledProgram

bundled_program = BundledProgram(et_program, method_test_suites)
引发的错误
The input tensor tensor([[-2,  0],
        [-2, -1]], dtype=torch.int32) dtype shall be torch.float32, but now is torch.int32
---------------------------------------------------------------------------
AssertionError                            Traceback (most recent call last)
Cell In[1], line 72
     56 method_test_suites = [
     57     MethodTestSuite(
     58         method_name=m_name,
   (...)
     67     for m_name in method_names
     68 ]
     70 # Step 3: Generate BundledProgram
---> 72 bundled_program = create_bundled_program(program, method_test_suites)
File /executorch/devtools/bundled_program/core.py:276, in create_bundled_program(program, method_test_suites)
    264 """Create bp_schema.BundledProgram by bundling the given program and method_test_suites together.
    265
    266 Args:
   (...)
    271     The `BundledProgram` variable contains given ExecuTorch program and test cases.
    272 """
    274 method_test_suites = sorted(method_test_suites, key=lambda x: x.method_name)
--> 276 assert_valid_bundle(program, method_test_suites)
    278 bundled_method_test_suites: List[bp_schema.BundledMethodTestSuite] = []
    280 # Emit data and metadata of bundled tensor
File /executorch/devtools/bundled_program/core.py:219, in assert_valid_bundle(program, method_test_suites)
    215 # type of tensor input should match execution plan
    216 if type(cur_plan_test_inputs[j]) == torch.Tensor:
    217     # pyre-fixme[16]: Undefined attribute [16]: Item `bool` of `typing.Union[bool, float, int, torch._tensor.Tensor]`
    218     # has no attribute `dtype`.
--> 219     assert cur_plan_test_inputs[j].dtype == get_input_dtype(
    220         program, program_plan_id, j
    221     ), "The input tensor {} dtype shall be {}, but now is {}".format(
    222         cur_plan_test_inputs[j],
    223         get_input_dtype(program, program_plan_id, j),
    224         cur_plan_test_inputs[j].dtype,
    225     )
    226 elif type(cur_plan_test_inputs[j]) in (
    227     int,
    228     bool,
    229     float,
    230 ):
    231     assert type(cur_plan_test_inputs[j]) == get_input_type(
    232         program, program_plan_id, j
    233     ), "The input primitive dtype shall be {}, but now is {}".format(
    234         get_input_type(program, program_plan_id, j),
    235         type(cur_plan_test_inputs[j]),
    236     )
AssertionError: The input tensor tensor([[-2,  0],
        [-2, -1]], dtype=torch.int32) dtype shall be torch.float32, but now is torch.int32

BundleConfig 中的方法名称不存在。

另一个常见错误是任何 MethodTestSuite 中的方法名称在模型中不存在。BundledProgram 将引发错误并显示不存在的方法名称

import torch

from executorch.exir import to_edge
from executorch.devtools import BundledProgram

from executorch.devtools.bundled_program.config import MethodTestCase, MethodTestSuite
from torch.export import export, export_for_training


class Module(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.a = 3 * torch.ones(2, 2, dtype=torch.float)
        self.b = 2 * torch.ones(2, 2, dtype=torch.float)

    def forward(self, x):
        out_1 = torch.ones(2, 2, dtype=torch.float)
        out_2 = torch.ones(2, 2, dtype=torch.float)
        torch.mul(self.a, x, out=out_1)
        torch.add(out_1, self.b, out=out_2)
        return out_2


model = Module()
method_names = ["forward"]

inputs = (torch.ones(2, 2, dtype=torch.float),)

# Find each method of model needs to be traced my its name, export its FX Graph.
method_graph = export(
    export_for_training(model, inputs).module(),
    inputs,
)

# Emit the traced methods into ET Program.
et_program = to_edge(method_graph).to_executorch()

# number of input sets to be verified
n_input = 10

# Input sets to be verified for each inference methods.
# To simplify, here we create same inputs for all methods.
inputs = {
    # Inference method name corresponding to its test cases.
    m_name: [
        [
            torch.randint(-5, 5, (2, 2), dtype=torch.float),
        ]
        for _ in range(n_input)
    ]
    for m_name in method_names
}

# Generate Test Suites
method_test_suites = [
    MethodTestSuite(
        method_name=m_name,
        test_cases=[
            MethodTestCase(
                inputs=input,
                expected_outputs=(getattr(model, m_name)(*input),),
            )
            for input in inputs[m_name]
        ],
    )
    for m_name in method_names
]

# NOTE: MISSING_METHOD_NAME is not an inference method in the above model.
method_test_suites[0].method_name = "MISSING_METHOD_NAME"

# Generate BundledProgram
bundled_program = BundledProgram(et_program, method_test_suites)

引发的错误
All method names in bundled config should be found in program.execution_plan,          but {'MISSING_METHOD_NAME'} does not include.
---------------------------------------------------------------------------
AssertionError                            Traceback (most recent call last)
Cell In[3], line 73
     70 method_test_suites[0].method_name = "MISSING_METHOD_NAME"
     72 # Generate BundledProgram
---> 73 bundled_program = create_bundled_program(program, method_test_suites)
File /executorch/devtools/bundled_program/core.py:276, in create_bundled_program(program, method_test_suites)
    264 """Create bp_schema.BundledProgram by bundling the given program and method_test_suites together.
    265
    266 Args:
   (...)
    271     The `BundledProgram` variable contains given ExecuTorch program and test cases.
    272 """
    274 method_test_suites = sorted(method_test_suites, key=lambda x: x.method_name)
--> 276 assert_valid_bundle(program, method_test_suites)
    278 bundled_method_test_suites: List[bp_schema.BundledMethodTestSuite] = []
    280 # Emit data and metadata of bundled tensor
File /executorch/devtools/bundled_program/core.py:141, in assert_valid_bundle(program, method_test_suites)
    138 method_name_of_program = {e.name for e in program.execution_plan}
    139 method_name_of_test_suites = {t.method_name for t in method_test_suites}
--> 141 assert method_name_of_test_suites.issubset(
    142     method_name_of_program
    143 ), f"All method names in bundled config should be found in program.execution_plan, \
    144      but {str(method_name_of_test_suites - method_name_of_program)} does not include."
    146 # check if method_tesdt_suites has been sorted in ascending alphabetical order of method name.
    147 for test_suite_id in range(1, len(method_test_suites)):
AssertionError: All method names in bundled config should be found in program.execution_plan,          but {'MISSING_METHOD_NAME'} does not include.

文档

访问 PyTorch 的全面开发者文档

查看文档

教程

获取面向初学者和高级开发者的深入教程

查看教程

资源

查找开发资源并获得问题解答

查看资源