(原型) PyTorch 2 导出量化感知训练 (QAT)¶
**作者**:Andrew Or
本教程展示了如何在基于 torch.export.export 的图模式下执行量化感知训练 (QAT)。有关 PyTorch 2 导出量化的更多详细信息,请参阅 训练后量化教程。
PyTorch 2 导出 QAT 流程如下所示——在大多数情况下,它类似于训练后量化 (PTQ) 流程
import torch
from torch._export import capture_pre_autograd_graph
from torch.ao.quantization.quantize_pt2e import (
prepare_qat_pt2e,
convert_pt2e,
)
from torch.ao.quantization.quantizer import (
XNNPACKQuantizer,
get_symmetric_quantization_config,
)
class M(torch.nn.Module):
def __init__(self):
super().__init__()
self.linear = torch.nn.Linear(5, 10)
def forward(self, x):
return self.linear(x)
example_inputs = (torch.randn(1, 5),)
m = M()
# Step 1. program capture
# NOTE: this API will be updated to torch.export API in the future, but the captured
# result shoud mostly stay the same
m = capture_pre_autograd_graph(m, *example_inputs)
# we get a model with aten ops
# Step 2. quantization-aware training
# backend developer will write their own Quantizer and expose methods to allow
# users to express how they want the model to be quantized
quantizer = XNNPACKQuantizer().set_global(get_symmetric_quantization_config())
m = prepare_qat_pt2e(m, quantizer)
# train omitted
m = convert_pt2e(m)
# we have a model with aten ops doing integer computations when possible
# move the quantized model to eval mode, equivalent to `m.eval()`
torch.ao.quantization.move_exported_model_to_eval(m)
请注意,在程序捕获后调用 model.eval()
或 model.train()
是不允许的,因为这些方法不再正确地更改某些操作(如 dropout 和批归一化)的行为。相反,请分别使用 torch.ao.quantization.move_exported_model_to_eval()
和 torch.ao.quantization.move_exported_model_to_train()
(即将推出)。
定义辅助函数并准备数据集¶
要使用本教程中的代码运行整个 ImageNet 数据集,请首先按照 ImageNet 数据 中的说明下载 ImageNet。将下载的文件解压缩到 data_path
文件夹中。
接下来,下载 torchvision resnet18 模型 并将其重命名为 data/resnet18_pretrained_float.pth
。
我们将从进行必要的导入、定义一些辅助函数和准备数据开始。这些步骤与 静态 Eager 模式训练后量化教程 中定义的步骤非常相似。
import os
import sys
import time
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
import torchvision
from torchvision import datasets
from torchvision.models.resnet import resnet18
import torchvision.transforms as transforms
# Set up warnings
import warnings
warnings.filterwarnings(
action='ignore',
category=DeprecationWarning,
module=r'.*'
)
warnings.filterwarnings(
action='default',
module=r'torch.ao.quantization'
)
# Specify random seed for repeatable results
_ = torch.manual_seed(191009)
class AverageMeter(object):
"""Computes and stores the average and current value"""
def __init__(self, name, fmt=':f'):
self.name = name
self.fmt = fmt
self.reset()
def reset(self):
self.val = 0
self.avg = 0
self.sum = 0
self.count = 0
def update(self, val, n=1):
self.val = val
self.sum += val * n
self.count += n
self.avg = self.sum / self.count
def __str__(self):
fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})'
return fmtstr.format(**self.__dict__)
def accuracy(output, target, topk=(1,)):
"""
Computes the accuracy over the k top predictions for the specified
values of k.
"""
with torch.no_grad():
maxk = max(topk)
batch_size = target.size(0)
_, pred = output.topk(maxk, 1, True, True)
pred = pred.t()
correct = pred.eq(target.view(1, -1).expand_as(pred))
res = []
for k in topk:
correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True)
res.append(correct_k.mul_(100.0 / batch_size))
return res
def evaluate(model, criterion, data_loader, device):
torch.ao.quantization.move_exported_model_to_eval(model)
top1 = AverageMeter('Acc@1', ':6.2f')
top5 = AverageMeter('Acc@5', ':6.2f')
cnt = 0
with torch.no_grad():
for image, target in data_loader:
image = image.to(device)
target = target.to(device)
output = model(image)
loss = criterion(output, target)
cnt += 1
acc1, acc5 = accuracy(output, target, topk=(1, 5))
top1.update(acc1[0], image.size(0))
top5.update(acc5[0], image.size(0))
print('')
return top1, top5
def load_model(model_file):
model = resnet18(pretrained=False)
state_dict = torch.load(model_file, weights_only=True)
model.load_state_dict(state_dict)
return model
def print_size_of_model(model):
if isinstance(model, torch.jit.RecursiveScriptModule):
torch.jit.save(model, "temp.p")
else:
torch.jit.save(torch.jit.script(model), "temp.p")
print("Size (MB):", os.path.getsize("temp.p")/1e6)
os.remove("temp.p")
def prepare_data_loaders(data_path):
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
dataset = torchvision.datasets.ImageNet(
data_path, split="train", transform=transforms.Compose([
transforms.RandomResizedCrop(224),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
normalize,
]))
dataset_test = torchvision.datasets.ImageNet(
data_path, split="val", transform=transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
normalize,
]))
train_sampler = torch.utils.data.RandomSampler(dataset)
test_sampler = torch.utils.data.SequentialSampler(dataset_test)
data_loader = torch.utils.data.DataLoader(
dataset, batch_size=train_batch_size,
sampler=train_sampler)
data_loader_test = torch.utils.data.DataLoader(
dataset_test, batch_size=eval_batch_size,
sampler=test_sampler)
return data_loader, data_loader_test
def train_one_epoch(model, criterion, optimizer, data_loader, device, ntrain_batches):
# Note: do not call model.train() here, since this doesn't work on an exported model.
# Instead, call `torch.ao.quantization.move_exported_model_to_train(model)`, which will
# be added in the near future
top1 = AverageMeter('Acc@1', ':6.2f')
top5 = AverageMeter('Acc@5', ':6.2f')
avgloss = AverageMeter('Loss', '1.5f')
cnt = 0
for image, target in data_loader:
start_time = time.time()
print('.', end = '')
cnt += 1
image, target = image.to(device), target.to(device)
output = model(image)
loss = criterion(output, target)
optimizer.zero_grad()
loss.backward()
optimizer.step()
acc1, acc5 = accuracy(output, target, topk=(1, 5))
top1.update(acc1[0], image.size(0))
top5.update(acc5[0], image.size(0))
avgloss.update(loss, image.size(0))
if cnt >= ntrain_batches:
print('Loss', avgloss.avg)
print('Training: * Acc@1 {top1.avg:.3f} Acc@5 {top5.avg:.3f}'
.format(top1=top1, top5=top5))
return
print('Full imagenet train set: * Acc@1 {top1.global_avg:.3f} Acc@5 {top5.global_avg:.3f}'
.format(top1=top1, top5=top5))
return
data_path = '~/.data/imagenet'
saved_model_dir = 'data/'
float_model_file = 'resnet18_pretrained_float.pth'
train_batch_size = 32
eval_batch_size = 32
data_loader, data_loader_test = prepare_data_loaders(data_path)
example_inputs = (next(iter(data_loader))[0])
criterion = nn.CrossEntropyLoss()
float_model = load_model(saved_model_dir + float_model_file).to("cuda")
使用 torch.export 导出模型¶
以下是如何使用 torch.export
导出模型的方法
from torch._export import capture_pre_autograd_graph
example_inputs = (torch.rand(2, 3, 224, 224),)
exported_model = capture_pre_autograd_graph(float_model, example_inputs)
# or, to capture with dynamic dimensions:
from torch._export import dynamic_dim
example_inputs = (torch.rand(2, 3, 224, 224),)
exported_model = capture_pre_autograd_graph(
float_model,
example_inputs,
constraints=[dynamic_dim(example_inputs[0], 0)],
)
注意
capture_pre_autograd_graph
是一个短期 API,当官方的 torch.export
API 就绪时,它将更新为使用该 API。
导入后端特定的量化器并配置如何量化模型¶
以下代码片段描述了如何量化模型
from torch.ao.quantization.quantizer.xnnpack_quantizer import (
XNNPACKQuantizer,
get_symmetric_quantization_config,
)
quantizer = XNNPACKQuantizer()
quantizer.set_global(get_symmetric_quantization_config(is_qat=True))
Quantizer
是后端特定的,每个 Quantizer
将提供自己的方式来允许用户配置他们的模型。
注意
查看我们的 教程,其中描述了如何编写新的 Quantizer
。
准备模型进行量化感知训练¶
prepare_qat_pt2e
在模型的适当位置插入伪量化,并执行适当的 QAT“融合”,例如 Conv2d
+ BatchNorm2d
,以获得更好的训练精度。融合后的操作在准备好的图中表示为 ATen 操作的子图。
prepared_model = prepare_qat_pt2e(exported_model, quantizer)
print(prepared_model)
注意
如果您的模型包含批归一化,则您在图中获得的实际 ATen 操作取决于您导出模型时模型的设备。如果模型在 CPU 上,则您将获得 torch.ops.aten._native_batch_norm_legit
。如果模型在 CUDA 上,则您将获得 torch.ops.aten.cudnn_batch_norm
。但是,这不是根本性的,将来可能会发生变化。
在这两个操作之间,已经证明 torch.ops.aten.cudnn_batch_norm
在 MobileNetV2 等模型上提供了更好的数值。要获得此操作,请在导出前调用 model.cuda()
,或在准备后运行以下操作以手动交换操作
for n in prepared_model.graph.nodes:
if n.target == torch.ops.aten._native_batch_norm_legit.default:
n.target = torch.ops.aten.cudnn_batch_norm.default
prepared_model.recompile()
将来,我们计划整合批归一化操作,以便不再需要上述操作。
训练循环¶
训练循环类似于先前版本 QAT 中的循环。为了获得更好的精度,您可以选择在一定数量的 epochs 后禁用观察器和更新批归一化统计信息,或者每 N
个 epochs 评估 QAT 或迄今为止训练的量化模型。
num_epochs = 10
num_train_batches = 20
num_eval_batches = 20
num_observer_update_epochs = 4
num_batch_norm_update_epochs = 3
num_epochs_between_evals = 2
# QAT takes time and one needs to train over a few epochs.
# Train and check accuracy after each epoch
for nepoch in range(num_epochs):
train_one_epoch(prepared_model, criterion, optimizer, data_loader, "cuda", num_train_batches)
# Optionally disable observer/batchnorm stats after certain number of epochs
if epoch >= num_observer_update_epochs:
print("Disabling observer for subseq epochs, epoch = ", epoch)
prepared_model.apply(torch.ao.quantization.disable_observer)
if epoch >= num_batch_norm_update_epochs:
print("Freezing BN for subseq epochs, epoch = ", epoch)
for n in prepared_model.graph.nodes:
# Args: input, weight, bias, running_mean, running_var, training, momentum, eps
# We set the `training` flag to False here to freeze BN stats
if n.target in [
torch.ops.aten._native_batch_norm_legit.default,
torch.ops.aten.cudnn_batch_norm.default,
]:
new_args = list(n.args)
new_args[5] = False
n.args = new_args
prepared_model.recompile()
# Check the quantized accuracy every N epochs
# Note: If you wish to just evaluate the QAT model (not the quantized model),
# then you can just call `torch.ao.quantization.move_exported_model_to_eval/train`.
# However, the latter API is not ready yet and will be available in the near future.
if (nepoch + 1) % num_epochs_between_evals == 0:
prepared_model_copy = copy.deepcopy(prepared_model)
quantized_model = convert_pt2e(prepared_model_copy)
top1, top5 = evaluate(quantized_model, criterion, data_loader_test, neval_batches=num_eval_batches)
print('Epoch %d: Evaluation accuracy on %d images, %2.2f' % (nepoch, num_eval_batches * eval_batch_size, top1.avg))
保存和加载模型检查点¶
PyTorch 2 导出 QAT 流程的模型检查点与任何其他训练流程中的模型检查点相同。它们可用于暂停训练并在以后恢复训练、从失败的训练运行中恢复以及在以后的某个时间在不同的机器上执行推理。您可以在训练期间或训练后保存模型检查点,如下所示
checkpoint_path = "/path/to/my/checkpoint_%s.pth" % nepoch
torch.save(prepared_model.state_dict(), "checkpoint_path")
要加载检查点,您必须以最初导出和准备模型的完全相同的方式导出和准备模型。例如
from torch._export import capture_pre_autograd_graph
from torch.ao.quantization.quantizer.xnnpack_quantizer import (
XNNPACKQuantizer,
get_symmetric_quantization_config,
)
from torchvision.models.resnet import resnet18
example_inputs = (torch.rand(2, 3, 224, 224),)
float_model = resnet18(pretrained=False)
exported_model = capture_pre_autograd_graph(float_model, example_inputs)
quantizer = XNNPACKQuantizer()
quantizer.set_global(get_symmetric_quantization_config(is_qat=True))
prepared_model = prepare_qat_pt2e(exported_model, quantizer)
prepared_model.load_state_dict(torch.load(checkpoint_path))
# resume training or perform inference
将训练后的模型转换为量化模型¶
convert_pt2e
获取一个校准后的模型并生成一个量化模型。请注意,在推理之前,您必须首先调用 torch.ao.quantization.move_exported_model_to_eval()
以确保某些操作(如 dropout)在评估图中正常运行。否则,例如,我们将继续在推理期间的前向传递中错误地应用 dropout。
quantized_model = convert_pt2e(prepared_model)
# move certain ops like dropout to eval mode, equivalent to `m.eval()`
torch.ao.quantization.move_exported_model_to_eval(m)
print(quantized_model)
top1, top5 = evaluate(quantized_model, criterion, data_loader_test, neval_batches=num_eval_batches)
print('Final evaluation accuracy on %d images, %2.2f' % (num_eval_batches * eval_batch_size, top1.avg))