注意
点击这里下载完整示例代码
优化视觉Transformer模型以进行部署¶
视觉Transformer模型将自然语言处理中引入的尖端基于注意力的Transformer模型应用于计算机视觉任务,以实现各种最先进(SOTA)的结果。Facebook数据高效图像Transformer DeiT是一个在ImageNet上训练用于图像分类的视觉Transformer模型。
在本教程中,我们将首先介绍DeiT是什么以及如何使用它,然后介绍脚本、量化、优化以及在iOS和Android应用程序中使用模型的完整步骤。我们还将比较量化、优化和非量化、非优化模型的性能,并展示在模型中应用量化和优化的优势。
什么是DeiT¶
自2012年深度学习兴起以来,卷积神经网络(CNN)一直是图像分类的主要模型,但CNN通常需要数亿张图像才能在训练中达到最先进的结果。DeiT是一种视觉Transformer模型,它在训练中需要更少的数据和计算资源,从而与领先的CNN在图像分类方面的性能相媲美,这是DeiT的两个关键组成部分所实现的。
模拟在更大数据集上训练的数据增强;
原生蒸馏,使Transformer网络能够从CNN的输出中学习。
DeiT表明,Transformer可以在数据和资源有限的情况下成功应用于计算机视觉任务。有关DeiT的更多详细信息,请参阅repo和paper.
使用DeiT对图像进行分类¶
请遵循DeiT存储库中的README.md
,以获取有关如何使用DeiT对图像进行分类的详细信息,或者进行快速测试,首先安装所需的软件包
pip install torch torchvision timm pandas requests
要在Google Colab中运行,请通过运行以下命令安装依赖项
!pip install timm pandas requests
然后运行下面的脚本
from PIL import Image
import torch
import timm
import requests
import torchvision.transforms as transforms
from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
print(torch.__version__)
# should be 1.8.0
model = torch.hub.load('facebookresearch/deit:main', 'deit_base_patch16_224', pretrained=True)
model.eval()
transform = transforms.Compose([
transforms.Resize(256, interpolation=3),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD),
])
img = Image.open(requests.get("https://raw.githubusercontent.com/pytorch/ios-demo-app/master/HelloWorld/HelloWorld/HelloWorld/image.png", stream=True).raw)
img = transform(img)[None,]
out = model(img)
clsidx = torch.argmax(out)
print(clsidx.item())
2.5.0+cu124
Downloading: "https://github.com/facebookresearch/deit/zipball/main" to /var/lib/ci-user/.cache/torch/hub/main.zip
/var/lib/ci-user/.cache/torch/hub/facebookresearch_deit_main/models.py:63: UserWarning:
Overwriting deit_tiny_patch16_224 in registry with models.deit_tiny_patch16_224. This is because the name being registered conflicts with an existing name. Please check if this is not expected.
/var/lib/ci-user/.cache/torch/hub/facebookresearch_deit_main/models.py:78: UserWarning:
Overwriting deit_small_patch16_224 in registry with models.deit_small_patch16_224. This is because the name being registered conflicts with an existing name. Please check if this is not expected.
/var/lib/ci-user/.cache/torch/hub/facebookresearch_deit_main/models.py:93: UserWarning:
Overwriting deit_base_patch16_224 in registry with models.deit_base_patch16_224. This is because the name being registered conflicts with an existing name. Please check if this is not expected.
/var/lib/ci-user/.cache/torch/hub/facebookresearch_deit_main/models.py:108: UserWarning:
Overwriting deit_tiny_distilled_patch16_224 in registry with models.deit_tiny_distilled_patch16_224. This is because the name being registered conflicts with an existing name. Please check if this is not expected.
/var/lib/ci-user/.cache/torch/hub/facebookresearch_deit_main/models.py:123: UserWarning:
Overwriting deit_small_distilled_patch16_224 in registry with models.deit_small_distilled_patch16_224. This is because the name being registered conflicts with an existing name. Please check if this is not expected.
/var/lib/ci-user/.cache/torch/hub/facebookresearch_deit_main/models.py:138: UserWarning:
Overwriting deit_base_distilled_patch16_224 in registry with models.deit_base_distilled_patch16_224. This is because the name being registered conflicts with an existing name. Please check if this is not expected.
/var/lib/ci-user/.cache/torch/hub/facebookresearch_deit_main/models.py:153: UserWarning:
Overwriting deit_base_patch16_384 in registry with models.deit_base_patch16_384. This is because the name being registered conflicts with an existing name. Please check if this is not expected.
/var/lib/ci-user/.cache/torch/hub/facebookresearch_deit_main/models.py:168: UserWarning:
Overwriting deit_base_distilled_patch16_384 in registry with models.deit_base_distilled_patch16_384. This is because the name being registered conflicts with an existing name. Please check if this is not expected.
Downloading: "https://dl.fbaipublicfiles.com/deit/deit_base_patch16_224-b5f2ef4d.pth" to /var/lib/ci-user/.cache/torch/hub/checkpoints/deit_base_patch16_224-b5f2ef4d.pth
0%| | 0.00/330M [00:00<?, ?B/s]
5%|4 | 15.5M/330M [00:00<00:02, 162MB/s]
9%|9 | 31.4M/330M [00:00<00:01, 164MB/s]
14%|#4 | 47.5M/330M [00:00<00:01, 166MB/s]
19%|#9 | 63.4M/330M [00:00<00:01, 166MB/s]
24%|##4 | 80.0M/330M [00:00<00:01, 169MB/s]
29%|##9 | 97.1M/330M [00:00<00:01, 172MB/s]
35%|###4 | 114M/330M [00:00<00:01, 175MB/s]
40%|###9 | 132M/330M [00:00<00:01, 176MB/s]
45%|####5 | 149M/330M [00:00<00:01, 178MB/s]
50%|##### | 166M/330M [00:01<00:00, 178MB/s]
56%|#####5 | 183M/330M [00:01<00:00, 179MB/s]
61%|###### | 201M/330M [00:01<00:00, 179MB/s]
66%|######5 | 218M/330M [00:01<00:00, 180MB/s]
71%|#######1 | 235M/330M [00:01<00:00, 180MB/s]
76%|#######6 | 253M/330M [00:01<00:00, 181MB/s]
82%|########1 | 271M/330M [00:01<00:00, 183MB/s]
87%|########7 | 289M/330M [00:01<00:00, 185MB/s]
93%|#########2| 307M/330M [00:01<00:00, 186MB/s]
98%|#########8| 325M/330M [00:01<00:00, 187MB/s]
100%|##########| 330M/330M [00:01<00:00, 179MB/s]
269
输出应该为269,根据ImageNet的类索引到标签文件,映射到timber wolf, grey wolf, gray wolf, Canis lupus
。
现在我们已经验证可以使用DeiT模型对图像进行分类,让我们看看如何修改该模型,使其能够在iOS和Android应用程序上运行。
编写DeiT脚本¶
要在移动设备上使用该模型,我们首先需要编写该模型的脚本。请参阅脚本和优化食谱,以了解快速概述。运行以下代码将上一步骤中使用的DeiT模型转换为可在移动设备上运行的TorchScript格式。
model = torch.hub.load('facebookresearch/deit:main', 'deit_base_patch16_224', pretrained=True)
model.eval()
scripted_model = torch.jit.script(model)
scripted_model.save("fbdeit_scripted.pt")
Using cache found in /var/lib/ci-user/.cache/torch/hub/facebookresearch_deit_main
生成了大约346MB大小的脚本化模型文件fbdeit_scripted.pt
。
量化DeiT¶
为了在保持推理精度大致相同的情况下显着减小训练后的模型大小,可以对模型应用量化。由于DeiT中使用的Transformer模型,我们可以轻松地对模型应用动态量化,因为动态量化最适合LSTM和Transformer模型(有关更多详细信息,请参阅这里)。
现在运行下面的代码
# Use 'x86' for server inference (the old 'fbgemm' is still available but 'x86' is the recommended default) and ``qnnpack`` for mobile inference.
backend = "x86" # replaced with ``qnnpack`` causing much worse inference speed for quantized model on this notebook
model.qconfig = torch.quantization.get_default_qconfig(backend)
torch.backends.quantized.engine = backend
quantized_model = torch.quantization.quantize_dynamic(model, qconfig_spec={torch.nn.Linear}, dtype=torch.qint8)
scripted_quantized_model = torch.jit.script(quantized_model)
scripted_quantized_model.save("fbdeit_scripted_quantized.pt")
/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/ao/quantization/observer.py:229: UserWarning:
Please use quant_min and quant_max to specify the range for observers. reduce_range will be deprecated in a future release of PyTorch.
这将生成模型的脚本化和量化版本fbdeit_quantized_scripted.pt
,大小约为89MB,比未量化模型的346MB大小减少了74%!
您可以使用scripted_quantized_model
来生成相同的推理结果
269
优化DeiT¶
在移动设备上使用量化和脚本化的模型之前,最后的步骤是对它进行优化
from torch.utils.mobile_optimizer import optimize_for_mobile
optimized_scripted_quantized_model = optimize_for_mobile(scripted_quantized_model)
optimized_scripted_quantized_model.save("fbdeit_optimized_scripted_quantized.pt")
生成的fbdeit_optimized_scripted_quantized.pt
文件的大小与量化、脚本化但未优化的模型的大小大致相同。推理结果保持不变。
269
使用Lite Interpreter¶
为了了解Lite Interpreter可以带来的模型大小减少和推理速度提升,让我们创建模型的Lite版本。
optimized_scripted_quantized_model._save_for_lite_interpreter("fbdeit_optimized_scripted_quantized_lite.ptl")
ptl = torch.jit.load("fbdeit_optimized_scripted_quantized_lite.ptl")
虽然Lite模型的大小与非Lite版本相当,但在移动设备上运行Lite版本时,预计推理速度会有所提升。
比较推理速度¶
为了了解四种模型的推理速度差异 - 原始模型、脚本化模型、量化和脚本化模型、优化、量化和脚本化模型 - 请运行以下代码
with torch.autograd.profiler.profile(use_cuda=False) as prof1:
out = model(img)
with torch.autograd.profiler.profile(use_cuda=False) as prof2:
out = scripted_model(img)
with torch.autograd.profiler.profile(use_cuda=False) as prof3:
out = scripted_quantized_model(img)
with torch.autograd.profiler.profile(use_cuda=False) as prof4:
out = optimized_scripted_quantized_model(img)
with torch.autograd.profiler.profile(use_cuda=False) as prof5:
out = ptl(img)
print("original model: {:.2f}ms".format(prof1.self_cpu_time_total/1000))
print("scripted model: {:.2f}ms".format(prof2.self_cpu_time_total/1000))
print("scripted & quantized model: {:.2f}ms".format(prof3.self_cpu_time_total/1000))
print("scripted & quantized & optimized model: {:.2f}ms".format(prof4.self_cpu_time_total/1000))
print("lite model: {:.2f}ms".format(prof5.self_cpu_time_total/1000))
original model: 165.32ms
scripted model: 137.03ms
scripted & quantized model: 126.08ms
scripted & quantized & optimized model: 137.81ms
lite model: 149.98ms
在Google Colab上运行的结果是
original model: 1236.69ms
scripted model: 1226.72ms
scripted & quantized model: 593.19ms
scripted & quantized & optimized model: 598.01ms
lite model: 600.72ms
以下结果总结了每个模型所花费的推理时间以及每个模型相对于原始模型的百分比减少。
import pandas as pd
import numpy as np
df = pd.DataFrame({'Model': ['original model','scripted model', 'scripted & quantized model', 'scripted & quantized & optimized model', 'lite model']})
df = pd.concat([df, pd.DataFrame([
["{:.2f}ms".format(prof1.self_cpu_time_total/1000), "0%"],
["{:.2f}ms".format(prof2.self_cpu_time_total/1000),
"{:.2f}%".format((prof1.self_cpu_time_total-prof2.self_cpu_time_total)/prof1.self_cpu_time_total*100)],
["{:.2f}ms".format(prof3.self_cpu_time_total/1000),
"{:.2f}%".format((prof1.self_cpu_time_total-prof3.self_cpu_time_total)/prof1.self_cpu_time_total*100)],
["{:.2f}ms".format(prof4.self_cpu_time_total/1000),
"{:.2f}%".format((prof1.self_cpu_time_total-prof4.self_cpu_time_total)/prof1.self_cpu_time_total*100)],
["{:.2f}ms".format(prof5.self_cpu_time_total/1000),
"{:.2f}%".format((prof1.self_cpu_time_total-prof5.self_cpu_time_total)/prof1.self_cpu_time_total*100)]],
columns=['Inference Time', 'Reduction'])], axis=1)
print(df)
"""
Model Inference Time Reduction
0 original model 1236.69ms 0%
1 scripted model 1226.72ms 0.81%
2 scripted & quantized model 593.19ms 52.03%
3 scripted & quantized & optimized model 598.01ms 51.64%
4 lite model 600.72ms 51.43%
"""
Model ... Reduction
0 original model ... 0%
1 scripted model ... 17.11%
2 scripted & quantized model ... 23.73%
3 scripted & quantized & optimized model ... 16.64%
4 lite model ... 9.28%
[5 rows x 3 columns]
'\n Model Inference Time Reduction\n0\toriginal model 1236.69ms 0%\n1\tscripted model 1226.72ms 0.81%\n2\tscripted & quantized model 593.19ms 52.03%\n3\tscripted & quantized & optimized model 598.01ms 51.64%\n4\tlite model 600.72ms 51.43%\n'