• 文档 >
  • Torch Export 到 StableHLO
快捷方式

Torch Export 到 StableHLO

本文档描述了如何使用 torch export + torch xla 导出为 StableHLO 格式。

from torch.export import export
from torch_xla.stablehlo import exported_program_to_stablehlo
import torch_xla.core.xla_model as xm
import torchvision
import torch

xla_device = xm.xla_device()

resnet18 = torchvision.models.resnet18()
# Sample input is a tuple
sample_input = (torch.randn(4, 3, 224, 224), )
output = resnet18(*sample_input)
exported = export(resnet18, sample_input)
stablehlo_program = exported_program_to_stablehlo(exported)

# Now stablehlo_program is a callable backed by stablehlo IR.

# we can see it's stablehlo code with
#   here 'forward' is the name of function. Currently we only support
#   one entry point per program, but in the future we will support
#   multiple entry points in a program.
print(stablehlo_program.get_stablehlo_text('forward'))

# we can also print out the bytecode
print(stablehlo_program.get_stablehlo_bytecode('forward'))

# we can also run the module, to run the stablehlo module, we need to move
# our tensors to XLA device.
sample_input_xla = tuple(s.to(xla_device) for s in sample_input)

output2 = stablehlo_program(*sample_input_xla)
print(torch.allclose(output, output2.cpu(), atol=1e-5))

将 StableHLO 字节码保存到磁盘

现在可以使用以下方式将 stablehlo 保存到磁盘:

stablehlo_program.save('/tmp/stablehlo_dir')

路径应指向一个空目录。如果目录不存在,则会创建。该目录可以作为另一个 stablehlo_program 再次加载。

from torch_xla.stablehlo import StableHLOGraphModule
stablehlo_program2 = StableHLOGraphModule.load('/tmp/stablehlo_dir')
output3 = stablehlo_program2(*sample_input_xla)

转换已保存的 StableHLO 用于服务

StableHLO 是一种开放格式,在 tensorflow.serving 模型服务器中支持服务。然而,在将其提供给 tf.serving 之前,我们需要先将生成的 StableHLO 字节码包装成 tf.saved_model 格式。

为此,首先确保您当前的 python 环境中安装了最新的 tensorflow;如果未安装,请使用以下命令安装:

pip install tf-nightly

现在,您可以运行一个转换器(由 torch/xla 安装提供)

stablehlo-to-saved-model /tmp/stablehlo_dir /tmp/resnet_tf/1

之后,您可以使用 tf serving 二进制文件在新生成的 tf.saved_model 上运行您的模型服务器。

docker pull tensorflow/serving
docker run -p 8500:8500 \
--mount type=bind,source=/tmp/resnet_tf,target=/models/resnet_tf \
-e MODEL_NAME=resnet_tf -t tensorflow/serving &

您也可以直接使用 tf.serving 二进制文件,无需 docker。更多详细信息,请遵循 tf serving 指南

常用包装器

我想直接保存为 tf.saved_model 格式,无需运行单独的命令。

您可以通过使用此帮助函数来实现此目的:

from torch_xla.tf_saved_model_integration import save_torch_module_as_tf_saved_model

save_torch_module_as_tf_saved_model(
    resnet18,  # original pytorch torch.nn.Module
    sample_inputs, # sample inputs used to trace
    '/tmp/resnet_tf'   # directory for tf.saved_model
)

其他常用包装器

def save_as_stablehlo(exported_model: 'ExportedProgram',
                      stablehlo_dir: os.PathLike,
                      options: Optional[StableHLOExportOptions] = None):

save_as_stablehlo(也别名为 torch_xla.save_as_stablehlo)接受 ExportedProgram 并将 StableHLO 保存到磁盘。即,与 exported_program_to_stablehlo(…).save(…) 相同

def save_torch_model_as_stablehlo(
    torchmodel: torch.nn.Module,
    args: Tuple[Any],
    path: os.PathLike,
    options: Optional[StableHLOExportOptions] = None) -> None:
    """Convert a torch model to a callable backed by StableHLO.

接受 torch.nn.Module 并将 StableHLO 保存到磁盘。即,与先 torch.export.export 再 save_as_stablehlo 相同

save_as_stablehlo 生成的文件。

在上面示例的 /tmp/stablehlo_dir 中,您会找到 3 个目录:dataconstantsfunctions。data 和 constants 目录都包含程序使用的张量,这些张量使用 numpy.save 保存为 numpy.ndarray 格式。

functions 目录将包含 StableHLO 字节码,此处命名为 forward.bytecode,以及人类可读的 StableHLO 代码(MLIR 格式)forward.mlir,还有一个 JSON 文件,指定哪些权重和原始用户输入成为此 StableHLO 函数的哪些位置参数;以及每个参数的数据类型(dtypes)和形状(shapes)。

示例

$ find /tmp/stablehlo_dir
./functions
./functions/forward.mlir
./functions/forward.bytecode
./functions/forward.meta
./constants
./constants/3
./constants/1
./constants/0
./constants/2
./data
./data/L__fn___layers_15_feed_forward_w2.weight
./data/L__fn___layers_13_feed_forward_w1.weight
./data/L__fn___layers_3_attention_wo.weight
./data/L__fn___layers_12_ffn_norm_weight
./data/L__fn___layers_25_attention_wo.weight
...

该 JSON 文件是 torch_xla.stablehlo.StableHLOFunc 类的序列化形式。

这种格式目前仍处于原型阶段,不保证向后兼容性。未来的计划是标准化一种主要框架(PyTorch、JAX、TensorFlow)都能接受的格式。

通过生成 stablehlo.composite 在 StableHLO 中保留高层 PyTorch 操作

高层 PyTorch ops(例如 F.scaled_dot_product_attention)在 PyTorch -> StableHLO 转换过程中会被分解为低层 ops。在下游 ML 编译器中捕获高层 op 对于生成高性能、高效的专业内核至关重要。虽然在 ML 编译器中对一组低层 ops 进行模式匹配可能具有挑战性且容易出错,但我们提供了一种更可靠的方法来在 StableHLO 程序中概括高层 PyTorch op - 通过为高层 PyTorch ops 生成 stablehlo.composite

使用 StableHLOCompositeBuilder,用户可以在 torch.nn.Moduleforward 函数中概括任意区域。然后在导出的 StableHLO 程序中,将为概括的区域生成一个 composite op。

注意:由于概括区域的非张量输入值将在导出图中硬编码,如果希望从下游编译器中检索,请将这些值存储为复合属性(composite attributes)。

以下示例展示了一个实际用例 - 捕获 scaled_product_attention

import torch
import torch.nn.functional as F
from torch_xla import stablehlo
from torch_xla.experimental.mark_pattern_utils import StableHLOCompositeBuilder


class M(torch.nn.Module):

    def __init__(self):
        super().__init__()
        self.q_proj = torch.nn.Linear(128, 128, bias=False)
        self.k_proj = torch.nn.Linear(128, 128, bias=False)
        self.v_proj = torch.nn.Linear(128, 128, bias=False)
        # Initialize the StableHLOCompositeBuilder with the name of the composite op and its attributes
        # Note: To capture the value of non-tensor inputs, please pass them as attributes to the builder
        self.b = StableHLOCompositeBuilder("test.sdpa", {"scale": 0.25, "other_attr": "val"})

    def forward(self, x):
        q = self.q_proj(x)
        k = self.k_proj(x)
        v = self.v_proj(x)
        q, k, v = self.b.mark_inputs(q, k, v)
        attn_out = F.scaled_dot_product_attention(q, k, v, scale=0.25)
        attn_out = self.b.mark_outputs(attn_out)
        attn_out = attn_out + x
        return attn_out

input_args = (torch.randn((10, 8, 128)), )
# torch.export to Exported Program
exported = torch.export.export(M(), input_args)
# Exported Program to StableHLO
stablehlo_gm = stablehlo.exported_program_to_stablehlo(exported)
stablehlo = stablehlo_gm.get_stablehlo_text()
print(stablehlo)

主要的 StableHLO 图如下所示:

module @IrToHlo.56 attributes {mhlo.cross_program_prefetches = [], mhlo.input_output_alias = [], mhlo.is_dynamic = false, mhlo.use_auto_spmd_partitioning = false} {
  func.func @main(%arg0: tensor<10x8x128xf32>, %arg1: tensor<128x128xf32>, %arg2: tensor<128x128xf32>, %arg3: tensor<128x128xf32>) -> tensor<10x8x128xf32> {
    ...
    %10 = stablehlo.composite "test.sdpa" %3, %6, %9 {composite_attributes = {other_attr = "val", scale = 2.500000e-01 : f32}, decomposition = @test.sdpa.impl} : (tensor<10x8x128xf32>, tensor<10x8x128xf32>, tensor<10x8x128xf32>) -> tensor<10x8x128xf32>
    %11 = stablehlo.add %10, %arg0 : tensor<10x8x128xf32>
    return %11 : tensor<10x8x128xf32>
  }

  func.func private @test.sdpa.impl(%arg0: tensor<10x8x128xf32>, %arg1: tensor<10x8x128xf32>, %arg2: tensor<10x8x128xf32>) -> tensor<10x8x128xf32> {
    // Actual implementation of the composite
    ...
    return %11 : tensor<10x8x128xf32>
  }

sdpa 操作在主图中被封装为一个 stablehlo composite 调用。torch.nn.Module 中指定的名称和属性会被传播。

%10 = stablehlo.composite "test.sdpa" %3, %6, %9 {composite_attributes = {other_attr = "val", scale = 2.500000e-01 : f32}, decomposition = @test.sdpa.impl}

sdpa 操作的 PyTorch 参考分解被捕获在 StableHLO 函数中

func.func private @test.sdpa.impl(%arg0: tensor<10x8x128xf32>, %arg1: tensor<10x8x128xf32>, %arg2: tensor<10x8x128xf32>) -> tensor<10x8x128xf32> {
    // Actual implementation of the composite
    ...
    return %11 : tensor<10x8x128xf32>
  }

文档

访问 PyTorch 的全面开发者文档

查看文档

教程

获取针对初学者和高级开发者的深度教程

查看教程

资源

查找开发资源并获得问题解答

查看资源