注意
点击此处下载完整的示例代码
部署优化的 Vision Transformer 模型¶
创建于:2021 年 3 月 15 日 | 最后更新:2024 年 1 月 19 日 | 最后验证:2024 年 11 月 05 日
Vision Transformer 模型应用了尖端的基于注意力机制的 Transformer 模型,这些模型在自然语言处理中引入,以在计算机视觉任务中实现各种最先进 (SOTA) 的结果。Facebook 数据高效图像 Transformer DeiT 是在 ImageNet 上训练用于图像分类的 Vision Transformer 模型。
在本教程中,我们将首先介绍 DeiT 是什么以及如何使用它,然后逐步完成脚本编写、量化、优化以及在 iOS 和 Android 应用程序中使用模型的完整步骤。我们还将比较量化、优化和非量化、非优化模型的性能,并展示在各个步骤中对模型应用量化和优化的好处。
什么是 DeiT¶
自 2012 年深度学习兴起以来,卷积神经网络 (CNN) 一直是图像分类的主要模型,但 CNN 通常需要数亿张图像进行训练才能达到 SOTA 结果。DeiT 是一种 Vision Transformer 模型,它需要更少的数据和计算资源进行训练,即可与领先的 CNN 在执行图像分类方面竞争,这得益于 DeiT 的两个关键组成部分
模拟在更大的数据集上训练的数据增强;
允许 Transformer 网络从 CNN 的输出中学习的本地蒸馏。
DeiT 表明,Transformer 可以成功应用于计算机视觉任务,并且在数据和资源有限的情况下。有关 DeiT 的更多详细信息,请参阅 repo 和 paper。
使用 DeiT 对图像进行分类¶
请按照 DeiT 仓库中的 README.md
了解有关如何使用 DeiT 对图像进行分类的详细信息,或者进行快速测试,首先安装所需的软件包
pip install torch torchvision timm pandas requests
要在 Google Colab 中运行,请运行以下命令安装依赖项
!pip install timm pandas requests
然后运行以下脚本
from PIL import Image
import torch
import timm
import requests
import torchvision.transforms as transforms
from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
print(torch.__version__)
# should be 1.8.0
model = torch.hub.load('facebookresearch/deit:main', 'deit_base_patch16_224', pretrained=True)
model.eval()
transform = transforms.Compose([
transforms.Resize(256, interpolation=3),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD),
])
img = Image.open(requests.get("https://raw.githubusercontent.com/pytorch/ios-demo-app/master/HelloWorld/HelloWorld/HelloWorld/image.png", stream=True).raw)
img = transform(img)[None,]
out = model(img)
clsidx = torch.argmax(out)
print(clsidx.item())
2.6.0+cu124
Downloading: "https://github.com/facebookresearch/deit/zipball/main" to /var/lib/ci-user/.cache/torch/hub/main.zip
/usr/local/lib/python3.10/dist-packages/timm/models/registry.py:4: FutureWarning:
Importing from timm.models.registry is deprecated, please import via timm.models
/usr/local/lib/python3.10/dist-packages/timm/models/layers/__init__.py:48: FutureWarning:
Importing from timm.models.layers is deprecated, please import via timm.layers
/var/lib/ci-user/.cache/torch/hub/facebookresearch_deit_main/models.py:63: UserWarning:
Overwriting deit_tiny_patch16_224 in registry with models.deit_tiny_patch16_224. This is because the name being registered conflicts with an existing name. Please check if this is not expected.
/var/lib/ci-user/.cache/torch/hub/facebookresearch_deit_main/models.py:78: UserWarning:
Overwriting deit_small_patch16_224 in registry with models.deit_small_patch16_224. This is because the name being registered conflicts with an existing name. Please check if this is not expected.
/var/lib/ci-user/.cache/torch/hub/facebookresearch_deit_main/models.py:93: UserWarning:
Overwriting deit_base_patch16_224 in registry with models.deit_base_patch16_224. This is because the name being registered conflicts with an existing name. Please check if this is not expected.
/var/lib/ci-user/.cache/torch/hub/facebookresearch_deit_main/models.py:108: UserWarning:
Overwriting deit_tiny_distilled_patch16_224 in registry with models.deit_tiny_distilled_patch16_224. This is because the name being registered conflicts with an existing name. Please check if this is not expected.
/var/lib/ci-user/.cache/torch/hub/facebookresearch_deit_main/models.py:123: UserWarning:
Overwriting deit_small_distilled_patch16_224 in registry with models.deit_small_distilled_patch16_224. This is because the name being registered conflicts with an existing name. Please check if this is not expected.
/var/lib/ci-user/.cache/torch/hub/facebookresearch_deit_main/models.py:138: UserWarning:
Overwriting deit_base_distilled_patch16_224 in registry with models.deit_base_distilled_patch16_224. This is because the name being registered conflicts with an existing name. Please check if this is not expected.
/var/lib/ci-user/.cache/torch/hub/facebookresearch_deit_main/models.py:153: UserWarning:
Overwriting deit_base_patch16_384 in registry with models.deit_base_patch16_384. This is because the name being registered conflicts with an existing name. Please check if this is not expected.
/var/lib/ci-user/.cache/torch/hub/facebookresearch_deit_main/models.py:168: UserWarning:
Overwriting deit_base_distilled_patch16_384 in registry with models.deit_base_distilled_patch16_384. This is because the name being registered conflicts with an existing name. Please check if this is not expected.
Downloading: "https://dl.fbaipublicfiles.com/deit/deit_base_patch16_224-b5f2ef4d.pth" to /var/lib/ci-user/.cache/torch/hub/checkpoints/deit_base_patch16_224-b5f2ef4d.pth
0%| | 0.00/330M [00:00<?, ?B/s]
5%|4 | 15.6M/330M [00:00<00:02, 164MB/s]
10%|# | 34.4M/330M [00:00<00:01, 183MB/s]
16%|#6 | 54.2M/330M [00:00<00:01, 194MB/s]
22%|##2 | 74.2M/330M [00:00<00:01, 200MB/s]
29%|##8 | 94.2M/330M [00:00<00:01, 203MB/s]
35%|###4 | 114M/330M [00:00<00:01, 206MB/s]
41%|#### | 134M/330M [00:00<00:00, 207MB/s]
47%|####6 | 154M/330M [00:00<00:00, 190MB/s]
52%|#####2 | 173M/330M [00:00<00:00, 190MB/s]
58%|#####8 | 193M/330M [00:01<00:00, 196MB/s]
64%|######4 | 213M/330M [00:01<00:00, 201MB/s]
71%|####### | 233M/330M [00:01<00:00, 203MB/s]
76%|#######6 | 253M/330M [00:01<00:00, 193MB/s]
83%|########2 | 273M/330M [00:01<00:00, 198MB/s]
88%|########8 | 292M/330M [00:01<00:00, 195MB/s]
94%|#########4| 312M/330M [00:01<00:00, 197MB/s]
100%|##########| 330M/330M [00:01<00:00, 197MB/s]
269
输出应为 269,根据 ImageNet 类索引到 标签文件 的映射,对应于 timber wolf, grey wolf, gray wolf, Canis lupus
。
现在我们已经验证了可以使用 DeiT 模型对图像进行分类,接下来看看如何修改模型,使其可以在 iOS 和 Android 应用程序上运行。
脚本化 DeiT¶
要在移动设备上使用该模型,我们首先需要脚本化该模型。请参阅 脚本和优化食谱 以获得快速概述。运行以下代码以将上一步中使用的 DeiT 模型转换为可在移动设备上运行的 TorchScript 格式。
model = torch.hub.load('facebookresearch/deit:main', 'deit_base_patch16_224', pretrained=True)
model.eval()
scripted_model = torch.jit.script(model)
scripted_model.save("fbdeit_scripted.pt")
Using cache found in /var/lib/ci-user/.cache/torch/hub/facebookresearch_deit_main
生成脚本化模型文件 fbdeit_scripted.pt
,大小约为 346MB。
量化 DeiT¶
为了在保持推理准确率大致相同的情况下显著减小训练模型的大小,可以将量化应用于模型。得益于 DeiT 中使用的 Transformer 模型,我们可以轻松地将动态量化应用于该模型,因为动态量化最适合 LSTM 和 Transformer 模型(有关更多详细信息,请参阅 此处)。
现在运行以下代码
# Use 'x86' for server inference (the old 'fbgemm' is still available but 'x86' is the recommended default) and ``qnnpack`` for mobile inference.
backend = "x86" # replaced with ``qnnpack`` causing much worse inference speed for quantized model on this notebook
model.qconfig = torch.quantization.get_default_qconfig(backend)
torch.backends.quantized.engine = backend
quantized_model = torch.quantization.quantize_dynamic(model, qconfig_spec={torch.nn.Linear}, dtype=torch.qint8)
scripted_quantized_model = torch.jit.script(quantized_model)
scripted_quantized_model.save("fbdeit_scripted_quantized.pt")
/usr/local/lib/python3.10/dist-packages/torch/ao/quantization/observer.py:229: UserWarning:
Please use quant_min and quant_max to specify the range for observers. reduce_range will be deprecated in a future release of PyTorch.
这将生成模型的脚本化和量化版本 fbdeit_quantized_scripted.pt
,大小约为 89MB,比非量化模型大小 346MB 减少了 74%!
您可以使用 scripted_quantized_model
生成相同的推理结果
269
优化 DeiT¶
在移动设备上使用量化和脚本化模型之前的最后一步是优化它
from torch.utils.mobile_optimizer import optimize_for_mobile
optimized_scripted_quantized_model = optimize_for_mobile(scripted_quantized_model)
optimized_scripted_quantized_model.save("fbdeit_optimized_scripted_quantized.pt")
生成的 fbdeit_optimized_scripted_quantized.pt
文件的大小与量化、脚本化但未优化的模型大致相同。推理结果保持不变。
269
使用 Lite Interpreter¶
为了了解 Lite Interpreter 可以带来多少模型大小缩减和推理速度提升,让我们创建模型的 Lite 版本。
optimized_scripted_quantized_model._save_for_lite_interpreter("fbdeit_optimized_scripted_quantized_lite.ptl")
ptl = torch.jit.load("fbdeit_optimized_scripted_quantized_lite.ptl")
尽管 Lite 模型大小与非 Lite 版本相当,但在移动设备上运行 Lite 版本时,预计推理速度会加快。
比较推理速度¶
为了了解原始模型、脚本化模型、量化和脚本化模型、优化、量化和脚本化模型这四个模型的推理速度有何不同,请运行以下代码
with torch.autograd.profiler.profile(use_cuda=False) as prof1:
out = model(img)
with torch.autograd.profiler.profile(use_cuda=False) as prof2:
out = scripted_model(img)
with torch.autograd.profiler.profile(use_cuda=False) as prof3:
out = scripted_quantized_model(img)
with torch.autograd.profiler.profile(use_cuda=False) as prof4:
out = optimized_scripted_quantized_model(img)
with torch.autograd.profiler.profile(use_cuda=False) as prof5:
out = ptl(img)
print("original model: {:.2f}ms".format(prof1.self_cpu_time_total/1000))
print("scripted model: {:.2f}ms".format(prof2.self_cpu_time_total/1000))
print("scripted & quantized model: {:.2f}ms".format(prof3.self_cpu_time_total/1000))
print("scripted & quantized & optimized model: {:.2f}ms".format(prof4.self_cpu_time_total/1000))
print("lite model: {:.2f}ms".format(prof5.self_cpu_time_total/1000))
original model: 96.01ms
scripted model: 105.84ms
scripted & quantized model: 124.12ms
scripted & quantized & optimized model: 135.20ms
lite model: 147.87ms
在 Google Colab 上运行的结果是
original model: 1236.69ms
scripted model: 1226.72ms
scripted & quantized model: 593.19ms
scripted & quantized & optimized model: 598.01ms
lite model: 600.72ms
以下结果总结了每个模型所花费的推理时间以及每个模型相对于原始模型的减少百分比。
import pandas as pd
import numpy as np
df = pd.DataFrame({'Model': ['original model','scripted model', 'scripted & quantized model', 'scripted & quantized & optimized model', 'lite model']})
df = pd.concat([df, pd.DataFrame([
["{:.2f}ms".format(prof1.self_cpu_time_total/1000), "0%"],
["{:.2f}ms".format(prof2.self_cpu_time_total/1000),
"{:.2f}%".format((prof1.self_cpu_time_total-prof2.self_cpu_time_total)/prof1.self_cpu_time_total*100)],
["{:.2f}ms".format(prof3.self_cpu_time_total/1000),
"{:.2f}%".format((prof1.self_cpu_time_total-prof3.self_cpu_time_total)/prof1.self_cpu_time_total*100)],
["{:.2f}ms".format(prof4.self_cpu_time_total/1000),
"{:.2f}%".format((prof1.self_cpu_time_total-prof4.self_cpu_time_total)/prof1.self_cpu_time_total*100)],
["{:.2f}ms".format(prof5.self_cpu_time_total/1000),
"{:.2f}%".format((prof1.self_cpu_time_total-prof5.self_cpu_time_total)/prof1.self_cpu_time_total*100)]],
columns=['Inference Time', 'Reduction'])], axis=1)
print(df)
"""
Model Inference Time Reduction
0 original model 1236.69ms 0%
1 scripted model 1226.72ms 0.81%
2 scripted & quantized model 593.19ms 52.03%
3 scripted & quantized & optimized model 598.01ms 51.64%
4 lite model 600.72ms 51.43%
"""
Model ... Reduction
0 original model ... 0%
1 scripted model ... -10.25%
2 scripted & quantized model ... -29.28%
3 scripted & quantized & optimized model ... -40.82%
4 lite model ... -54.02%
[5 rows x 3 columns]
'\n Model Inference Time Reduction\n0\toriginal model 1236.69ms 0%\n1\tscripted model 1226.72ms 0.81%\n2\tscripted & quantized model 593.19ms 52.03%\n3\tscripted & quantized & optimized model 598.01ms 51.64%\n4\tlite model 600.72ms 51.43%\n'