AOTInductor:Torch.Export-ed 模型的预先编译¶
警告
AOTInductor 及其相关功能处于原型状态,可能会有向后不兼容的更改。
AOTInductor 是 TorchInductor 的一个专门版本,旨在处理导出的 PyTorch 模型,优化它们,并生成共享库以及其他相关工件。这些编译后的工件是专门为部署在非 Python 环境中而设计的,这些环境通常用于服务器端的推理部署。
在本教程中,您将深入了解如何获取 PyTorch 模型,导出它,将其编译成共享库,并使用 C++ 进行模型预测的过程。
模型编译¶
使用 AOTInductor,您仍然可以使用 Python 编写模型。以下示例演示了如何调用 aoti_compile_and_package
将模型转换为共享库。
此 API 使用 torch.export.export
将模型捕获到计算图中,然后使用 TorchInductor 生成一个 .so 文件,该文件可以在非 Python 环境中运行。有关 torch._inductor.aoti_compile_and_package
API 的全面详细信息,您可以参考 此处 的代码。有关 torch.export.export
的更多详细信息,您可以参考 torch.export 文档。
注意
如果您的机器上装有支持 CUDA 的设备,并且您安装了支持 CUDA 的 PyTorch,则以下代码会将模型编译为用于 CUDA 执行的共享库。否则,编译后的工件将在 CPU 上运行。为了在 CPU 推理期间获得更好的性能,建议在运行下面的 Python 脚本之前,通过设置 export TORCHINDUCTOR_FREEZING=1 来启用冻结。
import os
import torch
class Model(torch.nn.Module):
def __init__(self):
super().__init__()
self.fc1 = torch.nn.Linear(10, 16)
self.relu = torch.nn.ReLU()
self.fc2 = torch.nn.Linear(16, 1)
self.sigmoid = torch.nn.Sigmoid()
def forward(self, x):
x = self.fc1(x)
x = self.relu(x)
x = self.fc2(x)
x = self.sigmoid(x)
return x
with torch.no_grad():
device = "cuda" if torch.cuda.is_available() else "cpu"
model = Model().to(device=device)
example_inputs=(torch.randn(8, 10, device=device),)
batch_dim = torch.export.Dim("batch", min=1, max=1024)
# [Optional] Specify the first dimension of the input x as dynamic.
exported = torch.export.export(model, example_inputs, dynamic_shapes={"x": {0: batch_dim}})
# [Note] In this example we directly feed the exported module to aoti_compile_and_package.
# Depending on your use case, e.g. if your training platform and inference platform
# are different, you may choose to save the exported model using torch.export.save and
# then load it back using torch.export.load on your inference platform to run AOT compilation.
output_path = torch._inductor.aoti_compile_and_package(
exported,
# [Optional] Specify the generated shared library path. If not specified,
# the generated artifact is stored in your system temp directory.
package_path=os.path.join(os.getcwd(), "model.pt2"),
)
在此示例中,Dim
参数用于将输入变量 “x” 的第一个维度指定为动态维度。值得注意的是,编译库的路径和名称仍然未指定,导致共享库存储在临时目录中。为了从 C++ 端访问此路径,我们将其保存到一个文件中,以便稍后在 C++ 代码中检索。
Python 中的推理¶
有多种部署编译后的工件以进行推理的方法,其中一种是使用 Python。我们在 Python 中提供了一个方便的实用程序 API torch._inductor.aoti_load_package
,用于加载和运行工件,如下例所示
import os
import torch
device = "cuda" if torch.cuda.is_available() else "cpu"
model = torch._inductor.aoti_load_package(os.path.join(os.getcwd(), "model.pt2"))
print(model(torch.randn(8, 10, device=device)))
C++ 中的推理¶
接下来,我们使用以下示例 C++ 文件 inference.cpp
来加载编译后的工件,使我们能够在 C++ 环境中直接进行模型预测。
#include <iostream>
#include <vector>
#include <torch/torch.h>
#include <torch/csrc/inductor/aoti_package/model_package_loader.h>
int main() {
c10::InferenceMode mode;
torch::inductor::AOTIModelPackageLoader loader("model.pt2");
torch::inductor::AOTIModelContainerRunner* runner = loader.get_runner();
// Assume running on CUDA
std::vector<torch::Tensor> inputs = {torch::randn({8, 10}, at::kCUDA)};
std::vector<torch::Tensor> outputs = runner->run(inputs);
std::cout << "Result from the first inference:"<< std::endl;
std::cout << outputs[0] << std::endl;
// The second inference uses a different batch size and it works because we
// specified that dimension as dynamic when compiling model.pt2.
std::cout << "Result from the second inference:"<< std::endl;
// Assume running on CUDA
std::cout << runner->run({torch::randn({1, 10}, at::kCUDA)})[0] << std::endl;
return 0;
}
为了构建 C++ 文件,您可以使用提供的 CMakeLists.txt
文件,该文件自动化了调用 python model.py
以进行模型的 AOT 编译并将 inference.cpp
编译为名为 aoti_example
的可执行二进制文件的过程。
cmake_minimum_required(VERSION 3.18 FATAL_ERROR)
project(aoti_example)
find_package(Torch REQUIRED)
add_executable(aoti_example inference.cpp model.pt2)
add_custom_command(
OUTPUT model.pt2
COMMAND python ${CMAKE_CURRENT_SOURCE_DIR}/model.py
DEPENDS model.py
)
target_link_libraries(aoti_example "${TORCH_LIBRARIES}")
set_property(TARGET aoti_example PROPERTY CXX_STANDARD 17)
如果目录结构如下所示,则可以执行以下命令来构建二进制文件。需要注意的是,CMAKE_PREFIX_PATH
变量对于 CMake 定位 LibTorch 库至关重要,应将其设置为绝对路径。请注意,您的路径可能与此示例中所示的路径有所不同。
aoti_example/
CMakeLists.txt
inference.cpp
model.py
$ mkdir build
$ cd build
$ CMAKE_PREFIX_PATH=/path/to/python/install/site-packages/torch/share/cmake cmake ..
$ cmake --build . --config Release
在 aoti_example
二进制文件在 build
目录中生成后,执行它将显示类似于以下的結果
$ ./aoti_example
Result from the first inference:
0.4866
0.5184
0.4462
0.4611
0.4744
0.4811
0.4938
0.4193
[ CUDAFloatType{8,1} ]
Result from the second inference:
0.4883
0.4703
[ CUDAFloatType{2,1} ]