• 文档 >
  • Torch 导出到 StableHLO
快捷方式

Torch 导出到 StableHLO

本文档介绍了如何使用 torch export + torch xla 导出到 StableHLO 格式。

from torch.export import export
from torch_xla.stablehlo import exported_program_to_stablehlo
import torch_xla.core.xla_model as xm
import torchvision
import torch

xla_device = xm.xla_device()

resnet18 = torchvision.models.resnet18()
# Sample input is a tuple
sample_input = (torch.randn(4, 3, 224, 224), )
output = resnet18(*sample_input)
exported = export(resnet18, sample_input)
stablehlo_program = exported_program_to_stablehlo(exported)

# Now stablehlo_program is a callable backed by stablehlo IR.

# we can see it's stablehlo code with
#   here 'forward' is the name of function. Currently we only support
#   one entry point per program, but in the future we will support
#   multiple entry points in a program.
print(stablehlo_program.get_stablehlo_text('forward'))

# we can also print out the bytecode
print(stablehlo_program.get_stablehlo_bytecode('forward'))

# we can also run the module, to run the stablehlo module, we need to move
# our tensors to XLA device.
sample_input_xla = tuple(s.to(xla_device) for s in sample_input)

output2 = stablehlo_program(*sample_input_xla)
print(torch.allclose(output, output2.cpu(), atol=1e-5))

将 StableHLO 字节码保存到磁盘

现在可以使用以下命令将 stablehlo 保存到磁盘:

stablehlo_program.save('/tmp/stablehlo_dir')

路径应该是空目录的路径。如果目录不存在,则会创建它。此目录可以作为另一个 stablehlo_program 重新加载

from torch_xla.stablehlo import StableHLOGraphModule
stablehlo_program2 = StableHLOGraphModule.load('/tmp/stablehlo_dir')
output3 = stablehlo_program2(*sample_input_xla)

转换保存的 StableHLO 以进行 Serving

StableHLO 是一种开放格式,并且 tensorflow.serving 模型服务器支持使用该格式进行 Serving。但是,在将其提供给 tf.serving 之前,我们需要首先将生成的 StableHLO 字节码包装成 tf.saved_model 格式。

为此,首先确保您在当前的 python 环境中安装了最新的 tensorflow,如果没有,请使用以下命令安装:

pip install tf-nightly

现在,您可以运行转换器(在 torch/xla 安装中提供):

stablehlo-to-saved-model /tmp/stablehlo_dir /tmp/resnet_tf/1

之后,您可以使用 tf serving 二进制文件在新生成的 tf.saved_model 上运行模型服务器。

docker pull tensorflow/serving
docker run -p 8500:8500 \
--mount type=bind,source=/tmp/resnet_tf,target=/models/resnet_tf \
-e MODEL_NAME=resnet_tf -t tensorflow/serving &

您也可以直接使用 tf.serving 二进制文件,而无需 docker。有关更多详细信息,请遵循 tf serving 指南

常用包装器

我想直接保存 tf.saved_model 格式,而无需运行单独的命令。

您可以使用此辅助函数来实现此目的

from torch_xla.tf_saved_model_integration import save_torch_module_as_tf_saved_model

save_torch_module_as_tf_saved_model(
    resnet18,  # original pytorch torch.nn.Module
    sample_inputs, # sample inputs used to trace
    '/tmp/resnet_tf'   # directory for tf.saved_model
)

其他常用包装器

def save_as_stablehlo(exported_model: 'ExportedProgram',
                      stablehlo_dir: os.PathLike,
                      options: Optional[StableHLOExportOptions] = None):

save_as_stablehlo (也别名为 torch_xla.save_as_stablehlo)接受 ExportedProgram 并在磁盘上保存 StableHLO。即与 exported_program_to_stablehlo(…).save(…) 相同

def save_torch_model_as_stablehlo(
    torchmodel: torch.nn.Module,
    args: Tuple[Any],
    path: os.PathLike,
    options: Optional[StableHLOExportOptions] = None) -> None:
    """Convert a torch model to a callable backed by StableHLO.

接受 torch.nn.Module 并在磁盘上保存 StableHLO。即与 torch.export.export 后跟 save_as_stablehlo 相同

save_as_stablehlo 生成的文件。

在上面示例中的 /tmp/stablehlo_dir 内部,您将找到 3 个目录:dataconstantsfunctions。data 和 constants 都将包含程序使用的张量,这些张量使用 numpy.save 保存为 numpy.ndarray

functions 目录将包含 StableHLO 字节码,此处命名为 forward.bytecode,人类可读的 StableHLO 代码(MLIR 格式)forward.mlir,以及一个 JSON 文件,该文件指定了哪些权重和原始用户的输入成为此 StableHLO 函数的哪些位置参数;以及每个参数的 dtypes 和形状。

示例

$ find /tmp/stablehlo_dir
./functions
./functions/forward.mlir
./functions/forward.bytecode
./functions/forward.meta
./constants
./constants/3
./constants/1
./constants/0
./constants/2
./data
./data/L__fn___layers_15_feed_forward_w2.weight
./data/L__fn___layers_13_feed_forward_w1.weight
./data/L__fn___layers_3_attention_wo.weight
./data/L__fn___layers_12_ffn_norm_weight
./data/L__fn___layers_25_attention_wo.weight
...

JSON 文件是 torch_xla.stablehlo.StableHLOFunc 类的序列化形式。

此格式目前也处于原型阶段,并且不保证向后兼容性。未来的计划是标准化主要框架(PyTorch、JAX、TensorFlow)可以接受的格式。

通过生成 stablehlo.composite 在 StableHLO 中保留高级 PyTorch 操作

高级 PyTorch 操作(例如 F.scaled_dot_product_attention)将在 PyTorch -> StableHLO 降低期间分解为低级操作。在下游 ML 编译器中捕获高级操作对于生成高性能、高效的专用内核至关重要。虽然在 ML 编译器中模式匹配一堆低级操作可能具有挑战性且容易出错,但我们提供了一种更强大的方法来在 StableHLO 程序中概述高级 PyTorch 操作 - 通过为高级 PyTorch 操作生成 stablehlo.composite

使用 StableHLOCompositeBuilder,用户可以概述 torch.nn.Moduleforward 函数内的任意区域。然后在导出的 StableHLO 程序中,将为概述区域生成一个 composite 操作。

注意: 由于非张量输入到概述区域的值将硬编码在导出的图中,如果希望从下游编译器检索这些值,请将这些值存储为 composite 属性。

以下示例显示了一个实际用例 - 捕获 scaled_product_attention

import torch
import torch.nn.functional as F
from torch_xla import stablehlo
from torch_xla.experimental.mark_pattern_utils import StableHLOCompositeBuilder


class M(torch.nn.Module):

    def __init__(self):
        super().__init__()
        self.q_proj = torch.nn.Linear(128, 128, bias=False)
        self.k_proj = torch.nn.Linear(128, 128, bias=False)
        self.v_proj = torch.nn.Linear(128, 128, bias=False)
        # Initialize the StableHLOCompositeBuilder with the name of the composite op and its attributes
        # Note: To capture the value of non-tensor inputs, please pass them as attributes to the builder
        self.b = StableHLOCompositeBuilder("test.sdpa", {"scale": 0.25, "other_attr": "val"})

    def forward(self, x):
        q = self.q_proj(x)
        k = self.k_proj(x)
        v = self.v_proj(x)
        q, k, v = self.b.mark_inputs(q, k, v)
        attn_out = F.scaled_dot_product_attention(q, k, v, scale=0.25)
        attn_out = self.b.mark_outputs(attn_out)
        attn_out = attn_out + x
        return attn_out

input_args = (torch.randn((10, 8, 128)), )
# torch.export to Exported Program
exported = torch.export.export(M(), input_args)
# Exported Program to StableHLO
stablehlo_gm = stablehlo.exported_program_to_stablehlo(exported)
stablehlo = stablehlo_gm.get_stablehlo_text()
print(stablehlo)

主 StableHLO 图如下所示

module @IrToHlo.56 attributes {mhlo.cross_program_prefetches = [], mhlo.input_output_alias = [], mhlo.is_dynamic = false, mhlo.use_auto_spmd_partitioning = false} {
  func.func @main(%arg0: tensor<10x8x128xf32>, %arg1: tensor<128x128xf32>, %arg2: tensor<128x128xf32>, %arg3: tensor<128x128xf32>) -> tensor<10x8x128xf32> {
    ...
    %10 = stablehlo.composite "test.sdpa" %3, %6, %9 {composite_attributes = {other_attr = "val", scale = 2.500000e-01 : f32}, decomposition = @test.sdpa.impl} : (tensor<10x8x128xf32>, tensor<10x8x128xf32>, tensor<10x8x128xf32>) -> tensor<10x8x128xf32>
    %11 = stablehlo.add %10, %arg0 : tensor<10x8x128xf32>
    return %11 : tensor<10x8x128xf32>
  }

  func.func private @test.sdpa.impl(%arg0: tensor<10x8x128xf32>, %arg1: tensor<10x8x128xf32>, %arg2: tensor<10x8x128xf32>) -> tensor<10x8x128xf32> {
    // Actual implementation of the composite
    ...
    return %11 : tensor<10x8x128xf32>
  }

sdpa 操作被封装为主图中的 stablehlo composite 调用。torch.nn.Module 中指定的名称和属性会得到传播。

%10 = stablehlo.composite "test.sdpa" %3, %6, %9 {composite_attributes = {other_attr = "val", scale = 2.500000e-01 : f32}, decomposition = @test.sdpa.impl}

sdpa 操作的参考 PyTorch 分解被捕获在 StableHLO 函数中

func.func private @test.sdpa.impl(%arg0: tensor<10x8x128xf32>, %arg1: tensor<10x8x128xf32>, %arg2: tensor<10x8x128xf32>) -> tensor<10x8x128xf32> {
    // Actual implementation of the composite
    ...
    return %11 : tensor<10x8x128xf32>
  }

文档

访问 PyTorch 的全面开发者文档

查看文档

教程

获取面向初学者和高级开发者的深度教程

查看教程

资源

查找开发资源并获得问题解答

查看资源