(原型) 利用半结构化 (2:4) 稀疏加速 BERT¶
创建日期: Oct 03, 2023 | 最后更新日期: Jan 16, 2024 | 最后验证日期: Nov 05, 2024
作者: Jesse Cai
与其他形式的稀疏一样,半结构化稀疏是一种模型优化技术,旨在降低神经网络的内存开销和延迟,代价是损失部分模型精度。它也称为细粒度结构化稀疏或 2:4 结构化稀疏。
半结构化稀疏得名于其独特的稀疏模式,即每 2n 个元素中修剪 n 个元素。我们最常看到 n=2,因此称为 2:4 稀疏。半结构化稀疏尤其令人关注,因为它可以在 GPU 上高效加速,并且不像其他稀疏模式那样严重降低模型精度。
随着 半结构化稀疏支持 的引入,无需离开 PyTorch 即可对半结构化稀疏模型进行剪枝和加速。本教程将解释这一过程。

在本教程结束时,我们将把一个 BERT 问答模型稀疏化为 2:4 稀疏,并通过微调来恢复几乎所有的 F1 损失(密集模型 86.92 vs 稀疏模型 86.48)。最后,我们将加速这个 2:4 稀疏模型用于推理,获得 1.3 倍的加速比。
要求¶
PyTorch >= 2.1。
支持半结构化稀疏的 NVIDIA GPU(计算能力 8.0+)。
注意
本教程面向半结构化稀疏/一般稀疏的初学者。对于拥有现有 2:4 稀疏模型的用户,使用 to_sparse_semi_structured
加速 nn.Linear
层的推理就像以下代码一样简单:
import torch
from torch.sparse import to_sparse_semi_structured, SparseSemiStructuredTensor
from torch.utils.benchmark import Timer
SparseSemiStructuredTensor._FORCE_CUTLASS = True
# mask Linear weight to be 2:4 sparse
mask = torch.Tensor([0, 0, 1, 1]).tile((3072, 2560)).cuda().bool()
linear = torch.nn.Linear(10240, 3072).half().cuda().eval()
linear.weight = torch.nn.Parameter(mask * linear.weight)
x = torch.rand(3072, 10240).half().cuda()
with torch.inference_mode():
dense_output = linear(x)
dense_t = Timer(stmt="linear(x)",
globals={"linear": linear,
"x": x}).blocked_autorange().median * 1e3
# accelerate via SparseSemiStructuredTensor
linear.weight = torch.nn.Parameter(to_sparse_semi_structured(linear.weight))
sparse_output = linear(x)
sparse_t = Timer(stmt="linear(x)",
globals={"linear": linear,
"x": x}).blocked_autorange().median * 1e3
# sparse and dense matmul are numerically equivalent
assert torch.allclose(sparse_output, dense_output, atol=1e-3)
print(f"Dense: {dense_t:.3f}ms Sparse: {sparse_t:.3f}ms | Speedup: {(dense_t / sparse_t):.3f}x")
在 A100 80GB 上,我们看到: Dense: 0.870ms Sparse: 0.630ms | Speedup: 1.382x
半结构化稀疏解决了什么问题?¶
稀疏性背后的总体动机很简单:如果你的网络中存在零,你可以避免存储/计算这些参数。然而,稀疏性的具体细节很复杂。直接将参数置零并不会开箱即用地影响我们模型的延迟/内存开销。
这是因为密集张量仍然包含被剪枝(零)的元素,密集矩阵乘法核函数仍然会处理这些元素。为了实现性能提升,我们需要用稀疏核函数替换密集核函数,稀疏核函数会跳过涉及被剪枝元素的计算。
为此,这些核函数作用于稀疏矩阵,稀疏矩阵不存储被剪枝的元素,而是以压缩格式存储指定的元素。
对于半结构化稀疏,我们存储原始参数的一半,以及关于元素排列方式的一些压缩元数据。
存在许多不同的稀疏布局,每种布局都有各自的优缺点。2:4 半结构化稀疏布局因两个原因尤其有趣:1. 与之前的稀疏格式不同,半结构化稀疏旨在在 GPU 上高效加速。
2020 年,NVIDIA 推出了 Ampere 架构的硬件支持半结构化稀疏,并通过 CUTLASS/cuSPARSELt 发布了快速稀疏核函数。
同时,与其它稀疏格式相比,半结构化稀疏对模型精度的影响通常较小,尤其是在考虑更先进的剪枝/微调方法时。NVIDIA 在其白皮书中表明,简单地一次进行幅度剪枝以达到 2:4 稀疏度,然后重新训练模型,可以获得几乎相同的模型精度。
半结构化稀疏处于一个有利位置,在更低的稀疏度(50%)下提供 2 倍(理论上)的加速比,同时仍然足够细粒度以保留模型精度。
网络 |
数据集 |
指标 |
密集 FP16 |
稀疏 FP16 |
---|---|---|---|---|
ResNet-50 |
ImageNet |
Top-1 |
76.1 |
76.2 |
ResNeXt-101_32x8d |
ImageNet |
Top-1 |
79.3 |
79.3 |
Xception |
ImageNet |
Top-1 |
79.2 |
79.2 |
SSD-RN50 |
COCO2017 |
bbAP |
24.8 |
24.8 |
MaskRCNN-RN50 |
COCO2017 |
bbAP |
37.9 |
37.9 |
FairSeq Transformer |
EN-DE WMT14 |
BLEU |
28.2 |
28.5 |
BERT-Large |
SQuAD v1.1 |
F1 |
91.9 |
91.9 |
从工作流程的角度来看,半结构化稀疏还有一个额外的优势。由于稀疏度固定为 50%,将模型稀疏化的问题分解为两个不同的子问题会更容易:
精度 - 我们如何找到一组 2:4 稀疏权重,以最大程度地减少模型的精度下降?
性能 - 我们如何加速我们的 2:4 稀疏权重以用于推理并降低内存开销?
这两个问题之间的自然交接点是归零后的密集张量。我们的推理解决方案旨在以这种格式压缩和加速张量。我们预计许多用户会提出自定义遮罩解决方案,因为这是一个活跃的研究领域。
既然我们对半结构化稀疏有了更多了解,接下来将其应用于在问答任务 SQuAD 上训练的 BERT 模型。
介绍与设置¶
首先导入所有需要的包。
import collections
import datasets
import evaluate
import numpy as np
import torch
import torch.utils.benchmark as benchmark
from torch import nn
from torch.sparse import to_sparse_semi_structured, SparseSemiStructuredTensor
from torch.ao.pruning import WeightNormSparsifier
import transformers
# force CUTLASS use if cuSPARSELt is not available
SparseSemiStructuredTensor._FORCE_CUTLASS = True
torch.manual_seed(100)
我们还需要定义一些针对当前数据集/任务的辅助函数。这些函数改编自 这篇 huggingface 课程作为参考。
def preprocess_validation_function(examples, tokenizer):
inputs = tokenizer(
[q.strip() for q in examples["question"]],
examples["context"],
max_length=384,
truncation="only_second",
return_overflowing_tokens=True,
return_offsets_mapping=True,
padding="max_length",
)
sample_map = inputs.pop("overflow_to_sample_mapping")
example_ids = []
for i in range(len(inputs["input_ids"])):
sample_idx = sample_map[i]
example_ids.append(examples["id"][sample_idx])
sequence_ids = inputs.sequence_ids(i)
offset = inputs["offset_mapping"][i]
inputs["offset_mapping"][i] = [
o if sequence_ids[k] == 1 else None for k, o in enumerate(offset)
]
inputs["example_id"] = example_ids
return inputs
def preprocess_train_function(examples, tokenizer):
inputs = tokenizer(
[q.strip() for q in examples["question"]],
examples["context"],
max_length=384,
truncation="only_second",
return_offsets_mapping=True,
padding="max_length",
)
offset_mapping = inputs["offset_mapping"]
answers = examples["answers"]
start_positions = []
end_positions = []
for i, (offset, answer) in enumerate(zip(offset_mapping, answers)):
start_char = answer["answer_start"][0]
end_char = start_char + len(answer["text"][0])
sequence_ids = inputs.sequence_ids(i)
# Find the start and end of the context
idx = 0
while sequence_ids[idx] != 1:
idx += 1
context_start = idx
while sequence_ids[idx] == 1:
idx += 1
context_end = idx - 1
# If the answer is not fully inside the context, label it (0, 0)
if offset[context_start][0] > end_char or offset[context_end][1] < start_char:
start_positions.append(0)
end_positions.append(0)
else:
# Otherwise it's the start and end token positions
idx = context_start
while idx <= context_end and offset[idx][0] <= start_char:
idx += 1
start_positions.append(idx - 1)
idx = context_end
while idx >= context_start and offset[idx][1] >= end_char:
idx -= 1
end_positions.append(idx + 1)
inputs["start_positions"] = start_positions
inputs["end_positions"] = end_positions
return inputs
def compute_metrics(start_logits, end_logits, features, examples):
n_best = 20
max_answer_length = 30
metric = evaluate.load("squad")
example_to_features = collections.defaultdict(list)
for idx, feature in enumerate(features):
example_to_features[feature["example_id"]].append(idx)
predicted_answers = []
# for example in tqdm(examples):
for example in examples:
example_id = example["id"]
context = example["context"]
answers = []
# Loop through all features associated with that example
for feature_index in example_to_features[example_id]:
start_logit = start_logits[feature_index]
end_logit = end_logits[feature_index]
offsets = features[feature_index]["offset_mapping"]
start_indexes = np.argsort(start_logit)[-1 : -n_best - 1 : -1].tolist()
end_indexes = np.argsort(end_logit)[-1 : -n_best - 1 : -1].tolist()
for start_index in start_indexes:
for end_index in end_indexes:
# Skip answers that are not fully in the context
if offsets[start_index] is None or offsets[end_index] is None:
continue
# Skip answers with a length that is either < 0
# or > max_answer_length
if (
end_index < start_index
or end_index - start_index + 1 > max_answer_length
):
continue
answer = {
"text": context[
offsets[start_index][0] : offsets[end_index][1]
],
"logit_score": start_logit[start_index] + end_logit[end_index],
}
answers.append(answer)
# Select the answer with the best score
if len(answers) > 0:
best_answer = max(answers, key=lambda x: x["logit_score"])
predicted_answers.append(
{"id": example_id, "prediction_text": best_answer["text"]}
)
else:
predicted_answers.append({"id": example_id, "prediction_text": ""})
theoretical_answers = [
{"id": ex["id"], "answers": ex["answers"]} for ex in examples
]
return metric.compute(predictions=predicted_answers, references=theoretical_answers)
定义好这些之后,我们还需要一个额外的辅助函数,它将帮助我们对模型进行基准测试。
def measure_execution_time(model, batch_sizes, dataset):
dataset_for_model = dataset.remove_columns(["example_id", "offset_mapping"])
dataset_for_model.set_format("torch")
model.cuda()
batch_size_to_time_sec = {}
for batch_size in batch_sizes:
batch = {
k: dataset_for_model[k][:batch_size].to(model.device)
for k in dataset_for_model.column_names
}
with torch.inference_mode():
timer = benchmark.Timer(
stmt="model(**batch)", globals={"model": model, "batch": batch}
)
p50 = timer.blocked_autorange().median * 1000
batch_size_to_time_sec[batch_size] = p50
return batch_size_to_time_sec
我们将从加载模型和分词器开始,然后设置我们的数据集。
# load model
model_name = "bert-base-cased"
tokenizer = transformers.AutoTokenizer.from_pretrained(model_name)
model = transformers.AutoModelForQuestionAnswering.from_pretrained(model_name)
print(f"Loading tokenizer: {model_name}")
print(f"Loading model: {model_name}")
# set up train and val dataset
squad_dataset = datasets.load_dataset("squad")
tokenized_squad_dataset = {}
tokenized_squad_dataset["train"] = squad_dataset["train"].map(
lambda x: preprocess_train_function(x, tokenizer), batched=True
)
tokenized_squad_dataset["validation"] = squad_dataset["validation"].map(
lambda x: preprocess_validation_function(x, tokenizer),
batched=True,
remove_columns=squad_dataset["train"].column_names,
)
data_collator = transformers.DataCollatorWithPadding(tokenizer=tokenizer)
接下来,我们将在 SQuAD 上快速训练一个模型基线。此任务要求我们的模型在给定上下文(维基百科文章)中识别文本跨度或片段,以回答给定的问题。运行以下代码会给我一个 F1 得分 86.9。这与 NVIDIA 报告的得分非常接近,差异可能是由于使用了 BERT-base 而非 BERT-large 模型,或者微调超参数不同造成的。
training_args = transformers.TrainingArguments(
"trainer",
num_train_epochs=1,
lr_scheduler_type="constant",
per_device_train_batch_size=64,
per_device_eval_batch_size=512,
)
trainer = transformers.Trainer(
model,
training_args,
train_dataset=tokenized_squad_dataset["train"],
eval_dataset=tokenized_squad_dataset["validation"],
data_collator=data_collator,
tokenizer=tokenizer,
)
trainer.train()
# batch sizes to compare for eval
batch_sizes = [4, 16, 64, 256]
# 2:4 sparsity require fp16, so we cast here for a fair comparison
with torch.autocast("cuda"):
with torch.inference_mode():
predictions = trainer.predict(tokenized_squad_dataset["validation"])
start_logits, end_logits = predictions.predictions
fp16_baseline = compute_metrics(
start_logits,
end_logits,
tokenized_squad_dataset["validation"],
squad_dataset["validation"],
)
fp16_time = measure_execution_time(
model,
batch_sizes,
tokenized_squad_dataset["validation"],
)
print("fp16", fp16_baseline)
print("cuda_fp16 time", fp16_time)
# fp16 {'exact_match': 78.53358561967833, 'f1': 86.9280493093186}
# cuda_fp16 time {4: 10.927572380751371, 16: 19.607915310189128, 64: 73.18846387788653, 256: 286.91255673766136}
将 BERT 剪枝为 2:4 稀疏¶
现在我们有了基线模型,是时候对 BERT 进行剪枝了。存在许多不同的剪枝策略,但最常见的一种是幅度剪枝,它旨在移除 L1 范数最低的权重。NVIDIA 在其所有结果中都使用了幅度剪枝,这是一种常见的基线方法。
为此,我们将使用 torch.ao.pruning
包,其中包含权重范数(幅度)稀疏器。这些稀疏器通过对模型中的权重张量应用遮罩参数化来工作。这使得它们可以通过遮罩掉被剪枝的权重来模拟稀疏性。
我们还需要决定将稀疏性应用于模型的哪些层,在这种情况下是所有的 nn.Linear 层,除了特定任务的头部输出层。这是因为半结构化稀疏具有形状约束,而特定任务的 nn.Linear 层不满足这些约束。
sparsifier = WeightNormSparsifier(
# apply sparsity to all blocks
sparsity_level=1.0,
# shape of 4 elemens is a block
sparse_block_shape=(1, 4),
# two zeros for every block of 4
zeros_per_block=2
)
# add to config if nn.Linear and in the BERT model.
sparse_config = [
{"tensor_fqn": f"{fqn}.weight"}
for fqn, module in model.named_modules()
if isinstance(module, nn.Linear) and "layer" in fqn
]
剪枝模型的第一步是插入参数化,用于对模型的权重进行遮罩。这通过 prepare 步骤完成。每当我们尝试访问 .weight
时,我们将改为获取 mask * weight
。
# Prepare the model, insert fake-sparsity parameterizations for training
sparsifier.prepare(model, sparse_config)
print(model.bert.encoder.layer[0].output)
# BertOutput(
# (dense): ParametrizedLinear(
# in_features=3072, out_features=768, bias=True
# (parametrizations): ModuleDict(
# (weight): ParametrizationList(
# (0-5): 6 x FakeSparsity()
# )
# )
# )
# (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
# (dropout): Dropout(p=0.1, inplace=False)
# )
然后,我们将执行一个剪枝步骤。所有剪枝器都实现了一个 update_mask()
方法,该方法根据剪枝器实现所确定的逻辑来更新遮罩。step 方法会为稀疏配置中指定的权重调用这个 update_mask
函数。
我们还将评估模型,以展示零样本剪枝(即未进行微调/重新训练的剪枝)的精度下降。
sparsifier.step()
with torch.autocast("cuda"):
with torch.inference_mode():
predictions = trainer.predict(tokenized_squad_dataset["validation"])
pruned = compute_metrics(
*predictions.predictions,
tokenized_squad_dataset["validation"],
squad_dataset["validation"],
)
print("pruned eval metrics:", pruned)
# pruned eval metrics: {'exact_match': 40.59602649006622, 'f1': 56.51610004515979}
在此状态下,我们可以开始微调模型,更新那些不会被剪枝的元素,以更好地弥补精度损失。达到满意状态后,我们可以调用 squash_mask
将遮罩和权重融合在一起。这将移除参数化,我们得到一个归零的 2:4 密集模型。
trainer.train()
sparsifier.squash_mask()
torch.set_printoptions(edgeitems=4)
print(model.bert.encoder.layer[0].intermediate.dense.weight)
# Parameter containing:
# tensor([[ 0.0000, -0.0237, 0.0000, 0.0130, ..., -0.0462, -0.0000, 0.0000, -0.0272],
# [ 0.0436, -0.0000, -0.0000, 0.0492, ..., -0.0000, 0.0844, 0.0340, -0.0000],
# [-0.0302, -0.0350, 0.0000, 0.0000, ..., 0.0303, 0.0175, -0.0000, 0.0000],
# [ 0.0000, -0.0000, -0.0529, 0.0327, ..., 0.0213, 0.0000, -0.0000, 0.0735],
# ...,
# [ 0.0000, -0.0000, -0.0258, -0.0239, ..., -0.0000, -0.0000, 0.0380, 0.0562],
# [-0.0432, -0.0000, 0.0000, -0.0598, ..., 0.0000, -0.0000, 0.0262 -0.0227],
# [ 0.0244, 0.0921, -0.0000, -0.0000, ..., -0.0000, -0.0784, 0.0000, 0.0761],
# [ 0.0000, 0.0225, -0.0395, -0.0000, ..., -0.0000, 0.0684, -0.0344, -0.0000]], device='cuda:0', requires_grad=True)
加速用于推理的 2:4 稀疏模型。现在我们有了这种格式的模型,就可以像快速入门指南中一样对其进行推理加速。
model = model.cuda().half()
# accelerate for sparsity
for fqn, module in model.named_modules():
if isinstance(module, nn.Linear) and "layer" in fqn:
module.weight = nn.Parameter(to_sparse_semi_structured(module.weight))
with torch.inference_mode():
predictions = trainer.predict(tokenized_squad_dataset["validation"])
start_logits, end_logits = predictions.predictions
metrics_sparse = compute_metrics(
start_logits,
end_logits,
tokenized_squad_dataset["validation"],
squad_dataset["validation"],
)
print("sparse eval metrics: ", metrics_sparse)
sparse_perf = measure_execution_time(
model,
batch_sizes,
tokenized_squad_dataset["validation"],
)
print("sparse perf metrics: ", sparse_perf)
# sparse eval metrics: {'exact_match': 78.43897824030275, 'f1': 86.48718950090766}
# sparse perf metrics: {4: 12.621004460379481, 16: 15.368514601141214, 64: 58.702805917710066, 256: 244.19364519417286}
幅度剪枝后重新训练模型已恢复了模型剪枝时损失的几乎所有 F1 值。同时,对于批量大小为 16 (bs=16),我们实现了 1.28 倍的加速比。请注意,并非所有形状都适合性能提升。当批量大小较小且计算时间有限时,稀疏核函数可能比相应的密集核函数慢。
指标 |
fp16 |
2:4 稀疏 |
变化量 / 加速比 |
---|---|---|---|
精确匹配 (%) |
78.53 |
78.44 |
-0.09 |
F1 (%) |
86.93 |
86.49 |
-0.44 |
时间 (bs=4) |
10.93 |
12.62 |
0.87x |
时间 (bs=16) |
19.61 |
15.37 |
1.28x |
时间 (bs=64) |
73.19 |
58.70 |
1.25x |
时间 (bs=256) |
286.91 |
244.19 |
1.18x |
结论¶
在本教程中,我们展示了如何将 BERT 剪枝为 2:4 稀疏以及如何加速 2:4 稀疏模型用于推理。通过利用我们的 SparseSemiStructuredTensor 子类,我们能够实现相对于 fp16 基线 1.3 倍的加速比。我们还通过微调 BERT 来恢复损失的 F1(密集模型 86.92 vs 稀疏模型 86.48),展示了 2:4 稀疏的好处。