• 教程 >
  • (原型) 在 BERT 上进行图模式动态量化
快捷方式

(原型) 在 BERT 上进行图模式动态量化

创建日期:2020 年 7 月 28 日 | 最后更新:2024 年 1 月 16 日 | 最后验证:2024 年 11 月 5 日

作者Supriya Rao

引言

本教程介绍了如何使用图模式量化进行训练后动态量化的步骤。动态量化将浮点模型转换为量化模型,其中权重使用静态 int8 数据类型,激活使用动态量化。激活会动态(按批次)量化为 int8,而权重会静态量化为 int8。图模式量化流程在模型图上运行,只需最少的用户干预即可量化模型。要使用图模式,浮点模型需要先进行 tracing 或 scripting。

图模式量化的优点包括

  • 在图模式中,我们可以检查在 forward 函数中执行的代码(例如 aten 函数调用),并通过模块和图操作实现量化。

  • 量化流程简单,人工步骤最少。

  • 解锁了进行更高级优化(如自动精度选择)的可能性。

有关图模式量化的更多详细信息,请参阅图模式静态量化教程

tl;dr 图模式动态量化 API

import torch
from torch.quantization import per_channel_dynamic_qconfig
from torch.quantization import quantize_dynamic_jit

ts_model = torch.jit.script(float_model) # or torch.jit.trace(float_model, input)

quantized = quantize_dynamic_jit(ts_model, {'': per_channel_dynamic_qconfig})

1. 量化 BERT 模型

安装步骤和模型详情与 Eager Mode 教程中的步骤相同。请参阅此处的教程了解更多详情。

1.1 设置

下载并安装所有必要的包后,我们进行代码设置。首先导入必要的库并为模型进行设置。

import logging
import numpy as np
import os
import random
import sys
import time
import torch

from argparse import Namespace
from torch.utils.data import (DataLoader, RandomSampler, SequentialSampler,
                              TensorDataset)
from tqdm import tqdm
from transformers import (BertConfig, BertForSequenceClassification, BertTokenizer,)
from transformers import glue_compute_metrics as compute_metrics
from transformers import glue_output_modes as output_modes
from transformers import glue_processors as processors
from transformers import glue_convert_examples_to_features as convert_examples_to_features
from torch.quantization import per_channel_dynamic_qconfig
from torch.quantization import quantize_dynamic_jit

def ids_tensor(shape, vocab_size):
    #  Creates a random int32 tensor of the shape within the vocab size
    return torch.randint(0, vocab_size, shape=shape, dtype=torch.int, device='cpu')

# Setup logging
logger = logging.getLogger(__name__)
logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s -   %(message)s',
                    datefmt = '%m/%d/%Y %H:%M:%S',
                    level = logging.WARN)

logging.getLogger("transformers.modeling_utils").setLevel(
   logging.WARN)  # Reduce logging

print(torch.__version__)

torch.set_num_threads(1)
print(torch.__config__.parallel_info())

1.2 下载 GLUE 数据集

在运行 MRPC 任务之前,我们通过运行此脚本下载 GLUE 数据,并将其解压到 glue_data 目录中。

python download_glue_data.py --data_dir='glue_data' --tasks='MRPC'

1.3 设置全局 BERT 配置

要运行此实验,我们首先需要一个微调过的 BERT 模型。我们提供了用于 MRPC 任务的微调 BERT 模型,请点击此处下载。为了节省时间,你可以将模型文件(约 400 MB)直接下载到你的本地文件夹 $OUT_DIR 中。

configs = Namespace()

# The output directory for the fine-tuned model, $OUT_DIR.
configs.output_dir = "./MRPC/"

# The data directory for the MRPC task in the GLUE benchmark, $GLUE_DIR/$TASK_NAME.
configs.data_dir = "./glue_data/MRPC"

# The model name or path for the pre-trained model.
configs.model_name_or_path = "bert-base-uncased"
# The maximum length of an input sequence
configs.max_seq_length = 128

# Prepare GLUE task.
configs.task_name = "MRPC".lower()
configs.processor = processors[configs.task_name]()
configs.output_mode = output_modes[configs.task_name]
configs.label_list = configs.processor.get_labels()
configs.model_type = "bert".lower()
configs.do_lower_case = True

# Set the device, batch size, topology, and caching flags.
configs.device = "cpu"
configs.per_gpu_eval_batch_size = 8
configs.n_gpu = 0
configs.local_rank = -1
configs.overwrite_cache = False

# Set random seed for reproducibility.
def set_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
set_seed(42)

tokenizer = BertTokenizer.from_pretrained(
    configs.output_dir, do_lower_case=configs.do_lower_case)

model = BertForSequenceClassification.from_pretrained(configs.output_dir, torchscript=True)
model.to(configs.device)

1.4 使用图模式量化 BERT 模型

1.4.1 对模型进行 Scripting/Tracing

图模式量化的输入是 TorchScript 模型,因此你需要先对模型进行 scripting 或 tracing。目前,scripting BERT 模型尚不支持,所以我们在这里对模型进行 tracing。

我们首先确定要传递给模型的输入。在这里,我们使用评估过程中将传递的最大可能的输入大小来对模型进行 tracing。根据下面评估步骤中传入的输入大小,我们选择批次大小为 8,序列长度为 128。在 tracing 时使用推理期间的最大可能形状是 huggingface BERT 模型的一个限制,正如此处所述。

我们使用 torch.jit.trace 对模型进行 tracing。

input_ids = ids_tensor([8, 128], 2)
token_type_ids = ids_tensor([8, 128], 2)
attention_mask = ids_tensor([8, 128], vocab_size=2)
dummy_input = (input_ids, attention_mask, token_type_ids)
traced_model = torch.jit.trace(model, dummy_input)

1.4.2 指定 qconfig_dict

qconfig_dict = {'': per_channel_dynamic_qconfig}

qconfig 是一个包含激活和权重观察器的命名元组 (named tuple)。对于动态量化,我们使用一个虚拟的激活观察器来模拟运行时在算子中发生的动态量化过程。对于权重张量,我们建议使用逐通道量化 (per-channel quantization),这有助于提高最终精度。qconfig_dict 是一个字典,其键是子模块的名称,值是该模块的 qconfig,空键表示 qconfig 将应用于整个模型,除非被更具体的配置覆盖,每个模块的 qconfig 要么在字典中找到,要么回退到父模块的 qconfig。

目前 qconfig_dict 是配置模型如何量化的唯一方式,它是以模块粒度进行的,也就是说,我们只支持为每个模块配置一种 qconfig,并且子模块的 qconfig 将覆盖父模块的 qconfig。例如,如果我们有

qconfig = {
    '' : qconfig_global,
    'sub' : qconfig_sub,
    'sub.fc1' : qconfig_fc,
    'sub.fc2': None
}

模块 sub.fc1 将使用 qconfig_fc 进行配置,sub 中的所有其他子模块将使用 qconfig_sub 进行配置,而 sub.fc2 将不会被量化。模型中的所有其他模块将使用 qconfig_global 进行量化。

qconfig_dict = {'': per_channel_dynamic_qconfig}

1.4.3 量化模型(一行 API)

我们调用一行 API(类似于 eager mode)来执行量化,如下所示。

quantized_model = quantize_dynamic_jit(traced_model, qconfig_dict)

2. 评估

我们重用 Huggingface 中的分词和评估函数。

def evaluate(args, model, tokenizer, prefix=""):
    # Loop to handle MNLI double evaluation (matched, mis-matched)
    eval_task_names = ("mnli", "mnli-mm") if args.task_name == "mnli" else (args.task_name,)
    eval_outputs_dirs = (args.output_dir, args.output_dir + '-MM') if args.task_name == "mnli" else (args.output_dir,)

    results = {}
    for eval_task, eval_output_dir in zip(eval_task_names, eval_outputs_dirs):
        eval_dataset = load_and_cache_examples(args, eval_task, tokenizer, evaluate=True)

        if not os.path.exists(eval_output_dir) and args.local_rank in [-1, 0]:
            os.makedirs(eval_output_dir)

        args.eval_batch_size = args.per_gpu_eval_batch_size * max(1, args.n_gpu)
        # Note that DistributedSampler samples randomly
        eval_sampler = SequentialSampler(eval_dataset) if args.local_rank == -1 else DistributedSampler(eval_dataset)
        eval_dataloader = DataLoader(eval_dataset, sampler=eval_sampler, batch_size=args.eval_batch_size)

        # multi-gpu eval
        if args.n_gpu > 1:
            model = torch.nn.DataParallel(model)

        # Eval!
        logger.info("***** Running evaluation {} *****".format(prefix))
        logger.info("  Num examples = %d", len(eval_dataset))
        logger.info("  Batch size = %d", args.eval_batch_size)
        nb_eval_steps = 0
        preds = None
        out_label_ids = None
        for batch in tqdm(eval_dataloader, desc="Evaluating"):
            model.eval()
            batch = tuple(t.to(args.device) for t in batch)

            with torch.no_grad():
                inputs = {'input_ids':      batch[0],
                          'attention_mask': batch[1]}
                labels = batch[3]
                if args.model_type != 'distilbert':
                    inputs['input'] = batch[2] if args.model_type in ['bert', 'xlnet'] else None  # XLM, DistilBERT and RoBERTa don't use segment_ids
                outputs = model(**inputs)
                logits = outputs[0]
            nb_eval_steps += 1
            if preds is None:
                preds = logits.detach().cpu().numpy()
                out_label_ids = labels.detach().cpu().numpy()
            else:
                preds = np.append(preds, logits.detach().cpu().numpy(), axis=0)
                out_label_ids = np.append(out_label_ids, labels.detach().cpu().numpy(), axis=0)

        if args.output_mode == "classification":
            preds = np.argmax(preds, axis=1)
        elif args.output_mode == "regression":
            preds = np.squeeze(preds)
        result = compute_metrics(eval_task, preds, out_label_ids)
        results.update(result)

        output_eval_file = os.path.join(eval_output_dir, prefix, "eval_results.txt")
        with open(output_eval_file, "w") as writer:
            logger.info("***** Eval results {} *****".format(prefix))
            for key in sorted(result.keys()):
                logger.info("  %s = %s", key, str(result[key]))
                writer.write("%s = %s\n" % (key, str(result[key])))

    return results

def load_and_cache_examples(args, task, tokenizer, evaluate=False):
    if args.local_rank not in [-1, 0] and not evaluate:
        torch.distributed.barrier()  # Make sure only the first process in distributed training process the dataset, and the others will use the cache

    processor = processors[task]()
    output_mode = output_modes[task]
    # Load data features from cache or dataset file
    cached_features_file = os.path.join(args.data_dir, 'cached_{}_{}_{}_{}'.format(
        'dev' if evaluate else 'train',
        list(filter(None, args.model_name_or_path.split('/'))).pop(),
        str(args.max_seq_length),
        str(task)))
    if os.path.exists(cached_features_file) and not args.overwrite_cache:
        logger.info("Loading features from cached file %s", cached_features_file)
        features = torch.load(cached_features_file)
    else:
        logger.info("Creating features from dataset file at %s", args.data_dir)
        label_list = processor.get_labels()
        if task in ['mnli', 'mnli-mm'] and args.model_type in ['roberta']:
            # HACK(label indices are swapped in RoBERTa pretrained model)
            label_list[1], label_list[2] = label_list[2], label_list[1]
        examples = processor.get_dev_examples(args.data_dir) if evaluate else processor.get_train_examples(args.data_dir)
        features = convert_examples_to_features(examples,
                                                tokenizer,
                                                label_list=label_list,
                                                max_length=args.max_seq_length,
                                                output_mode=output_mode,)
        if args.local_rank in [-1, 0]:
            logger.info("Saving features into cached file %s", cached_features_file)
            torch.save(features, cached_features_file)

    if args.local_rank == 0 and not evaluate:
        torch.distributed.barrier()  # Make sure only the first process in distributed training process the dataset, and the others will use the cache

    # Convert to Tensors and build dataset
    all_input_ids = torch.tensor([f.input_ids for f in features], dtype=torch.long)
    all_attention_mask = torch.tensor([f.attention_mask for f in features], dtype=torch.long)
    all_token_type_ids = torch.tensor([f.token_type_ids for f in features], dtype=torch.long)
    if output_mode == "classification":
        all_labels = torch.tensor([f.label for f in features], dtype=torch.long)
    elif output_mode == "regression":
        all_labels = torch.tensor([f.label for f in features], dtype=torch.float)

    dataset = TensorDataset(all_input_ids, all_attention_mask, all_token_type_ids, all_labels)
    return dataset

def time_model_evaluation(model, configs, tokenizer):
    eval_start_time = time.time()
    result = evaluate(configs, model, tokenizer, prefix="")
    eval_end_time = time.time()
    eval_duration_time = eval_end_time - eval_start_time
    print(result)
    print("Evaluate total time (seconds): {0:.1f}".format(eval_duration_time))

2.1 检查模型大小

我们打印模型大小,以体现量化带来的收益。

def print_size_of_model(model):
    if isinstance(model, torch.jit.RecursiveScriptModule):
        torch.jit.save(model, "temp.p")
    else:
        torch.jit.save(torch.jit.script(model), "temp.p")
    print('Size (MB):', os.path.getsize("temp.p")/1e6)
    os.remove('temp.p')

print("Size of model before quantization")
print_size_of_model(traced_model)
print("Size of model after quantization")

print_size_of_model(quantized_model)
Size of model before quantization
Size (MB): 438.242141
Size of model after quantization
Size (MB): 184.354759

2.2 运行评估

我们评估 FP32 模型和量化模型,并比较它们的 F1 分数。请注意,下面的性能数据是在开发机器上获得的,在生产服务器上可能会有所提升。

time_model_evaluation(traced_model, configs, tokenizer)
time_model_evaluation(quantized_model, configs, tokenizer)
FP32 model results -
'f1': 0.901
Time taken - 188.0s

INT8 model results -
'f1': 0.902
Time taken - 157.4s

3. 调试量化模型

我们可以通过传入 debug 选项来调试量化模型。

quantized_model = quantize_dynamic_jit(traced_model, qconfig_dict, debug=True)

如果将 debug 设置为 True

  • 我们可以像在 torchscript 模型中一样访问量化模型的属性,例如 model.fc1.weight(如果你使用模块列表或顺序模型,可能会更困难)。

  • 所有算术运算都在浮点数中进行,其数值与最终的量化模型相同,从而允许进行调试。

quantized_model_debug = quantize_dynamic_jit(traced_model, qconfig_dict, debug=True)

调用 quantize_dynamic_jit 等效于先调用 prepare_dynamic_jit 再调用 convert_dynamic_jit。推荐使用一行 API。但如果你希望在每个步骤后调试或分析模型,则可以使用多行 API。

3.1. 评估调试模型

# Evaluate the debug model
time_model_evaluation(quantized_model_debug, configs, tokenizer)
Size (MB): 438.406429

INT8 (debug=True) model results -
'f1': 0.897

请注意,调试版本的精度接近但与非调试版本不完全相同,因为调试版本使用浮点运算来模拟量化运算,且数值匹配是近似的。这种情况仅出现在逐通道量化中(我们正在努力改进)。逐张量量化(使用 default_dynamic_qconfig)的数值与非调试版本完全匹配。

print(str(quantized_model_debug.graph))

打印的图片段 -

%111 : Tensor = prim::GetAttr[name="bias"](%110)
%112 : Tensor = prim::GetAttr[name="weight"](%110)
%113 : Float(768:1) = prim::GetAttr[name="4_scale_0"](%110)
%114 : Int(768:1) = prim::GetAttr[name="4_zero_point_0"](%110)
%115 : int = prim::GetAttr[name="4_axis_0"](%110)
%116 : int = prim::GetAttr[name="4_scalar_type_0"](%110)
%4.quant.6 : Tensor = aten::quantize_per_channel(%112, %113, %114, %115, %116)
%4.dequant.6 : Tensor = aten::dequantize(%4.quant.6)
%1640 : bool = prim::Constant[value=1]()
%input.5.scale.1 : float, %input.5.zero_point.1 : int = aten::_choose_qparams_per_tensor(%input.5, %1640)
%input.5.quant.1 : Tensor = aten::quantize_per_tensor(%input.5, %input.5.scale.1, %input.5.zero_point.1, %74)
%input.5.dequant.1 : Float(8:98304, 128:768, 768:1) = aten::dequantize(%input.5.quant.1)
%119 : Tensor = aten::linear(%input.5.dequant.1, %4.dequant.6, %111)

我们可以看到模型中没有 quantized::linear_dynamic,而是数值上等效的模式:aten::_choose_qparams_per_tensor - aten::quantize_per_tensor - aten::dequantize - aten::linear

# Get the size of the debug model
print_size_of_model(quantized_model_debug)
Size (MB): 438.406429

调试模型的大小与浮点模型接近,因为所有权重都是浮点数,尚未量化和冻结,这使得人们可以检查权重。你可以直接在 torchscript 模型中访问权重属性。在调试模型中访问权重与在 TorchScript 模型中访问权重相同。

print(quantized_model.bert.encoder.layer._c.getattr('0').attention.self.query.weight)
tensor([[-0.0157,  0.0257, -0.0269,  ...,  0.0158,  0.0764,  0.0548],
        [-0.0325,  0.0345, -0.0423,  ..., -0.0528,  0.1382,  0.0069],
        [ 0.0106,  0.0335,  0.0113,  ..., -0.0275,  0.0253, -0.0457],
        ...,
        [-0.0090,  0.0512,  0.0555,  ...,  0.0277,  0.0543, -0.0539],
        [-0.0195,  0.0943,  0.0619,  ..., -0.1040,  0.0598,  0.0465],
        [ 0.0009, -0.0949,  0.0097,  ..., -0.0183, -0.0511, -0.0085]],
        grad_fn=<CloneBackward>)

访问对应权重的 scale 和 zero_point 可以按如下方式进行 -

print(quantized_model.bert.encoder.layer._c.getattr('0').attention.self.query.getattr('4_scale_0'))
print(quantized_model.bert.encoder.layer._c.getattr('0').attention.self.query.getattr('4_zero_point_0'))

由于我们使用逐通道量化,因此得到逐通道的 scales 张量。

tensor([0.0009, 0.0011, 0.0010, 0.0011, 0.0034, 0.0013, 0.0010, 0.0010, 0.0013,
        0.0012, 0.0011, 0.0010, 0.0010, 0.0010, 0.0010, 0.0010, 0.0009, 0.0015,
        0.0016, 0.0036, 0.0012, 0.0009, 0.0010, 0.0014, 0.0008, 0.0008, 0.0008,
        ...,
        0.0019, 0.0023, 0.0013, 0.0018, 0.0012, 0.0031, 0.0015, 0.0013, 0.0014,
        0.0022, 0.0011, 0.0024])

零点 (Zero-point) 张量 -

tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        ..,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
       dtype=torch.int32)

4. 与 Eager Mode 结果比较

以下结果显示了通过遵循教程中提到的步骤,对同一模型进行 Eager Mode 量化获得的 F1 分数和模型大小。结果表明,对模型进行 Eager Mode 和图模式量化产生了相同的结果。

FP32 model results -
Size (MB): 438.016605
'f1': 0.901

INT8 model results -
Size (MB): 182.878029
'f1': 0.902

5. 对模型进行基准测试

我们使用虚拟输入对模型进行基准测试,并在生产服务器机器上比较浮点模型与 Eager Mode 和图模式量化模型的性能。

def benchmark(model):
    model = torch.jit.load(model)
    model.eval()
    torch.set_num_threads(1)
    input_ids = ids_tensor([8, 128], 2)
    token_type_ids = ids_tensor([8, 128], 2)
    attention_mask = ids_tensor([8, 128], vocab_size=2)
    elapsed = 0
    for _i in range(50):
        start = time.time()
        output = model(input_ids, token_type_ids, attention_mask)
        end = time.time()
        elapsed = elapsed + (end - start)
    print('Elapsed time: ', (elapsed / 50), ' s')
    return
print("Running benchmark for Float model")
benchmark(args.jit_model_path_float)
print("Running benchmark for Eager Mode Quantized model")
benchmark(args.jit_model_path_eager)
print("Running benchmark for Graph Mode Quantized model")
benchmark(args.jit_model_path_graph)
Running benchmark for Float model
Elapsed time: 4.49 s
Running benchmark for Eager Mode Quantized model
Elapsed time: 2.67 s
Running benchmark for Graph Mode Quantized model
Elapsed time: 2.69 s
As we can see both graph mode and eager mode quantized model have a similar speed up over the floating point model.

结论

在本教程中,我们演示了如何使用图模式将 BERT 等著名的最先进 NLP 模型转换为动态量化模型,其性能与 eager mode 相同。动态量化可以减小模型大小,同时对精度影响有限。

感谢阅读!一如既往,我们欢迎任何反馈意见,如果你有任何问题,请随时在此处创建 issue。

文档

查阅 PyTorch 的完整开发者文档

查看文档

教程

获取适合初学者和高级开发者的深入教程

查看教程

资源

查找开发资源并获得问题解答

查看资源