torch.export 流程、常见挑战及解决方案演示¶
作者: Ankith Gunapal, Jordi Ramon, Marcos Carranza
在`torch.export` 入门教程中,我们学习了如何使用torch.export。本教程在前一个教程的基础上进行扩展,探讨了使用代码导出常用模型的过程,并解决了使用torch.export
时可能出现的一些常见挑战。
在本教程中,您将学习如何导出适用于以下用例的模型:
视频分类器(MViT)
自动语音识别(OpenAI Whisper-Tiny)
图像字幕生成(BLIP)
可提示式图像分割(SAM2)
选择这四种模型是为了演示 torch.export 的独特功能,以及实现中遇到的一些实际考量和问题。
先决条件¶
PyTorch 2.4 或更高版本
对
torch.export
和 PyTorch 即时推理有基本了解。
`torch.export` 的关键要求:无图断裂¶
torch.compile 通过使用 JIT 将 PyTorch 代码编译成优化的内核来加速 PyTorch 代码。它使用 TorchDynamo
优化给定的模型,并创建优化的图,然后使用 API 中指定的后端将其转换为硬件特定的代码(lowered)。当 TorchDynamo 遇到不支持的 Python 特性时,它会中断计算图,让默认的 Python 解释器处理不支持的代码,然后恢复图捕获。计算图中的这种中断称为图断裂。
`torch.export` 和 torch.compile
的关键区别之一在于 torch.export
不支持图断裂,这意味着您要导出的整个模型或部分模型需要是一个单一的图。这是因为处理图断裂涉及使用默认 Python 评估来解释不受支持的操作,这与 torch.export
的设计目的不兼容。您可以在此链接中阅读有关不同 PyTorch 框架之间差异的详细信息
您可以使用以下命令来识别程序中的图断裂:
TORCH_LOGS="graph_breaks" python <file_name>.py
您需要修改程序以消除图断裂。解决后,您就可以导出模型了。PyTorch 每晚都会在流行的 HuggingFace 和 TIMM 模型上运行 torch.compile 的基准测试。这些模型中的大多数都没有图断裂。
本代码示例集中的模型没有图断裂,但在使用 torch.export 时失败。
视频分类¶
MViT 是一类基于 MultiScale Vision Transformers 的模型。该模型已使用 Kinetics-400 数据集训练用于视频分类。该模型结合相关数据集可用于游戏场景中的动作识别。
下面的代码通过使用 batch_size=2
进行跟踪来导出 MViT,然后检查导出的程序 (ExportedProgram
) 是否可以使用 batch_size=4
运行。
import numpy as np
import torch
from torchvision.models.video import MViT_V1_B_Weights, mvit_v1_b
import traceback as tb
model = mvit_v1_b(weights=MViT_V1_B_Weights.DEFAULT)
# Create a batch of 2 videos, each with 16 frames of shape 224x224x3.
input_frames = torch.randn(2,16, 224, 224, 3)
# Transpose to get [1, 3, num_clips, height, width].
input_frames = np.transpose(input_frames, (0, 4, 1, 2, 3))
# Export the model.
exported_program = torch.export.export(
model,
(input_frames,),
)
# Create a batch of 4 videos, each with 16 frames of shape 224x224x3.
input_frames = torch.randn(4,16, 224, 224, 3)
input_frames = np.transpose(input_frames, (0, 4, 1, 2, 3))
try:
exported_program.module()(input_frames)
except Exception:
tb.print_exc()
错误:静态批量大小¶
raise RuntimeError(
RuntimeError: Expected input at *args[0].shape[0] to be equal to 2, but got 4
默认情况下,导出流程会在跟踪程序时假定所有输入形状都是静态的,因此如果您使用与跟踪时不同的输入形状运行程序,将会遇到错误。
解决方案¶
为解决该错误,我们将输入的第一维(batch_size
)指定为动态的,并指定 batch_size
的预期范围。在下面显示的更正后的示例中,我们指定预期的 batch_size
范围为 1 到 16。需要注意的一个细节是 min=2
并不是一个 bug,并在The 0/1 Specialization Problem中进行了解释。关于 torch.export
动态形状的详细说明可以在导出教程中找到。下面显示的代码演示了如何导出具有动态批量大小的 mViT:
import numpy as np
import torch
from torchvision.models.video import MViT_V1_B_Weights, mvit_v1_b
import traceback as tb
model = mvit_v1_b(weights=MViT_V1_B_Weights.DEFAULT)
# Create a batch of 2 videos, each with 16 frames of shape 224x224x3.
input_frames = torch.randn(2,16, 224, 224, 3)
# Transpose to get [1, 3, num_clips, height, width].
input_frames = np.transpose(input_frames, (0, 4, 1, 2, 3))
# Export the model.
batch_dim = torch.export.Dim("batch", min=2, max=16)
exported_program = torch.export.export(
model,
(input_frames,),
# Specify the first dimension of the input x as dynamic
dynamic_shapes={"x": {0: batch_dim}},
)
# Create a batch of 4 videos, each with 16 frames of shape 224x224x3.
input_frames = torch.randn(4,16, 224, 224, 3)
input_frames = np.transpose(input_frames, (0, 4, 1, 2, 3))
try:
exported_program.module()(input_frames)
except Exception:
tb.print_exc()
自动语音识别¶
自动语音识别 (ASR) 是指使用机器学习将口语转录成文本。Whisper 是 OpenAI 基于 Transformer 的编码器-解码器模型,该模型使用 68 万小时的标注数据进行 ASR 和语音翻译的训练。下面的代码尝试导出用于 ASR 的 whisper-tiny
模型。
import torch
from transformers import WhisperProcessor, WhisperForConditionalGeneration
from datasets import load_dataset
# load model
model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny")
# dummy inputs for exporting the model
input_features = torch.randn(1,80, 3000)
attention_mask = torch.ones(1, 3000)
decoder_input_ids = torch.tensor([[1, 1, 1 , 1]]) * model.config.decoder_start_token_id
model.eval()
exported_program: torch.export.ExportedProgram= torch.export.export(model, args=(input_features, attention_mask, decoder_input_ids,))
错误:使用 TorchDynamo 进行严格跟踪¶
torch._dynamo.exc.InternalTorchDynamoError: AttributeError: 'DynamicCache' object has no attribute 'key_cache'
默认情况下,torch.export
使用TorchDynamo(一种字节码分析引擎)跟踪您的代码,它对您的代码进行符号分析并构建图。这种分析提供了更强的安全性保证,但并非所有 Python 代码都受支持。当我们使用默认的严格模式导出 whisper-tiny
模型时,通常会在 Dynamo 中因不支持的特性而返回错误。要了解为何在 Dynamo 中出现此错误,您可以参考此GitHub issue。
解决方案¶
为解决上述错误,torch.export
支持 non_strict
模式,在该模式下,程序使用 Python 解释器进行跟踪,其工作方式类似于 PyTorch 即时执行。唯一的区别是所有 Tensor
对象都将被 ProxyTensors
替换,ProxyTensors
会将它们的所有操作记录到一个图中。通过使用 strict=False
,我们能够导出程序。
import torch
from transformers import WhisperProcessor, WhisperForConditionalGeneration
from datasets import load_dataset
# load model
model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny")
# dummy inputs for exporting the model
input_features = torch.randn(1,80, 3000)
attention_mask = torch.ones(1, 3000)
decoder_input_ids = torch.tensor([[1, 1, 1 , 1]]) * model.config.decoder_start_token_id
model.eval()
exported_program: torch.export.ExportedProgram= torch.export.export(model, args=(input_features, attention_mask, decoder_input_ids,), strict=False)
图像字幕生成¶
图像字幕生成 是将图像内容用文字描述的任务。在游戏场景中,图像字幕生成可用于通过动态生成场景中各种游戏对象的文本描述来增强游戏体验,从而为玩家提供额外细节。BLIP 是 SalesForce Research 发布的一个流行的图像字幕生成模型。下面的代码尝试使用 batch_size=1
导出 BLIP。
import torch
from models.blip import blip_decoder
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
image_size = 384
image = torch.randn(1, 3,384,384).to(device)
caption_input = ""
model_url = 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_capfilt_large.pth'
model = blip_decoder(pretrained=model_url, image_size=image_size, vit='base')
model.eval()
model = model.to(device)
exported_program: torch.export.ExportedProgram= torch.export.export(model, args=(image,caption_input,), strict=False)
错误:无法修改具有冻结存储的张量¶
在导出模型时,可能会失败,因为模型实现可能包含 torch.export
尚未支持的某些 Python 操作。其中一些失败可能有变通方法。BLIP 就是一个例子,原始模型会出错,可以通过对代码进行少量修改来解决。torch.export
在 ExportDB 中列出了支持和不支持操作的常见情况,并展示了如何修改代码以使其兼容导出。
File "/BLIP/models/blip.py", line 112, in forward
text.input_ids[:,0] = self.tokenizer.bos_token_id
File "/anaconda3/envs/export/lib/python3.10/site-packages/torch/_subclasses/functional_tensor.py", line 545, in __torch_dispatch__
outs_unwrapped = func._op_dk(
RuntimeError: cannot mutate tensors with frozen storage
可提示式图像分割¶
图像分割是一种计算机视觉技术,根据像素的特征将其分为不同的像素组或区域(segments)。Segment Anything Model (SAM)) 引入了可提示式图像分割,它根据指示所需对象的提示来预测对象掩码。SAM 2 是第一个用于跨图像和视频分割对象的统一模型。SAM2ImagePredictor 类为模型提供了一个简单的接口,用于对模型进行提示。该模型可以同时接受点和框提示作为输入,也可以接受前一次预测产生的掩码作为输入。由于 SAM2 在目标跟踪方面提供了强大的零样本性能,因此可用于跟踪场景中的游戏对象。
在 SAM2ImagePredictor 的 predict 方法中的张量操作在 _predict 方法中进行。因此,我们尝试像这样导出:
ep = torch.export.export(
self._predict,
args=(unnorm_coords, labels, unnorm_box, mask_input, multimask_output),
kwargs={"return_logits": return_logits},
strict=False,
)
错误:模型不是 torch.nn.Module
类型¶
torch.export
要求模块类型为 torch.nn.Module
。然而,我们尝试导出的模块是一个类方法。因此会出错。
Traceback (most recent call last):
File "/sam2/image_predict.py", line 20, in <module>
masks, scores, _ = predictor.predict(
File "/sam2/sam2/sam2_image_predictor.py", line 312, in predict
ep = torch.export.export(
File "python3.10/site-packages/torch/export/__init__.py", line 359, in export
raise ValueError(
ValueError: Expected `mod` to be an instance of `torch.nn.Module`, got <class 'method'>.
解决方案¶
我们编写一个辅助类,该类继承自 torch.nn.Module
,并在该类的 forward
方法中调用 _predict 方法
。完整的代码可以在这里找到。
class ExportHelper(torch.nn.Module):
def __init__(self):
super().__init__()
def forward(_, *args, **kwargs):
return self._predict(*args, **kwargs)
model_to_export = ExportHelper()
ep = torch.export.export(
model_to_export,
args=(unnorm_coords, labels, unnorm_box, mask_input, multimask_output),
kwargs={"return_logits": return_logits},
strict=False,
)
结论¶
在本教程中,我们学习了如何通过正确的配置和简单的代码修改来解决挑战,从而使用 torch.export
导出适用于流行用例的模型。成功导出模型后,对于服务器,您可以使用 AOTInductor 将 ExportedProgram
转换为硬件可执行代码;对于边缘设备,则可以使用 ExecuTorch。要了解有关 AOTInductor
(AOTI) 的更多信息,请参考 AOTI 教程。要了解有关 ExecuTorch
的更多信息,请参考 ExecuTorch 教程。