• 教程 >
  • (原型) 使用半结构化 (2:4) 稀疏性加速 BERT
快捷方式

(原型) 使用半结构化 (2:4) 稀疏性加速 BERT

创建于:2023 年 10 月 03 日 | 最后更新:2024 年 1 月 16 日 | 最后验证:2024 年 11 月 05 日

作者: Jesse Cai

与其他形式的稀疏性一样,半结构化稀疏性 是一种模型优化技术,旨在降低神经网络的内存开销和延迟,但会牺牲一定的模型精度。它也被称为 细粒度结构化稀疏性2:4 结构化稀疏性

半结构化稀疏性得名于其独特的稀疏模式,其中每 2n 个元素中有 n 个被剪枝。我们最常见的情况是 n=2,因此是 2:4 稀疏性。半结构化稀疏性特别有趣,因为它可以在 GPU 上高效加速,并且不像其他稀疏模式那样严重降低模型精度。

随着 半结构化稀疏性支持 的引入,无需离开 PyTorch 即可剪枝和加速半结构化稀疏模型。我们将在本教程中解释此过程。

../_static/img/pruning_flow.jpg

在本教程结束时,我们将把 BERT 问答模型稀疏化为 2:4 稀疏,对其进行微调以恢复几乎所有的 F1 损失(密集型为 86.92,稀疏型为 86.48)。最后,我们将加速此 2:4 稀疏模型以进行推理,从而实现 1.3 倍的加速。

要求

  • PyTorch >= 2.1。

  • 支持半结构化稀疏性的 NVIDIA GPU(计算能力 8.0+)。

注意

本教程专为半结构化稀疏性/一般稀疏性初学者设计。对于已有 2:4 稀疏模型的用户,使用 to_sparse_semi_structured 加速 nn.Linear 层以进行推理非常简单,只需

import torch
from torch.sparse import to_sparse_semi_structured, SparseSemiStructuredTensor
from torch.utils.benchmark import Timer
SparseSemiStructuredTensor._FORCE_CUTLASS = True

# mask Linear weight to be 2:4 sparse
mask = torch.Tensor([0, 0, 1, 1]).tile((3072, 2560)).cuda().bool()
linear = torch.nn.Linear(10240, 3072).half().cuda().eval()
linear.weight = torch.nn.Parameter(mask * linear.weight)

x = torch.rand(3072, 10240).half().cuda()

with torch.inference_mode():
    dense_output = linear(x)
    dense_t = Timer(stmt="linear(x)",
                    globals={"linear": linear,
                             "x": x}).blocked_autorange().median * 1e3

    # accelerate via SparseSemiStructuredTensor
    linear.weight = torch.nn.Parameter(to_sparse_semi_structured(linear.weight))

    sparse_output = linear(x)
    sparse_t = Timer(stmt="linear(x)",
                    globals={"linear": linear,
                             "x": x}).blocked_autorange().median * 1e3

    # sparse and dense matmul are numerically equivalent
    assert torch.allclose(sparse_output, dense_output, atol=1e-3)
    print(f"Dense: {dense_t:.3f}ms Sparse: {sparse_t:.3f}ms | Speedup: {(dense_t / sparse_t):.3f}x")

在 A100 80GB 上,我们看到:密集型:0.870 毫秒 稀疏型:0.630 毫秒 | 加速:1.382 倍

半结构化稀疏性解决了什么问题?

稀疏性背后的总体动机很简单:如果您的网络中有零,您可以避免存储/使用这些参数进行计算。但是,稀疏性的具体细节很棘手。从表面上看,将参数置零并不会影响我们模型的延迟/内存开销。

这是因为密集张量仍然包含剪枝(零)元素,密集矩阵乘法内核仍将对这些元素进行运算。为了实现性能提升,我们需要将密集内核换成稀疏内核,后者会跳过涉及剪枝元素的计算。

为此,这些内核在稀疏矩阵上工作,稀疏矩阵不存储剪枝元素,而是以压缩格式存储指定的元素。

对于半结构化稀疏性,我们存储原始参数的一半,以及一些关于元素如何排列的压缩元数据。

存在许多不同的稀疏布局,每种布局都有其自身的优点和缺点。2:4 半结构化稀疏布局特别有趣,原因有二:1. 与以前的稀疏格式不同,半结构化稀疏性旨在在 GPU 上高效加速。

2020 年,NVIDIA 在其 Ampere 架构中引入了对半结构化稀疏性的硬件支持,并且还通过 CUTLASS/cuSPARSELt 发布了快速稀疏内核。

  1. 与此同时,与其他稀疏格式相比,半结构化稀疏性对模型精度的影响往往较小,尤其是在考虑更高级的剪枝/微调方法时。NVIDIA 在他们的 白皮书 中表明,简单的范例是进行一次幅度剪枝以达到 2:4 稀疏,然后重新训练模型,即可获得几乎相同的模型精度。

半结构化稀疏性处于一个最佳位置,在较低的稀疏度(50%)下提供 2 倍(理论上)的加速,同时仍然足够精细以保持模型精度。

网络

数据集

指标

密集型 FP16

稀疏型 FP16

ResNet-50

ImageNet

Top-1

76.1

76.2

ResNeXt-101_32x8d

ImageNet

Top-1

79.3

79.3

Xception

ImageNet

Top-1

79.2

79.2

SSD-RN50

COCO2017

bbAP

24.8

24.8

MaskRCNN-RN50

COCO2017

bbAP

37.9

37.9

FairSeq Transformer

EN-DE WMT14

BLEU

28.2

28.5

BERT-Large

SQuAD v1.1

F1

91.9

91.9

从工作流程的角度来看,半结构化稀疏性还具有额外的优势。由于稀疏度固定为 50%,因此更容易将稀疏化模型的问题分解为两个不同的子问题

  • 精度 - 我们如何找到一组 2:4 稀疏权重,以最大限度地减少模型的精度下降?

  • 性能 - 我们如何加速 2:4 稀疏权重以进行推理并减少内存开销?

\[\begin{bmatrix} 1 & 1 & 0 & 0 \\ 0 & 0 & 1 & 1 \\ 1 & 0 & 0 & 0 \\ 0 & 0 & 1 & 1 \\ \end{bmatrix} \]

这两个问题之间的自然交接点是置零的密集张量。我们的推理解决方案旨在压缩和加速这种格式的张量。我们预计许多用户会提出自定义掩码解决方案,因为这是一个活跃的研究领域。

现在我们已经对半结构化稀疏性有了一些了解,让我们将其应用于在问答任务 SQuAD 上训练的 BERT 模型。

简介和设置

让我们首先导入我们需要的所有包。

import collections
import datasets
import evaluate
import numpy as np
import torch
import torch.utils.benchmark as benchmark
from torch import nn
from torch.sparse import to_sparse_semi_structured, SparseSemiStructuredTensor
from torch.ao.pruning import WeightNormSparsifier
import transformers

# force CUTLASS use if cuSPARSELt is not available
SparseSemiStructuredTensor._FORCE_CUTLASS = True
torch.manual_seed(100)

我们还需要定义一些特定于手头数据集/任务的辅助函数。这些函数改编自 这个 huggingface 课程作为参考。

def preprocess_validation_function(examples, tokenizer):
    inputs = tokenizer(
        [q.strip() for q in examples["question"]],
        examples["context"],
        max_length=384,
        truncation="only_second",
        return_overflowing_tokens=True,
        return_offsets_mapping=True,
        padding="max_length",
    )
    sample_map = inputs.pop("overflow_to_sample_mapping")
    example_ids = []

    for i in range(len(inputs["input_ids"])):
        sample_idx = sample_map[i]
        example_ids.append(examples["id"][sample_idx])
        sequence_ids = inputs.sequence_ids(i)
        offset = inputs["offset_mapping"][i]
        inputs["offset_mapping"][i] = [
            o if sequence_ids[k] == 1 else None for k, o in enumerate(offset)
        ]

    inputs["example_id"] = example_ids
    return inputs


def preprocess_train_function(examples, tokenizer):
    inputs = tokenizer(
        [q.strip() for q in examples["question"]],
        examples["context"],
        max_length=384,
        truncation="only_second",
        return_offsets_mapping=True,
        padding="max_length",
    )

    offset_mapping = inputs["offset_mapping"]
    answers = examples["answers"]
    start_positions = []
    end_positions = []

    for i, (offset, answer) in enumerate(zip(offset_mapping, answers)):
        start_char = answer["answer_start"][0]
        end_char = start_char + len(answer["text"][0])
        sequence_ids = inputs.sequence_ids(i)

        # Find the start and end of the context
        idx = 0
        while sequence_ids[idx] != 1:
            idx += 1
        context_start = idx
        while sequence_ids[idx] == 1:
            idx += 1
        context_end = idx - 1

        # If the answer is not fully inside the context, label it (0, 0)
        if offset[context_start][0] > end_char or offset[context_end][1] < start_char:
            start_positions.append(0)
            end_positions.append(0)
        else:
            # Otherwise it's the start and end token positions
            idx = context_start
            while idx <= context_end and offset[idx][0] <= start_char:
                idx += 1
            start_positions.append(idx - 1)

            idx = context_end
            while idx >= context_start and offset[idx][1] >= end_char:
                idx -= 1
            end_positions.append(idx + 1)

    inputs["start_positions"] = start_positions
    inputs["end_positions"] = end_positions
    return inputs


def compute_metrics(start_logits, end_logits, features, examples):
    n_best = 20
    max_answer_length = 30
    metric = evaluate.load("squad")

    example_to_features = collections.defaultdict(list)
    for idx, feature in enumerate(features):
        example_to_features[feature["example_id"]].append(idx)

    predicted_answers = []
    # for example in tqdm(examples):
    for example in examples:
        example_id = example["id"]
        context = example["context"]
        answers = []

        # Loop through all features associated with that example
        for feature_index in example_to_features[example_id]:
            start_logit = start_logits[feature_index]
            end_logit = end_logits[feature_index]
            offsets = features[feature_index]["offset_mapping"]

            start_indexes = np.argsort(start_logit)[-1 : -n_best - 1 : -1].tolist()
            end_indexes = np.argsort(end_logit)[-1 : -n_best - 1 : -1].tolist()
            for start_index in start_indexes:
                for end_index in end_indexes:
                    # Skip answers that are not fully in the context
                    if offsets[start_index] is None or offsets[end_index] is None:
                        continue
                    # Skip answers with a length that is either < 0
                    # or > max_answer_length
                    if (
                        end_index < start_index
                        or end_index - start_index + 1 > max_answer_length
                    ):
                        continue

                    answer = {
                        "text": context[
                            offsets[start_index][0] : offsets[end_index][1]
                        ],
                        "logit_score": start_logit[start_index] + end_logit[end_index],
                    }
                    answers.append(answer)

        # Select the answer with the best score
        if len(answers) > 0:
            best_answer = max(answers, key=lambda x: x["logit_score"])
            predicted_answers.append(
                {"id": example_id, "prediction_text": best_answer["text"]}
            )
        else:
            predicted_answers.append({"id": example_id, "prediction_text": ""})

    theoretical_answers = [
        {"id": ex["id"], "answers": ex["answers"]} for ex in examples
    ]
    return metric.compute(predictions=predicted_answers, references=theoretical_answers)

现在这些都已定义,我们只需要一个额外的辅助函数,这将帮助我们对模型进行基准测试。

def measure_execution_time(model, batch_sizes, dataset):
    dataset_for_model = dataset.remove_columns(["example_id", "offset_mapping"])
    dataset_for_model.set_format("torch")
    model.cuda()
    batch_size_to_time_sec = {}
    for batch_size in batch_sizes:
        batch = {
            k: dataset_for_model[k][:batch_size].to(model.device)
            for k in dataset_for_model.column_names
        }

        with torch.inference_mode():
            timer = benchmark.Timer(
                stmt="model(**batch)", globals={"model": model, "batch": batch}
            )
            p50 = timer.blocked_autorange().median * 1000
        batch_size_to_time_sec[batch_size] = p50
    return batch_size_to_time_sec

我们将首先加载我们的模型和分词器,然后设置我们的数据集。

# load model
model_name = "bert-base-cased"
tokenizer = transformers.AutoTokenizer.from_pretrained(model_name)
model = transformers.AutoModelForQuestionAnswering.from_pretrained(model_name)
print(f"Loading tokenizer: {model_name}")
print(f"Loading model: {model_name}")

# set up train and val dataset
squad_dataset = datasets.load_dataset("squad")
tokenized_squad_dataset = {}
tokenized_squad_dataset["train"] = squad_dataset["train"].map(
    lambda x: preprocess_train_function(x, tokenizer), batched=True
)
tokenized_squad_dataset["validation"] = squad_dataset["validation"].map(
    lambda x: preprocess_validation_function(x, tokenizer),
    batched=True,
    remove_columns=squad_dataset["train"].column_names,
)
data_collator = transformers.DataCollatorWithPadding(tokenizer=tokenizer)

接下来,我们将在 SQuAD 上训练模型的快速基线。此任务要求我们的模型在给定上下文(维基百科文章)中识别跨度或文本段,以回答给定的问题。运行以下代码,我得到的 F1 分数为 86.9。这非常接近 NVIDIA 报告的分数,差异可能是由于 BERT-base 与 BERT-large 或微调超参数造成的。

training_args = transformers.TrainingArguments(
    "trainer",
    num_train_epochs=1,
    lr_scheduler_type="constant",
    per_device_train_batch_size=64,
    per_device_eval_batch_size=512,
)

trainer = transformers.Trainer(
    model,
    training_args,
    train_dataset=tokenized_squad_dataset["train"],
    eval_dataset=tokenized_squad_dataset["validation"],
    data_collator=data_collator,
    tokenizer=tokenizer,
)

trainer.train()

# batch sizes to compare for eval
batch_sizes = [4, 16, 64, 256]
# 2:4 sparsity require fp16, so we cast here for a fair comparison
with torch.autocast("cuda"):
    with torch.inference_mode():
        predictions = trainer.predict(tokenized_squad_dataset["validation"])
    start_logits, end_logits = predictions.predictions
    fp16_baseline = compute_metrics(
        start_logits,
        end_logits,
        tokenized_squad_dataset["validation"],
        squad_dataset["validation"],
    )
    fp16_time = measure_execution_time(
        model,
        batch_sizes,
        tokenized_squad_dataset["validation"],
    )
print("fp16", fp16_baseline)
print("cuda_fp16 time", fp16_time)

# fp16 {'exact_match': 78.53358561967833, 'f1': 86.9280493093186}
# cuda_fp16 time {4: 10.927572380751371, 16: 19.607915310189128, 64: 73.18846387788653, 256: 286.91255673766136}

将 BERT 剪枝为 2:4 稀疏

现在我们有了基线,是时候剪枝 BERT 了。有很多不同的剪枝策略,但最常见的一种是 幅度剪枝,它旨在删除 L1 范数最低的权重。NVIDIA 在他们的所有结果中都使用了幅度剪枝,这是一种常见的基线。

为此,我们将使用 torch.ao.pruning 包,其中包含权重范数(幅度)稀疏器。这些稀疏器的工作原理是将掩码参数化应用于模型中的权重张量。这使它们可以通过屏蔽剪枝权重来模拟稀疏性。

我们还必须决定将稀疏性应用于模型的哪些层,在本例中是所有的 nn.Linear 层,除了特定于任务的头部输出。这是因为半结构化稀疏性具有 形状约束,而特定于任务的 nn.Linear 层不满足这些约束。

sparsifier = WeightNormSparsifier(
    # apply sparsity to all blocks
    sparsity_level=1.0,
    # shape of 4 elemens is a block
    sparse_block_shape=(1, 4),
    # two zeros for every block of 4
    zeros_per_block=2
)

# add to config if nn.Linear and in the BERT model.
sparse_config = [
    {"tensor_fqn": f"{fqn}.weight"}
    for fqn, module in model.named_modules()
    if isinstance(module, nn.Linear) and "layer" in fqn
]

剪枝模型的第一步是插入参数化以掩码模型的权重。这通过 prepare 步骤完成。任何时候我们尝试访问 .weight,我们都会得到 mask * weight

# Prepare the model, insert fake-sparsity parameterizations for training
sparsifier.prepare(model, sparse_config)
print(model.bert.encoder.layer[0].output)

# BertOutput(
#   (dense): ParametrizedLinear(
#     in_features=3072, out_features=768, bias=True
#     (parametrizations): ModuleDict(
#       (weight): ParametrizationList(
#         (0-5): 6 x FakeSparsity()
#       )
#     )
#   )
#   (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
#   (dropout): Dropout(p=0.1, inplace=False)
# )

然后,我们将执行单个剪枝步骤。所有剪枝器都实现了一个 update_mask() 方法,该方法使用由剪枝器实现确定的逻辑更新掩码。step 方法为稀疏配置中指定的权重调用此 update_mask 函数。

我们还将评估模型,以显示零样本剪枝或不进行微调/重新训练的剪枝的精度下降。

sparsifier.step()
with torch.autocast("cuda"):
    with torch.inference_mode():
        predictions = trainer.predict(tokenized_squad_dataset["validation"])
    pruned = compute_metrics(
        *predictions.predictions,
        tokenized_squad_dataset["validation"],
        squad_dataset["validation"],
    )
print("pruned eval metrics:", pruned)
# pruned eval metrics: {'exact_match': 40.59602649006622, 'f1': 56.51610004515979}

在这种状态下,我们可以开始微调模型,更新不会被剪枝的元素,以更好地弥补精度损失。一旦我们达到满意的状态,我们就可以调用 squash_mask 将掩码和权重融合在一起。这将删除参数化,我们最终得到一个置零的 2:4 密集模型。

trainer.train()
sparsifier.squash_mask()
torch.set_printoptions(edgeitems=4)
print(model.bert.encoder.layer[0].intermediate.dense.weight)

# Parameter containing:
# tensor([[ 0.0000, -0.0237,  0.0000,  0.0130,  ..., -0.0462, -0.0000, 0.0000, -0.0272],
#        [ 0.0436, -0.0000, -0.0000,  0.0492,  ..., -0.0000,  0.0844,  0.0340, -0.0000],
#        [-0.0302, -0.0350,  0.0000,  0.0000,  ...,  0.0303,  0.0175, -0.0000,  0.0000],
#        [ 0.0000, -0.0000, -0.0529,  0.0327,  ...,  0.0213,  0.0000, -0.0000,  0.0735],
#        ...,
#        [ 0.0000, -0.0000, -0.0258, -0.0239,  ..., -0.0000, -0.0000,  0.0380,  0.0562],
#        [-0.0432, -0.0000,  0.0000, -0.0598,  ...,  0.0000, -0.0000,  0.0262  -0.0227],
#        [ 0.0244,  0.0921, -0.0000, -0.0000,  ..., -0.0000, -0.0784,  0.0000,  0.0761],
#        [ 0.0000,  0.0225, -0.0395, -0.0000,  ..., -0.0000,  0.0684, -0.0344, -0.0000]], device='cuda:0', requires_grad=True)

加速 2:4 稀疏模型进行推理 ——–i———————————— 现在我们有了这种格式的模型,我们可以像快速入门指南中一样加速它以进行推理。

model = model.cuda().half()
# accelerate for sparsity
for fqn, module in model.named_modules():
    if isinstance(module, nn.Linear) and "layer" in fqn:
        module.weight = nn.Parameter(to_sparse_semi_structured(module.weight))

with torch.inference_mode():
    predictions = trainer.predict(tokenized_squad_dataset["validation"])
start_logits, end_logits = predictions.predictions
metrics_sparse = compute_metrics(
    start_logits,
    end_logits,
    tokenized_squad_dataset["validation"],
    squad_dataset["validation"],
)
print("sparse eval metrics: ", metrics_sparse)
sparse_perf = measure_execution_time(
    model,
    batch_sizes,
    tokenized_squad_dataset["validation"],
)
print("sparse perf metrics: ", sparse_perf)

# sparse eval metrics:  {'exact_match': 78.43897824030275, 'f1': 86.48718950090766}
# sparse perf metrics:  {4: 12.621004460379481, 16: 15.368514601141214, 64: 58.702805917710066, 256: 244.19364519417286}

在幅度剪枝后重新训练我们的模型几乎恢复了模型剪枝时丢失的所有 F1。与此同时,我们实现了 bs=16 的 1.28 倍加速。请注意,并非所有形状都适合性能改进。当批量大小较小且计算时间有限时,稀疏内核可能比其密集对应内核慢。

结果

指标

fp16

2:4 稀疏

delta / 加速

完全匹配 (%)

78.53

78.44

-0.09

F1 (%)

86.93

86.49

-0.44

时间 (bs=4)

10.93

12.62

0.87 倍

时间 (bs=16)

19.61

15.37

1.28 倍

时间 (bs=64)

73.19

58.70

1.25 倍

时间 (bs=256)

286.91

244.19

1.18 倍

结论

在本教程中,我们展示了如何将 BERT 剪枝为 2:4 稀疏,以及如何加速 2:4 稀疏模型以进行推理。通过利用我们的 SparseSemiStructuredTensor 子类,我们能够实现比 fp16 基线快 1.3 倍的加速。我们还通过微调 BERT 来恢复任何丢失的 F1(密集型为 86.92,稀疏型为 86.48),证明了 2:4 稀疏性的优势。


评价本教程

© 版权所有 2024,PyTorch。

使用 Sphinx 构建,主题由 theme 提供,并由 Read the Docs 托管。

文档

访问 PyTorch 的全面开发者文档

查看文档

教程

获取面向初学者和高级开发者的深入教程

查看教程

资源

查找开发资源并获得问题解答

查看资源