• 教程 >
  • 优化视觉Transformer模型以进行部署
快捷方式

优化视觉Transformer模型以进行部署

Jeff TangGeeta Chauhan

视觉Transformer模型将自然语言处理中引入的尖端基于注意力的Transformer模型应用于计算机视觉任务,以实现各种最先进(SOTA)的结果。Facebook数据高效图像Transformer DeiT是一个在ImageNet上训练用于图像分类的视觉Transformer模型。

在本教程中,我们将首先介绍DeiT是什么以及如何使用它,然后介绍脚本、量化、优化以及在iOS和Android应用程序中使用模型的完整步骤。我们还将比较量化、优化和非量化、非优化模型的性能,并展示在模型中应用量化和优化的优势。

什么是DeiT

自2012年深度学习兴起以来,卷积神经网络(CNN)一直是图像分类的主要模型,但CNN通常需要数亿张图像才能在训练中达到最先进的结果。DeiT是一种视觉Transformer模型,它在训练中需要更少的数据和计算资源,从而与领先的CNN在图像分类方面的性能相媲美,这是DeiT的两个关键组成部分所实现的。

  • 模拟在更大数据集上训练的数据增强;

  • 原生蒸馏,使Transformer网络能够从CNN的输出中学习。

DeiT表明,Transformer可以在数据和资源有限的情况下成功应用于计算机视觉任务。有关DeiT的更多详细信息,请参阅repopaper.

使用DeiT对图像进行分类

请遵循DeiT存储库中的README.md,以获取有关如何使用DeiT对图像进行分类的详细信息,或者进行快速测试,首先安装所需的软件包

pip install torch torchvision timm pandas requests

要在Google Colab中运行,请通过运行以下命令安装依赖项

!pip install timm pandas requests

然后运行下面的脚本

from PIL import Image
import torch
import timm
import requests
import torchvision.transforms as transforms
from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD

print(torch.__version__)
# should be 1.8.0


model = torch.hub.load('facebookresearch/deit:main', 'deit_base_patch16_224', pretrained=True)
model.eval()

transform = transforms.Compose([
    transforms.Resize(256, interpolation=3),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD),
])

img = Image.open(requests.get("https://raw.githubusercontent.com/pytorch/ios-demo-app/master/HelloWorld/HelloWorld/HelloWorld/image.png", stream=True).raw)
img = transform(img)[None,]
out = model(img)
clsidx = torch.argmax(out)
print(clsidx.item())
2.5.0+cu124
Downloading: "https://github.com/facebookresearch/deit/zipball/main" to /var/lib/ci-user/.cache/torch/hub/main.zip
/var/lib/ci-user/.cache/torch/hub/facebookresearch_deit_main/models.py:63: UserWarning:

Overwriting deit_tiny_patch16_224 in registry with models.deit_tiny_patch16_224. This is because the name being registered conflicts with an existing name. Please check if this is not expected.

/var/lib/ci-user/.cache/torch/hub/facebookresearch_deit_main/models.py:78: UserWarning:

Overwriting deit_small_patch16_224 in registry with models.deit_small_patch16_224. This is because the name being registered conflicts with an existing name. Please check if this is not expected.

/var/lib/ci-user/.cache/torch/hub/facebookresearch_deit_main/models.py:93: UserWarning:

Overwriting deit_base_patch16_224 in registry with models.deit_base_patch16_224. This is because the name being registered conflicts with an existing name. Please check if this is not expected.

/var/lib/ci-user/.cache/torch/hub/facebookresearch_deit_main/models.py:108: UserWarning:

Overwriting deit_tiny_distilled_patch16_224 in registry with models.deit_tiny_distilled_patch16_224. This is because the name being registered conflicts with an existing name. Please check if this is not expected.

/var/lib/ci-user/.cache/torch/hub/facebookresearch_deit_main/models.py:123: UserWarning:

Overwriting deit_small_distilled_patch16_224 in registry with models.deit_small_distilled_patch16_224. This is because the name being registered conflicts with an existing name. Please check if this is not expected.

/var/lib/ci-user/.cache/torch/hub/facebookresearch_deit_main/models.py:138: UserWarning:

Overwriting deit_base_distilled_patch16_224 in registry with models.deit_base_distilled_patch16_224. This is because the name being registered conflicts with an existing name. Please check if this is not expected.

/var/lib/ci-user/.cache/torch/hub/facebookresearch_deit_main/models.py:153: UserWarning:

Overwriting deit_base_patch16_384 in registry with models.deit_base_patch16_384. This is because the name being registered conflicts with an existing name. Please check if this is not expected.

/var/lib/ci-user/.cache/torch/hub/facebookresearch_deit_main/models.py:168: UserWarning:

Overwriting deit_base_distilled_patch16_384 in registry with models.deit_base_distilled_patch16_384. This is because the name being registered conflicts with an existing name. Please check if this is not expected.

Downloading: "https://dl.fbaipublicfiles.com/deit/deit_base_patch16_224-b5f2ef4d.pth" to /var/lib/ci-user/.cache/torch/hub/checkpoints/deit_base_patch16_224-b5f2ef4d.pth

  0%|          | 0.00/330M [00:00<?, ?B/s]
  5%|4         | 15.5M/330M [00:00<00:02, 162MB/s]
  9%|9         | 31.4M/330M [00:00<00:01, 164MB/s]
 14%|#4        | 47.5M/330M [00:00<00:01, 166MB/s]
 19%|#9        | 63.4M/330M [00:00<00:01, 166MB/s]
 24%|##4       | 80.0M/330M [00:00<00:01, 169MB/s]
 29%|##9       | 97.1M/330M [00:00<00:01, 172MB/s]
 35%|###4      | 114M/330M [00:00<00:01, 175MB/s]
 40%|###9      | 132M/330M [00:00<00:01, 176MB/s]
 45%|####5     | 149M/330M [00:00<00:01, 178MB/s]
 50%|#####     | 166M/330M [00:01<00:00, 178MB/s]
 56%|#####5    | 183M/330M [00:01<00:00, 179MB/s]
 61%|######    | 201M/330M [00:01<00:00, 179MB/s]
 66%|######5   | 218M/330M [00:01<00:00, 180MB/s]
 71%|#######1  | 235M/330M [00:01<00:00, 180MB/s]
 76%|#######6  | 253M/330M [00:01<00:00, 181MB/s]
 82%|########1 | 271M/330M [00:01<00:00, 183MB/s]
 87%|########7 | 289M/330M [00:01<00:00, 185MB/s]
 93%|#########2| 307M/330M [00:01<00:00, 186MB/s]
 98%|#########8| 325M/330M [00:01<00:00, 187MB/s]
100%|##########| 330M/330M [00:01<00:00, 179MB/s]
269

输出应该为269,根据ImageNet的类索引到标签文件,映射到timber wolf, grey wolf, gray wolf, Canis lupus

现在我们已经验证可以使用DeiT模型对图像进行分类,让我们看看如何修改该模型,使其能够在iOS和Android应用程序上运行。

编写DeiT脚本

要在移动设备上使用该模型,我们首先需要编写该模型的脚本。请参阅脚本和优化食谱,以了解快速概述。运行以下代码将上一步骤中使用的DeiT模型转换为可在移动设备上运行的TorchScript格式。

model = torch.hub.load('facebookresearch/deit:main', 'deit_base_patch16_224', pretrained=True)
model.eval()
scripted_model = torch.jit.script(model)
scripted_model.save("fbdeit_scripted.pt")
Using cache found in /var/lib/ci-user/.cache/torch/hub/facebookresearch_deit_main

生成了大约346MB大小的脚本化模型文件fbdeit_scripted.pt

量化DeiT

为了在保持推理精度大致相同的情况下显着减小训练后的模型大小,可以对模型应用量化。由于DeiT中使用的Transformer模型,我们可以轻松地对模型应用动态量化,因为动态量化最适合LSTM和Transformer模型(有关更多详细信息,请参阅这里)。

现在运行下面的代码

# Use 'x86' for server inference (the old 'fbgemm' is still available but 'x86' is the recommended default) and ``qnnpack`` for mobile inference.
backend = "x86" # replaced with ``qnnpack`` causing much worse inference speed for quantized model on this notebook
model.qconfig = torch.quantization.get_default_qconfig(backend)
torch.backends.quantized.engine = backend

quantized_model = torch.quantization.quantize_dynamic(model, qconfig_spec={torch.nn.Linear}, dtype=torch.qint8)
scripted_quantized_model = torch.jit.script(quantized_model)
scripted_quantized_model.save("fbdeit_scripted_quantized.pt")
/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/ao/quantization/observer.py:229: UserWarning:

Please use quant_min and quant_max to specify the range for observers.                     reduce_range will be deprecated in a future release of PyTorch.

这将生成模型的脚本化和量化版本fbdeit_quantized_scripted.pt,大小约为89MB,比未量化模型的346MB大小减少了74%!

您可以使用scripted_quantized_model来生成相同的推理结果

out = scripted_quantized_model(img)
clsidx = torch.argmax(out)
print(clsidx.item())
# The same output 269 should be printed
269

优化DeiT

在移动设备上使用量化和脚本化的模型之前,最后的步骤是对它进行优化

from torch.utils.mobile_optimizer import optimize_for_mobile
optimized_scripted_quantized_model = optimize_for_mobile(scripted_quantized_model)
optimized_scripted_quantized_model.save("fbdeit_optimized_scripted_quantized.pt")

生成的fbdeit_optimized_scripted_quantized.pt文件的大小与量化、脚本化但未优化的模型的大小大致相同。推理结果保持不变。

out = optimized_scripted_quantized_model(img)
clsidx = torch.argmax(out)
print(clsidx.item())
# Again, the same output 269 should be printed
269

使用Lite Interpreter

为了了解Lite Interpreter可以带来的模型大小减少和推理速度提升,让我们创建模型的Lite版本。

optimized_scripted_quantized_model._save_for_lite_interpreter("fbdeit_optimized_scripted_quantized_lite.ptl")
ptl = torch.jit.load("fbdeit_optimized_scripted_quantized_lite.ptl")

虽然Lite模型的大小与非Lite版本相当,但在移动设备上运行Lite版本时,预计推理速度会有所提升。

比较推理速度

为了了解四种模型的推理速度差异 - 原始模型、脚本化模型、量化和脚本化模型、优化、量化和脚本化模型 - 请运行以下代码

with torch.autograd.profiler.profile(use_cuda=False) as prof1:
    out = model(img)
with torch.autograd.profiler.profile(use_cuda=False) as prof2:
    out = scripted_model(img)
with torch.autograd.profiler.profile(use_cuda=False) as prof3:
    out = scripted_quantized_model(img)
with torch.autograd.profiler.profile(use_cuda=False) as prof4:
    out = optimized_scripted_quantized_model(img)
with torch.autograd.profiler.profile(use_cuda=False) as prof5:
    out = ptl(img)

print("original model: {:.2f}ms".format(prof1.self_cpu_time_total/1000))
print("scripted model: {:.2f}ms".format(prof2.self_cpu_time_total/1000))
print("scripted & quantized model: {:.2f}ms".format(prof3.self_cpu_time_total/1000))
print("scripted & quantized & optimized model: {:.2f}ms".format(prof4.self_cpu_time_total/1000))
print("lite model: {:.2f}ms".format(prof5.self_cpu_time_total/1000))
original model: 165.32ms
scripted model: 137.03ms
scripted & quantized model: 126.08ms
scripted & quantized & optimized model: 137.81ms
lite model: 149.98ms

在Google Colab上运行的结果是

original model: 1236.69ms
scripted model: 1226.72ms
scripted & quantized model: 593.19ms
scripted & quantized & optimized model: 598.01ms
lite model: 600.72ms

以下结果总结了每个模型所花费的推理时间以及每个模型相对于原始模型的百分比减少。

import pandas as pd
import numpy as np

df = pd.DataFrame({'Model': ['original model','scripted model', 'scripted & quantized model', 'scripted & quantized & optimized model', 'lite model']})
df = pd.concat([df, pd.DataFrame([
    ["{:.2f}ms".format(prof1.self_cpu_time_total/1000), "0%"],
    ["{:.2f}ms".format(prof2.self_cpu_time_total/1000),
     "{:.2f}%".format((prof1.self_cpu_time_total-prof2.self_cpu_time_total)/prof1.self_cpu_time_total*100)],
    ["{:.2f}ms".format(prof3.self_cpu_time_total/1000),
     "{:.2f}%".format((prof1.self_cpu_time_total-prof3.self_cpu_time_total)/prof1.self_cpu_time_total*100)],
    ["{:.2f}ms".format(prof4.self_cpu_time_total/1000),
     "{:.2f}%".format((prof1.self_cpu_time_total-prof4.self_cpu_time_total)/prof1.self_cpu_time_total*100)],
    ["{:.2f}ms".format(prof5.self_cpu_time_total/1000),
     "{:.2f}%".format((prof1.self_cpu_time_total-prof5.self_cpu_time_total)/prof1.self_cpu_time_total*100)]],
    columns=['Inference Time', 'Reduction'])], axis=1)

print(df)

"""
        Model                             Inference Time    Reduction
0   original model                             1236.69ms           0%
1   scripted model                             1226.72ms        0.81%
2   scripted & quantized model                  593.19ms       52.03%
3   scripted & quantized & optimized model      598.01ms       51.64%
4   lite model                                  600.72ms       51.43%
"""
                                    Model  ... Reduction
0                          original model  ...        0%
1                          scripted model  ...    17.11%
2              scripted & quantized model  ...    23.73%
3  scripted & quantized & optimized model  ...    16.64%
4                              lite model  ...     9.28%

[5 rows x 3 columns]

'\n        Model                             Inference Time    Reduction\n0\toriginal model                             1236.69ms           0%\n1\tscripted model                             1226.72ms        0.81%\n2\tscripted & quantized model                  593.19ms       52.03%\n3\tscripted & quantized & optimized model      598.01ms       51.64%\n4\tlite model                                  600.72ms       51.43%\n'

文档

访问 PyTorch 的全面开发者文档

查看文档

教程

获取针对初学者和高级开发人员的深入教程

查看教程

资源

查找开发资源并获得您的问题解答

查看资源