Torch 导出到 StableHLO¶
本文档介绍了如何使用 torch export + torch xla 导出到 StableHLO 格式。
from torch.export import export
from torch_xla.stablehlo import exported_program_to_stablehlo
import torch_xla.core.xla_model as xm
import torchvision
import torch
xla_device = xm.xla_device()
resnet18 = torchvision.models.resnet18()
# Sample input is a tuple
sample_input = (torch.randn(4, 3, 224, 224), )
output = resnet18(*sample_input)
exported = export(resnet18, sample_input)
stablehlo_program = exported_program_to_stablehlo(exported)
# Now stablehlo_program is a callable backed by stablehlo IR.
# we can see it's stablehlo code with
# here 'forward' is the name of function. Currently we only support
# one entry point per program, but in the future we will support
# multiple entry points in a program.
print(stablehlo_program.get_stablehlo_text('forward'))
# we can also print out the bytecode
print(stablehlo_program.get_stablehlo_bytecode('forward'))
# we can also run the module, to run the stablehlo module, we need to move
# our tensors to XLA device.
sample_input_xla = tuple(s.to(xla_device) for s in sample_input)
output2 = stablehlo_program(*sample_input_xla)
print(torch.allclose(output, output2.cpu(), atol=1e-5))
将 StableHLO 字节码保存到磁盘¶
现在可以使用以下命令将 stablehlo 保存到磁盘:
stablehlo_program.save('/tmp/stablehlo_dir')
路径应该是空目录的路径。如果目录不存在,则会创建它。此目录可以作为另一个 stablehlo_program 重新加载
from torch_xla.stablehlo import StableHLOGraphModule
stablehlo_program2 = StableHLOGraphModule.load('/tmp/stablehlo_dir')
output3 = stablehlo_program2(*sample_input_xla)
转换保存的 StableHLO 以进行 Serving¶
StableHLO 是一种开放格式,并且 tensorflow.serving 模型服务器支持使用该格式进行 Serving。但是,在将其提供给 tf.serving 之前,我们需要首先将生成的 StableHLO 字节码包装成 tf.saved_model
格式。
为此,首先确保您在当前的 python 环境中安装了最新的 tensorflow
,如果没有,请使用以下命令安装:
pip install tf-nightly
现在,您可以运行转换器(在 torch/xla 安装中提供):
stablehlo-to-saved-model /tmp/stablehlo_dir /tmp/resnet_tf/1
之后,您可以使用 tf serving 二进制文件在新生成的 tf.saved_model
上运行模型服务器。
docker pull tensorflow/serving
docker run -p 8500:8500 \
--mount type=bind,source=/tmp/resnet_tf,target=/models/resnet_tf \
-e MODEL_NAME=resnet_tf -t tensorflow/serving &
您也可以直接使用 tf.serving
二进制文件,而无需 docker。有关更多详细信息,请遵循 tf serving 指南。
常用包装器¶
我想直接保存 tf.saved_model 格式,而无需运行单独的命令。¶
您可以使用此辅助函数来实现此目的
from torch_xla.tf_saved_model_integration import save_torch_module_as_tf_saved_model
save_torch_module_as_tf_saved_model(
resnet18, # original pytorch torch.nn.Module
sample_inputs, # sample inputs used to trace
'/tmp/resnet_tf' # directory for tf.saved_model
)
其他常用包装器¶
def save_as_stablehlo(exported_model: 'ExportedProgram',
stablehlo_dir: os.PathLike,
options: Optional[StableHLOExportOptions] = None):
save_as_stablehlo
(也别名为 torch_xla.save_as_stablehlo
)接受 ExportedProgram 并在磁盘上保存 StableHLO。即与 exported_program_to_stablehlo(…).save(…) 相同
def save_torch_model_as_stablehlo(
torchmodel: torch.nn.Module,
args: Tuple[Any],
path: os.PathLike,
options: Optional[StableHLOExportOptions] = None) -> None:
"""Convert a torch model to a callable backed by StableHLO.
接受 torch.nn.Module
并在磁盘上保存 StableHLO。即与 torch.export.export 后跟 save_as_stablehlo 相同
save_as_stablehlo
生成的文件。¶
在上面示例中的 /tmp/stablehlo_dir
内部,您将找到 3 个目录:data
、constants
、functions
。data 和 constants 都将包含程序使用的张量,这些张量使用 numpy.save
保存为 numpy.ndarray
。
functions 目录将包含 StableHLO 字节码,此处命名为 forward.bytecode
,人类可读的 StableHLO 代码(MLIR 格式)forward.mlir
,以及一个 JSON 文件,该文件指定了哪些权重和原始用户的输入成为此 StableHLO 函数的哪些位置参数;以及每个参数的 dtypes 和形状。
示例
$ find /tmp/stablehlo_dir
./functions
./functions/forward.mlir
./functions/forward.bytecode
./functions/forward.meta
./constants
./constants/3
./constants/1
./constants/0
./constants/2
./data
./data/L__fn___layers_15_feed_forward_w2.weight
./data/L__fn___layers_13_feed_forward_w1.weight
./data/L__fn___layers_3_attention_wo.weight
./data/L__fn___layers_12_ffn_norm_weight
./data/L__fn___layers_25_attention_wo.weight
...
JSON 文件是 torch_xla.stablehlo.StableHLOFunc
类的序列化形式。
此格式目前也处于原型阶段,并且不保证向后兼容性。未来的计划是标准化主要框架(PyTorch、JAX、TensorFlow)可以接受的格式。
通过生成 stablehlo.composite
在 StableHLO 中保留高级 PyTorch 操作¶
高级 PyTorch 操作(例如 F.scaled_dot_product_attention
)将在 PyTorch -> StableHLO 降低期间分解为低级操作。在下游 ML 编译器中捕获高级操作对于生成高性能、高效的专用内核至关重要。虽然在 ML 编译器中模式匹配一堆低级操作可能具有挑战性且容易出错,但我们提供了一种更强大的方法来在 StableHLO 程序中概述高级 PyTorch 操作 - 通过为高级 PyTorch 操作生成 stablehlo.composite。
使用 StableHLOCompositeBuilder
,用户可以概述 torch.nn.Module
的 forward
函数内的任意区域。然后在导出的 StableHLO 程序中,将为概述区域生成一个 composite 操作。
注意: 由于非张量输入到概述区域的值将硬编码在导出的图中,如果希望从下游编译器检索这些值,请将这些值存储为 composite 属性。
以下示例显示了一个实际用例 - 捕获 scaled_product_attention
import torch
import torch.nn.functional as F
from torch_xla import stablehlo
from torch_xla.experimental.mark_pattern_utils import StableHLOCompositeBuilder
class M(torch.nn.Module):
def __init__(self):
super().__init__()
self.q_proj = torch.nn.Linear(128, 128, bias=False)
self.k_proj = torch.nn.Linear(128, 128, bias=False)
self.v_proj = torch.nn.Linear(128, 128, bias=False)
# Initialize the StableHLOCompositeBuilder with the name of the composite op and its attributes
# Note: To capture the value of non-tensor inputs, please pass them as attributes to the builder
self.b = StableHLOCompositeBuilder("test.sdpa", {"scale": 0.25, "other_attr": "val"})
def forward(self, x):
q = self.q_proj(x)
k = self.k_proj(x)
v = self.v_proj(x)
q, k, v = self.b.mark_inputs(q, k, v)
attn_out = F.scaled_dot_product_attention(q, k, v, scale=0.25)
attn_out = self.b.mark_outputs(attn_out)
attn_out = attn_out + x
return attn_out
input_args = (torch.randn((10, 8, 128)), )
# torch.export to Exported Program
exported = torch.export.export(M(), input_args)
# Exported Program to StableHLO
stablehlo_gm = stablehlo.exported_program_to_stablehlo(exported)
stablehlo = stablehlo_gm.get_stablehlo_text()
print(stablehlo)
主 StableHLO 图如下所示
module @IrToHlo.56 attributes {mhlo.cross_program_prefetches = [], mhlo.input_output_alias = [], mhlo.is_dynamic = false, mhlo.use_auto_spmd_partitioning = false} {
func.func @main(%arg0: tensor<10x8x128xf32>, %arg1: tensor<128x128xf32>, %arg2: tensor<128x128xf32>, %arg3: tensor<128x128xf32>) -> tensor<10x8x128xf32> {
...
%10 = stablehlo.composite "test.sdpa" %3, %6, %9 {composite_attributes = {other_attr = "val", scale = 2.500000e-01 : f32}, decomposition = @test.sdpa.impl} : (tensor<10x8x128xf32>, tensor<10x8x128xf32>, tensor<10x8x128xf32>) -> tensor<10x8x128xf32>
%11 = stablehlo.add %10, %arg0 : tensor<10x8x128xf32>
return %11 : tensor<10x8x128xf32>
}
func.func private @test.sdpa.impl(%arg0: tensor<10x8x128xf32>, %arg1: tensor<10x8x128xf32>, %arg2: tensor<10x8x128xf32>) -> tensor<10x8x128xf32> {
// Actual implementation of the composite
...
return %11 : tensor<10x8x128xf32>
}
sdpa 操作被封装为主图中的 stablehlo composite 调用。torch.nn.Module 中指定的名称和属性会得到传播。
%10 = stablehlo.composite "test.sdpa" %3, %6, %9 {composite_attributes = {other_attr = "val", scale = 2.500000e-01 : f32}, decomposition = @test.sdpa.impl}
sdpa 操作的参考 PyTorch 分解被捕获在 StableHLO 函数中
func.func private @test.sdpa.impl(%arg0: tensor<10x8x128xf32>, %arg1: tensor<10x8x128xf32>, %arg2: tensor<10x8x128xf32>) -> tensor<10x8x128xf32> {
// Actual implementation of the composite
...
return %11 : tensor<10x8x128xf32>
}