注意
点击 此处 下载完整的示例代码
(beta) 使用半结构化 (2:4) 稀疏性加速 BERT¶
**作者**:Jesse Cai
概述¶
与其他形式的稀疏性一样,**半结构化稀疏性**是一种模型优化技术,它试图以牺牲一些模型精度为代价来减少神经网络的内存开销和延迟。它也称为**细粒度结构化稀疏性**或**2:4 结构化稀疏性**。
半结构化稀疏性得名于其独特的稀疏性模式,其中每 2n 个元素中有 n 个被剪枝。我们最常看到 n=2,因此是 2:4 稀疏性。半结构化稀疏性特别有趣,因为它可以在 GPU 上高效加速,并且不会像其他稀疏性模式那样降低模型精度。
随着半结构化稀疏性支持的引入,可以在不离开 PyTorch 的情况下剪枝和加速半结构化稀疏模型。我们将在本教程中解释此过程。
在本教程结束时,我们将把 BERT 问答模型稀疏化为 2:4 稀疏,并对其进行微调以恢复几乎所有 F1 损失(密集模型 86.92 对比稀疏模型 86.48)。最后,我们将加速此 2:4 稀疏模型以进行推理,从而获得 1.3 倍的加速。
要求¶
PyTorch >= 2.1。
支持半结构化稀疏性的 NVIDIA GPU(计算能力 8.0+)。
本教程专为半结构化稀疏性以及稀疏性初学者设计。对于拥有现有 2:4 稀疏模型的用户,使用 to_sparse_semi_structured
加速推理中的 nn.Linear
层非常简单。这是一个示例
import torch
from torch.sparse import to_sparse_semi_structured, SparseSemiStructuredTensor
from torch.utils.benchmark import Timer
SparseSemiStructuredTensor._FORCE_CUTLASS = True
# mask Linear weight to be 2:4 sparse
mask = torch.Tensor([0, 0, 1, 1]).tile((3072, 2560)).cuda().bool()
linear = torch.nn.Linear(10240, 3072).half().cuda().eval()
linear.weight = torch.nn.Parameter(mask * linear.weight)
x = torch.rand(3072, 10240).half().cuda()
with torch.inference_mode():
dense_output = linear(x)
dense_t = Timer(stmt="linear(x)",
globals={"linear": linear,
"x": x}).blocked_autorange().median * 1e3
# accelerate via SparseSemiStructuredTensor
linear.weight = torch.nn.Parameter(to_sparse_semi_structured(linear.weight))
sparse_output = linear(x)
sparse_t = Timer(stmt="linear(x)",
globals={"linear": linear,
"x": x}).blocked_autorange().median * 1e3
# sparse and dense matmul are numerically equivalent
# On an A100 80GB, we see: `Dense: 0.870ms Sparse: 0.630ms | Speedup: 1.382x`
assert torch.allclose(sparse_output, dense_output, atol=1e-3)
print(f"Dense: {dense_t:.3f}ms Sparse: {sparse_t:.3f}ms | Speedup: {(dense_t / sparse_t):.3f}x")
Dense: 2.920ms Sparse: 1.646ms | Speedup: 1.773x
半结构化稀疏性解决了什么问题?¶
稀疏性的基本动机很简单:如果网络中存在零元素,可以通过不存储或计算这些参数来优化效率。然而,稀疏性的细节却很棘手。直接将参数置零并不会立即影响模型的延迟/内存开销。
这是因为密集张量仍然包含已修剪(为零)的元素,密集矩阵乘法内核仍然会对这些元素进行操作。为了实现性能提升,我们需要将密集内核替换为稀疏内核,后者会跳过涉及已修剪元素的计算。
为此,这些内核处理稀疏矩阵,这些矩阵不存储已修剪的元素,而是以压缩格式存储指定的元素。
对于半结构化稀疏性,我们存储原始参数的一半以及关于元素排列方式的一些压缩元数据。
存在许多不同的稀疏布局,每种布局都有其自身的优缺点。2:4 半结构化稀疏布局尤其引人注目,原因有两个。
与以前的稀疏格式不同,半结构化稀疏性旨在在 GPU 上高效加速。2020 年,英伟达在其 Ampere 架构中引入了对半结构化稀疏性的硬件支持,并通过 CUTLASS cuSPARSELt 发布了快速的稀疏内核。
同时,与其他稀疏格式相比,半结构化稀疏性往往对模型准确性的影响较小,尤其是在考虑更高级的修剪/微调方法时。英伟达在其 白皮书 中表明,只需进行一次幅度修剪以达到 2:4 稀疏,然后重新训练模型,即可获得几乎相同的模型准确率。
半结构化稀疏性处于一个最佳点,在较低的稀疏度水平(50%)下提供了 2 倍(理论)的加速,同时仍然足够细粒度以保留模型准确性。
网络 |
数据集 |
指标 |
密集 FP16 |
稀疏 FP16 |
---|---|---|---|---|
ResNet-50 |
ImageNet |
Top-1 |
76.1 |
76.2 |
ResNeXt-101_32x8d |
ImageNet |
Top-1 |
79.3 |
79.3 |
Xception |
ImageNet |
Top-1 |
79.2 |
79.2 |
SSD-RN50 |
COCO2017 |
bbAP |
24.8 |
24.8 |
MaskRCNN-RN50 |
COCO2017 |
bbAP |
37.9 |
37.9 |
FairSeq Transformer |
EN-DE WMT14 |
BLEU |
28.2 |
28.5 |
BERT-Large |
SQuAD v1.1 |
F1 |
91.9 |
91.9 |
从工作流程的角度来看,半结构化稀疏性具有额外的优势。由于稀疏度级别固定为 50%,因此可以更容易地将模型稀疏化的问题分解为两个不同的子问题。
准确性 - 我们如何找到一组 2:4 稀疏权重,以最大程度地减少模型的准确性下降?
性能 - 我们如何加速 2:4 稀疏权重以进行推理并降低内存开销?
这两个问题之间的自然交接点是置零的密集张量。我们的推理解决方案旨在压缩和加速此格式的张量。我们预计许多用户会提出自定义掩码解决方案,因为这是一个活跃的研究领域。
现在,我们已经对半结构化稀疏性有了更多了解,让我们将其应用于在问答任务 SQuAD 上训练的 BERT 模型。
简介和设置¶
让我们首先导入所有需要的包。
# If you are running this in Google Colab, run:
# .. code-block: python
#
# !pip install datasets transformers evaluate accelerate pandas
#
import os
os.environ["WANDB_DISABLED"] = "true"
import collections
import datasets
import evaluate
import numpy as np
import torch
import torch.utils.benchmark as benchmark
from torch import nn
from torch.sparse import to_sparse_semi_structured, SparseSemiStructuredTensor
from torch.ao.pruning import WeightNormSparsifier
import transformers
# force CUTLASS use if ``cuSPARSELt`` is not available
SparseSemiStructuredTensor._FORCE_CUTLASS = True
torch.manual_seed(100)
<torch._C.Generator object at 0x7efbfdcd3950>
我们还需要定义一些特定于数据集/任务的辅助函数。这些函数改编自 此 Hugging Face 课程作为参考。
def preprocess_validation_function(examples, tokenizer):
inputs = tokenizer(
[q.strip() for q in examples["question"]],
examples["context"],
max_length=384,
truncation="only_second",
return_overflowing_tokens=True,
return_offsets_mapping=True,
padding="max_length",
)
sample_map = inputs.pop("overflow_to_sample_mapping")
example_ids = []
for i in range(len(inputs["input_ids"])):
sample_idx = sample_map[i]
example_ids.append(examples["id"][sample_idx])
sequence_ids = inputs.sequence_ids(i)
offset = inputs["offset_mapping"][i]
inputs["offset_mapping"][i] = [
o if sequence_ids[k] == 1 else None for k, o in enumerate(offset)
]
inputs["example_id"] = example_ids
return inputs
def preprocess_train_function(examples, tokenizer):
inputs = tokenizer(
[q.strip() for q in examples["question"]],
examples["context"],
max_length=384,
truncation="only_second",
return_offsets_mapping=True,
padding="max_length",
)
offset_mapping = inputs["offset_mapping"]
answers = examples["answers"]
start_positions = []
end_positions = []
for i, (offset, answer) in enumerate(zip(offset_mapping, answers)):
start_char = answer["answer_start"][0]
end_char = start_char + len(answer["text"][0])
sequence_ids = inputs.sequence_ids(i)
# Find the start and end of the context
idx = 0
while sequence_ids[idx] != 1:
idx += 1
context_start = idx
while sequence_ids[idx] == 1:
idx += 1
context_end = idx - 1
# If the answer is not fully inside the context, label it (0, 0)
if offset[context_start][0] > end_char or offset[context_end][1] < start_char:
start_positions.append(0)
end_positions.append(0)
else:
# Otherwise it's the start and end token positions
idx = context_start
while idx <= context_end and offset[idx][0] <= start_char:
idx += 1
start_positions.append(idx - 1)
idx = context_end
while idx >= context_start and offset[idx][1] >= end_char:
idx -= 1
end_positions.append(idx + 1)
inputs["start_positions"] = start_positions
inputs["end_positions"] = end_positions
return inputs
def compute_metrics(start_logits, end_logits, features, examples):
n_best = 20
max_answer_length = 30
metric = evaluate.load("squad")
example_to_features = collections.defaultdict(list)
for idx, feature in enumerate(features):
example_to_features[feature["example_id"]].append(idx)
predicted_answers = []
# for example in ``tqdm`` (examples):
for example in examples:
example_id = example["id"]
context = example["context"]
answers = []
# Loop through all features associated with that example
for feature_index in example_to_features[example_id]:
start_logit = start_logits[feature_index]
end_logit = end_logits[feature_index]
offsets = features[feature_index]["offset_mapping"]
start_indexes = np.argsort(start_logit)[-1 : -n_best - 1 : -1].tolist()
end_indexes = np.argsort(end_logit)[-1 : -n_best - 1 : -1].tolist()
for start_index in start_indexes:
for end_index in end_indexes:
# Skip answers that are not fully in the context
if offsets[start_index] is None or offsets[end_index] is None:
continue
# Skip answers with a length that is either < 0
# or > max_answer_length
if (
end_index < start_index
or end_index - start_index + 1 > max_answer_length
):
continue
answer = {
"text": context[
offsets[start_index][0] : offsets[end_index][1]
],
"logit_score": start_logit[start_index] + end_logit[end_index],
}
answers.append(answer)
# Select the answer with the best score
if len(answers) > 0:
best_answer = max(answers, key=lambda x: x["logit_score"])
predicted_answers.append(
{"id": example_id, "prediction_text": best_answer["text"]}
)
else:
predicted_answers.append({"id": example_id, "prediction_text": ""})
theoretical_answers = [
{"id": ex["id"], "answers": ex["answers"]} for ex in examples
]
return metric.compute(predictions=predicted_answers, references=theoretical_answers)
定义完这些后,我们只需要一个额外的辅助函数,它将帮助我们对模型进行基准测试。
def measure_execution_time(model, batch_sizes, dataset):
dataset_for_model = dataset.remove_columns(["example_id", "offset_mapping"])
dataset_for_model.set_format("torch")
batch_size_to_time_sec = {}
for batch_size in batch_sizes:
batch = {
k: dataset_for_model[k][:batch_size].cuda()
for k in dataset_for_model.column_names
}
with torch.no_grad():
baseline_predictions = model(**batch)
timer = benchmark.Timer(
stmt="model(**batch)", globals={"model": model, "batch": batch}
)
p50 = timer.blocked_autorange().median * 1000
batch_size_to_time_sec[batch_size] = p50
model_c = torch.compile(model, fullgraph=True)
timer = benchmark.Timer(
stmt="model(**batch)", globals={"model": model_c, "batch": batch}
)
p50 = timer.blocked_autorange().median * 1000
batch_size_to_time_sec[f"{batch_size}_compile"] = p50
new_predictions = model_c(**batch)
return batch_size_to_time_sec
我们将从加载模型和分词器开始,然后设置数据集。
# load model
model_name = "bert-base-cased"
tokenizer = transformers.AutoTokenizer.from_pretrained(model_name)
model = transformers.AutoModelForQuestionAnswering.from_pretrained(model_name)
print(f"Loading tokenizer: {model_name}")
print(f"Loading model: {model_name}")
# set up train and val dataset
squad_dataset = datasets.load_dataset("squad")
tokenized_squad_dataset = {}
tokenized_squad_dataset["train"] = squad_dataset["train"].map(
lambda x: preprocess_train_function(x, tokenizer), batched=True
)
tokenized_squad_dataset["validation"] = squad_dataset["validation"].map(
lambda x: preprocess_validation_function(x, tokenizer),
batched=True,
remove_columns=squad_dataset["train"].column_names,
)
data_collator = transformers.DataCollatorWithPadding(tokenizer=tokenizer)
Some weights of BertForQuestionAnswering were not initialized from the model checkpoint at bert-base-cased and are newly initialized: ['qa_outputs.bias', 'qa_outputs.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
Loading tokenizer: bert-base-cased
Loading model: bert-base-cased
Downloading readme: 0%| | 0.00/7.62k [00:00<?, ?B/s]
Downloading readme: 100%|##########| 7.62k/7.62k [00:00<00:00, 49.7MB/s]
Downloading data: 0%| | 0.00/14.5M [00:00<?, ?B/s]
Downloading data: 73%|#######2 | 10.5M/14.5M [00:00<00:00, 68.5MB/s]
Downloading data: 100%|##########| 14.5M/14.5M [00:00<00:00, 87.3MB/s]
Downloading data: 0%| | 0.00/1.82M [00:00<?, ?B/s]
Downloading data: 100%|##########| 1.82M/1.82M [00:00<00:00, 40.4MB/s]
Generating train split: 0%| | 0/87599 [00:00<?, ? examples/s]
Generating train split: 73%|#######3 | 64000/87599 [00:00<00:00, 634428.21 examples/s]
Generating train split: 100%|##########| 87599/87599 [00:00<00:00, 627386.25 examples/s]
Generating validation split: 0%| | 0/10570 [00:00<?, ? examples/s]
Generating validation split: 100%|##########| 10570/10570 [00:00<00:00, 588480.85 examples/s]
Map: 0%| | 0/87599 [00:00<?, ? examples/s]
Map: 1%|1 | 1000/87599 [00:00<00:50, 1726.27 examples/s]
Map: 2%|2 | 2000/87599 [00:01<00:44, 1933.28 examples/s]
Map: 3%|3 | 3000/87599 [00:01<00:41, 2019.67 examples/s]
Map: 5%|4 | 4000/87599 [00:01<00:40, 2064.60 examples/s]
Map: 6%|5 | 5000/87599 [00:02<00:39, 2085.01 examples/s]
Map: 7%|6 | 6000/87599 [00:02<00:38, 2103.06 examples/s]
Map: 8%|7 | 7000/87599 [00:03<00:38, 2116.80 examples/s]
Map: 9%|9 | 8000/87599 [00:03<00:37, 2143.54 examples/s]
Map: 10%|# | 9000/87599 [00:04<00:36, 2134.16 examples/s]
Map: 11%|#1 | 10000/87599 [00:04<00:36, 2138.86 examples/s]
Map: 13%|#2 | 11000/87599 [00:05<00:35, 2155.32 examples/s]
Map: 14%|#3 | 12000/87599 [00:05<00:35, 2150.59 examples/s]
Map: 15%|#4 | 13000/87599 [00:06<00:34, 2152.33 examples/s]
Map: 16%|#5 | 14000/87599 [00:06<00:34, 2154.70 examples/s]
Map: 17%|#7 | 15000/87599 [00:07<00:33, 2153.25 examples/s]
Map: 18%|#8 | 16000/87599 [00:07<00:33, 2166.28 examples/s]
Map: 19%|#9 | 17000/87599 [00:08<00:32, 2159.96 examples/s]
Map: 21%|## | 18000/87599 [00:08<00:32, 2168.22 examples/s]
Map: 22%|##1 | 19000/87599 [00:08<00:31, 2163.13 examples/s]
Map: 23%|##2 | 20000/87599 [00:09<00:31, 2163.40 examples/s]
Map: 24%|##3 | 21000/87599 [00:09<00:30, 2158.70 examples/s]
Map: 25%|##5 | 22000/87599 [00:10<00:30, 2157.39 examples/s]
Map: 26%|##6 | 23000/87599 [00:10<00:29, 2160.35 examples/s]
Map: 27%|##7 | 24000/87599 [00:11<00:29, 2143.82 examples/s]
Map: 29%|##8 | 25000/87599 [00:11<00:29, 2149.47 examples/s]
Map: 30%|##9 | 26000/87599 [00:12<00:28, 2144.96 examples/s]
Map: 31%|### | 27000/87599 [00:12<00:28, 2134.07 examples/s]
Map: 32%|###1 | 28000/87599 [00:13<00:28, 2121.42 examples/s]
Map: 33%|###3 | 29000/87599 [00:13<00:27, 2112.69 examples/s]
Map: 34%|###4 | 30000/87599 [00:14<00:27, 2100.83 examples/s]
Map: 35%|###5 | 31000/87599 [00:14<00:26, 2098.66 examples/s]
Map: 37%|###6 | 32000/87599 [00:15<00:26, 2095.24 examples/s]
Map: 38%|###7 | 33000/87599 [00:15<00:26, 2085.60 examples/s]
Map: 39%|###8 | 34000/87599 [00:16<00:25, 2081.96 examples/s]
Map: 40%|###9 | 35000/87599 [00:16<00:25, 2081.83 examples/s]
Map: 41%|####1 | 36000/87599 [00:16<00:24, 2083.05 examples/s]
Map: 42%|####2 | 37000/87599 [00:17<00:24, 2084.31 examples/s]
Map: 43%|####3 | 38000/87599 [00:17<00:23, 2078.45 examples/s]
Map: 45%|####4 | 39000/87599 [00:18<00:23, 2075.95 examples/s]
Map: 46%|####5 | 40000/87599 [00:18<00:22, 2071.82 examples/s]
Map: 47%|####6 | 41000/87599 [00:19<00:22, 2077.27 examples/s]
Map: 48%|####7 | 42000/87599 [00:19<00:21, 2083.74 examples/s]
Map: 49%|####9 | 43000/87599 [00:20<00:21, 2093.18 examples/s]
Map: 50%|##### | 44000/87599 [00:20<00:20, 2087.08 examples/s]
Map: 51%|#####1 | 45000/87599 [00:21<00:20, 2090.34 examples/s]
Map: 53%|#####2 | 46000/87599 [00:21<00:19, 2084.59 examples/s]
Map: 54%|#####3 | 47000/87599 [00:22<00:19, 2084.82 examples/s]
Map: 55%|#####4 | 48000/87599 [00:22<00:18, 2089.12 examples/s]
Map: 56%|#####5 | 49000/87599 [00:23<00:18, 2090.96 examples/s]
Map: 57%|#####7 | 50000/87599 [00:23<00:17, 2091.74 examples/s]
Map: 58%|#####8 | 51000/87599 [00:24<00:17, 2083.05 examples/s]
Map: 59%|#####9 | 52000/87599 [00:24<00:17, 2084.47 examples/s]
Map: 61%|###### | 53000/87599 [00:25<00:16, 2086.25 examples/s]
Map: 62%|######1 | 54000/87599 [00:25<00:16, 2083.23 examples/s]
Map: 63%|######2 | 55000/87599 [00:26<00:15, 2079.12 examples/s]
Map: 64%|######3 | 56000/87599 [00:26<00:15, 2080.58 examples/s]
Map: 65%|######5 | 57000/87599 [00:27<00:14, 2082.59 examples/s]
Map: 66%|######6 | 58000/87599 [00:27<00:14, 2078.49 examples/s]
Map: 67%|######7 | 59000/87599 [00:28<00:13, 2063.44 examples/s]
Map: 68%|######8 | 60000/87599 [00:28<00:13, 2075.64 examples/s]
Map: 70%|######9 | 61000/87599 [00:28<00:12, 2089.36 examples/s]
Map: 71%|####### | 62000/87599 [00:29<00:12, 2090.93 examples/s]
Map: 72%|#######1 | 63000/87599 [00:29<00:11, 2086.71 examples/s]
Map: 73%|#######3 | 64000/87599 [00:30<00:11, 2083.57 examples/s]
Map: 74%|#######4 | 65000/87599 [00:30<00:10, 2078.96 examples/s]
Map: 75%|#######5 | 66000/87599 [00:31<00:10, 2080.12 examples/s]
Map: 76%|#######6 | 67000/87599 [00:31<00:09, 2085.24 examples/s]
Map: 78%|#######7 | 68000/87599 [00:32<00:09, 2083.06 examples/s]
Map: 79%|#######8 | 69000/87599 [00:32<00:08, 2077.96 examples/s]
Map: 80%|#######9 | 70000/87599 [00:33<00:08, 2076.96 examples/s]
Map: 81%|########1 | 71000/87599 [00:33<00:07, 2078.86 examples/s]
Map: 82%|########2 | 72000/87599 [00:34<00:07, 2073.63 examples/s]
Map: 83%|########3 | 73000/87599 [00:34<00:07, 2084.37 examples/s]
Map: 84%|########4 | 74000/87599 [00:35<00:06, 2082.80 examples/s]
Map: 86%|########5 | 75000/87599 [00:35<00:06, 2087.99 examples/s]
Map: 87%|########6 | 76000/87599 [00:36<00:05, 2088.58 examples/s]
Map: 88%|########7 | 77000/87599 [00:36<00:05, 2084.06 examples/s]
Map: 89%|########9 | 78000/87599 [00:37<00:04, 2081.21 examples/s]
Map: 90%|######### | 79000/87599 [00:37<00:04, 2073.82 examples/s]
Map: 91%|#########1| 80000/87599 [00:38<00:03, 2070.32 examples/s]
Map: 92%|#########2| 81000/87599 [00:38<00:03, 2073.04 examples/s]
Map: 94%|#########3| 82000/87599 [00:39<00:02, 2073.19 examples/s]
Map: 95%|#########4| 83000/87599 [00:39<00:02, 2057.16 examples/s]
Map: 96%|#########5| 84000/87599 [00:40<00:01, 2056.40 examples/s]
Map: 97%|#########7| 85000/87599 [00:42<00:03, 825.64 examples/s]
Map: 98%|#########8| 86000/87599 [00:43<00:01, 1006.96 examples/s]
Map: 99%|#########9| 87000/87599 [00:43<00:00, 1191.64 examples/s]
Map: 100%|##########| 87599/87599 [00:44<00:00, 1290.98 examples/s]
Map: 100%|##########| 87599/87599 [00:44<00:00, 1978.32 examples/s]
Map: 0%| | 0/10570 [00:00<?, ? examples/s]
Map: 9%|9 | 1000/10570 [00:00<00:03, 2835.60 examples/s]
Map: 19%|#8 | 2000/10570 [00:00<00:02, 2861.25 examples/s]
Map: 28%|##8 | 3000/10570 [00:01<00:02, 2843.67 examples/s]
Map: 38%|###7 | 4000/10570 [00:01<00:02, 2805.33 examples/s]
Map: 47%|####7 | 5000/10570 [00:01<00:02, 2586.36 examples/s]
Map: 57%|#####6 | 6000/10570 [00:02<00:01, 2583.47 examples/s]
Map: 66%|######6 | 7000/10570 [00:02<00:01, 2594.94 examples/s]
Map: 76%|#######5 | 8000/10570 [00:03<00:00, 2606.32 examples/s]
Map: 85%|########5 | 9000/10570 [00:03<00:00, 2603.16 examples/s]
Map: 95%|#########4| 10000/10570 [00:03<00:00, 2629.37 examples/s]
Map: 100%|##########| 10570/10570 [00:04<00:00, 2582.67 examples/s]
Map: 100%|##########| 10570/10570 [00:04<00:00, 2634.93 examples/s]
建立基线¶
接下来,我们将对模型在 SQuAD 上进行快速基线训练。此任务要求我们的模型识别给定上下文(维基百科文章)中回答给定问题的文本段落或片段。运行以下代码会得到 86.9 的 F1 分数。这与英伟达报告的分数非常接近,差异可能是由于 BERT-base 与 BERT-large 或微调超参数造成的。
training_args = transformers.TrainingArguments(
"trainer",
num_train_epochs=1,
lr_scheduler_type="constant",
per_device_train_batch_size=32,
per_device_eval_batch_size=256,
logging_steps=50,
# Limit max steps for tutorial runners. Delete the below line to see the reported accuracy numbers.
max_steps=500,
report_to=None,
)
trainer = transformers.Trainer(
model,
training_args,
train_dataset=tokenized_squad_dataset["train"],
eval_dataset=tokenized_squad_dataset["validation"],
data_collator=data_collator,
tokenizer=tokenizer,
)
trainer.train()
# batch sizes to compare for eval
batch_sizes = [4, 16, 64, 256]
# 2:4 sparsity require fp16, so we cast here for a fair comparison
with torch.autocast("cuda"):
with torch.no_grad():
predictions = trainer.predict(tokenized_squad_dataset["validation"])
start_logits, end_logits = predictions.predictions
fp16_baseline = compute_metrics(
start_logits,
end_logits,
tokenized_squad_dataset["validation"],
squad_dataset["validation"],
)
fp16_time = measure_execution_time(
model,
batch_sizes,
tokenized_squad_dataset["validation"],
)
print("fp16", fp16_baseline)
print("cuda_fp16 time", fp16_time)
import pandas as pd
df = pd.DataFrame(trainer.state.log_history)
df.plot.line(x='step', y='loss', title="Loss vs. # steps", ylabel="loss")
torch.distributed process group is initialized, but parallel_mode != ParallelMode.DISTRIBUTED. In order to use Torch DDP, launch your script with `python -m torch.distributed.launch
Using the `WANDB_DISABLED` environment variable is deprecated and will be removed in v5. Use the --report_to flag to control the integrations used for logging result (for instance --report_to none).
max_steps is given, it will override any value given in num_train_epochs
0%| | 0/500 [00:00<?, ?it/s]
0%| | 1/500 [00:00<04:41, 1.77it/s]
0%| | 2/500 [00:01<04:28, 1.85it/s]
1%| | 3/500 [00:01<04:25, 1.87it/s]
1%| | 4/500 [00:02<04:23, 1.88it/s]
1%|1 | 5/500 [00:02<04:22, 1.89it/s]
1%|1 | 6/500 [00:03<04:21, 1.89it/s]
1%|1 | 7/500 [00:03<04:20, 1.89it/s]
2%|1 | 8/500 [00:04<04:20, 1.89it/s]
2%|1 | 9/500 [00:04<04:19, 1.89it/s]
2%|2 | 10/500 [00:05<04:18, 1.89it/s]
2%|2 | 11/500 [00:05<04:18, 1.89it/s]
2%|2 | 12/500 [00:06<04:17, 1.89it/s]
3%|2 | 13/500 [00:06<04:17, 1.89it/s]
3%|2 | 14/500 [00:07<04:17, 1.89it/s]
3%|3 | 15/500 [00:07<04:16, 1.89it/s]
3%|3 | 16/500 [00:08<04:15, 1.89it/s]
3%|3 | 17/500 [00:09<04:15, 1.89it/s]
4%|3 | 18/500 [00:09<04:15, 1.89it/s]
4%|3 | 19/500 [00:10<04:14, 1.89it/s]
4%|4 | 20/500 [00:10<04:13, 1.89it/s]
4%|4 | 21/500 [00:11<04:13, 1.89it/s]
4%|4 | 22/500 [00:11<04:13, 1.89it/s]
5%|4 | 23/500 [00:12<04:12, 1.89it/s]
5%|4 | 24/500 [00:12<04:12, 1.89it/s]
5%|5 | 25/500 [00:13<04:11, 1.89it/s]
5%|5 | 26/500 [00:13<04:11, 1.89it/s]
5%|5 | 27/500 [00:14<04:10, 1.89it/s]
6%|5 | 28/500 [00:14<04:10, 1.89it/s]
6%|5 | 29/500 [00:15<04:09, 1.89it/s]
6%|6 | 30/500 [00:15<04:09, 1.88it/s]
6%|6 | 31/500 [00:16<04:09, 1.88it/s]
6%|6 | 32/500 [00:16<04:08, 1.88it/s]
7%|6 | 33/500 [00:17<04:07, 1.88it/s]
7%|6 | 34/500 [00:18<04:07, 1.88it/s]
7%|7 | 35/500 [00:18<04:06, 1.88it/s]
7%|7 | 36/500 [00:19<04:06, 1.88it/s]
7%|7 | 37/500 [00:19<04:06, 1.88it/s]
8%|7 | 38/500 [00:20<04:05, 1.88it/s]
8%|7 | 39/500 [00:20<04:04, 1.88it/s]
8%|8 | 40/500 [00:21<04:04, 1.88it/s]
8%|8 | 41/500 [00:21<04:03, 1.88it/s]
8%|8 | 42/500 [00:22<04:03, 1.88it/s]
9%|8 | 43/500 [00:22<04:02, 1.88it/s]
9%|8 | 44/500 [00:23<04:02, 1.88it/s]
9%|9 | 45/500 [00:23<04:02, 1.88it/s]
9%|9 | 46/500 [00:24<04:01, 1.88it/s]
9%|9 | 47/500 [00:24<04:01, 1.88it/s]
10%|9 | 48/500 [00:25<04:00, 1.88it/s]
10%|9 | 49/500 [00:26<04:00, 1.88it/s]
10%|# | 50/500 [00:26<03:59, 1.88it/s]
{'loss': 3.8379, 'grad_norm': 14.342127799987793, 'learning_rate': 5e-05, 'epoch': 0.02}
10%|# | 50/500 [00:26<03:59, 1.88it/s]
10%|# | 51/500 [00:27<03:59, 1.88it/s]
10%|# | 52/500 [00:27<03:58, 1.88it/s]
11%|# | 53/500 [00:28<03:58, 1.88it/s]
11%|# | 54/500 [00:28<03:57, 1.88it/s]
11%|#1 | 55/500 [00:29<03:57, 1.88it/s]
11%|#1 | 56/500 [00:29<03:56, 1.88it/s]
11%|#1 | 57/500 [00:30<03:56, 1.88it/s]
12%|#1 | 58/500 [00:30<03:55, 1.88it/s]
12%|#1 | 59/500 [00:31<03:55, 1.88it/s]
12%|#2 | 60/500 [00:31<03:54, 1.88it/s]
12%|#2 | 61/500 [00:32<03:54, 1.88it/s]
12%|#2 | 62/500 [00:32<03:53, 1.88it/s]
13%|#2 | 63/500 [00:33<03:53, 1.87it/s]
13%|#2 | 64/500 [00:34<03:52, 1.87it/s]
13%|#3 | 65/500 [00:34<03:51, 1.88it/s]
13%|#3 | 66/500 [00:35<03:51, 1.87it/s]
13%|#3 | 67/500 [00:35<03:50, 1.87it/s]
14%|#3 | 68/500 [00:36<03:50, 1.87it/s]
14%|#3 | 69/500 [00:36<03:49, 1.88it/s]
14%|#4 | 70/500 [00:37<03:49, 1.88it/s]
14%|#4 | 71/500 [00:37<03:48, 1.87it/s]
14%|#4 | 72/500 [00:38<03:48, 1.88it/s]
15%|#4 | 73/500 [00:38<03:47, 1.88it/s]
15%|#4 | 74/500 [00:39<03:47, 1.88it/s]
15%|#5 | 75/500 [00:39<03:46, 1.88it/s]
15%|#5 | 76/500 [00:40<03:46, 1.88it/s]
15%|#5 | 77/500 [00:40<03:45, 1.88it/s]
16%|#5 | 78/500 [00:41<03:45, 1.88it/s]
16%|#5 | 79/500 [00:42<03:44, 1.87it/s]
16%|#6 | 80/500 [00:42<03:44, 1.87it/s]
16%|#6 | 81/500 [00:43<03:43, 1.87it/s]
16%|#6 | 82/500 [00:43<03:43, 1.87it/s]
17%|#6 | 83/500 [00:44<03:42, 1.87it/s]
17%|#6 | 84/500 [00:44<03:42, 1.87it/s]
17%|#7 | 85/500 [00:45<03:41, 1.88it/s]
17%|#7 | 86/500 [00:45<03:40, 1.88it/s]
17%|#7 | 87/500 [00:46<03:39, 1.88it/s]
18%|#7 | 88/500 [00:46<03:38, 1.88it/s]
18%|#7 | 89/500 [00:47<03:38, 1.88it/s]
18%|#8 | 90/500 [00:47<03:37, 1.88it/s]
18%|#8 | 91/500 [00:48<03:36, 1.88it/s]
18%|#8 | 92/500 [00:48<03:36, 1.88it/s]
19%|#8 | 93/500 [00:49<03:35, 1.89it/s]
19%|#8 | 94/500 [00:49<03:35, 1.89it/s]
19%|#9 | 95/500 [00:50<03:34, 1.89it/s]
19%|#9 | 96/500 [00:51<03:34, 1.89it/s]
19%|#9 | 97/500 [00:51<03:33, 1.89it/s]
20%|#9 | 98/500 [00:52<03:33, 1.89it/s]
20%|#9 | 99/500 [00:52<03:32, 1.88it/s]
20%|## | 100/500 [00:53<03:32, 1.89it/s]
{'loss': 2.3803, 'grad_norm': 15.05634880065918, 'learning_rate': 5e-05, 'epoch': 0.04}
20%|## | 100/500 [00:53<03:32, 1.89it/s]
20%|## | 101/500 [00:53<03:31, 1.88it/s]
20%|## | 102/500 [00:54<03:31, 1.88it/s]
21%|## | 103/500 [00:54<03:30, 1.88it/s]
21%|## | 104/500 [00:55<03:30, 1.88it/s]
21%|##1 | 105/500 [00:55<03:29, 1.88it/s]
21%|##1 | 106/500 [00:56<03:29, 1.88it/s]
21%|##1 | 107/500 [00:56<03:28, 1.88it/s]
22%|##1 | 108/500 [00:57<03:28, 1.88it/s]
22%|##1 | 109/500 [00:57<03:27, 1.89it/s]
22%|##2 | 110/500 [00:58<03:27, 1.88it/s]
22%|##2 | 111/500 [00:58<03:26, 1.89it/s]
22%|##2 | 112/500 [00:59<03:25, 1.88it/s]
23%|##2 | 113/500 [01:00<03:25, 1.89it/s]
23%|##2 | 114/500 [01:00<03:24, 1.88it/s]
23%|##3 | 115/500 [01:01<03:24, 1.88it/s]
23%|##3 | 116/500 [01:01<03:23, 1.88it/s]
23%|##3 | 117/500 [01:02<03:23, 1.88it/s]
24%|##3 | 118/500 [01:02<03:22, 1.88it/s]
24%|##3 | 119/500 [01:03<03:22, 1.88it/s]
24%|##4 | 120/500 [01:03<03:21, 1.88it/s]
24%|##4 | 121/500 [01:04<03:21, 1.88it/s]
24%|##4 | 122/500 [01:04<03:20, 1.88it/s]
25%|##4 | 123/500 [01:05<03:20, 1.88it/s]
25%|##4 | 124/500 [01:05<03:19, 1.88it/s]
25%|##5 | 125/500 [01:06<03:19, 1.88it/s]
25%|##5 | 126/500 [01:06<03:18, 1.88it/s]
25%|##5 | 127/500 [01:07<03:18, 1.88it/s]
26%|##5 | 128/500 [01:08<03:17, 1.88it/s]
26%|##5 | 129/500 [01:08<03:17, 1.88it/s]
26%|##6 | 130/500 [01:09<03:16, 1.88it/s]
26%|##6 | 131/500 [01:09<03:16, 1.88it/s]
26%|##6 | 132/500 [01:10<03:15, 1.88it/s]
27%|##6 | 133/500 [01:10<03:14, 1.88it/s]
27%|##6 | 134/500 [01:11<03:14, 1.88it/s]
27%|##7 | 135/500 [01:11<03:13, 1.88it/s]
27%|##7 | 136/500 [01:12<03:13, 1.88it/s]
27%|##7 | 137/500 [01:12<03:12, 1.88it/s]
28%|##7 | 138/500 [01:13<03:12, 1.88it/s]
28%|##7 | 139/500 [01:13<03:11, 1.88it/s]
28%|##8 | 140/500 [01:14<03:11, 1.88it/s]
28%|##8 | 141/500 [01:14<03:10, 1.88it/s]
28%|##8 | 142/500 [01:15<03:10, 1.88it/s]
29%|##8 | 143/500 [01:15<03:09, 1.88it/s]
29%|##8 | 144/500 [01:16<03:09, 1.88it/s]
29%|##9 | 145/500 [01:17<03:08, 1.88it/s]
29%|##9 | 146/500 [01:17<03:08, 1.88it/s]
29%|##9 | 147/500 [01:18<03:07, 1.88it/s]
30%|##9 | 148/500 [01:18<03:06, 1.88it/s]
30%|##9 | 149/500 [01:19<03:06, 1.88it/s]
30%|### | 150/500 [01:19<03:06, 1.88it/s]
{'loss': 1.8675, 'grad_norm': 11.2459716796875, 'learning_rate': 5e-05, 'epoch': 0.05}
30%|### | 150/500 [01:19<03:06, 1.88it/s]
30%|### | 151/500 [01:20<03:05, 1.88it/s]
30%|### | 152/500 [01:20<03:05, 1.88it/s]
31%|### | 153/500 [01:21<03:04, 1.88it/s]
31%|### | 154/500 [01:21<03:03, 1.88it/s]
31%|###1 | 155/500 [01:22<03:03, 1.88it/s]
31%|###1 | 156/500 [01:22<03:02, 1.88it/s]
31%|###1 | 157/500 [01:23<03:02, 1.88it/s]
32%|###1 | 158/500 [01:23<03:01, 1.88it/s]
32%|###1 | 159/500 [01:24<03:01, 1.88it/s]
32%|###2 | 160/500 [01:25<03:00, 1.88it/s]
32%|###2 | 161/500 [01:25<03:00, 1.88it/s]
32%|###2 | 162/500 [01:26<02:59, 1.88it/s]
33%|###2 | 163/500 [01:26<02:59, 1.88it/s]
33%|###2 | 164/500 [01:27<02:58, 1.88it/s]
33%|###3 | 165/500 [01:27<02:57, 1.88it/s]
33%|###3 | 166/500 [01:28<02:57, 1.88it/s]
33%|###3 | 167/500 [01:28<02:56, 1.88it/s]
34%|###3 | 168/500 [01:29<02:56, 1.88it/s]
34%|###3 | 169/500 [01:29<02:55, 1.88it/s]
34%|###4 | 170/500 [01:30<02:55, 1.88it/s]
34%|###4 | 171/500 [01:30<02:54, 1.88it/s]
34%|###4 | 172/500 [01:31<02:54, 1.88it/s]
35%|###4 | 173/500 [01:31<02:53, 1.88it/s]
35%|###4 | 174/500 [01:32<02:53, 1.88it/s]
35%|###5 | 175/500 [01:33<02:52, 1.88it/s]
35%|###5 | 176/500 [01:33<02:51, 1.88it/s]
35%|###5 | 177/500 [01:34<02:51, 1.88it/s]
36%|###5 | 178/500 [01:34<02:50, 1.88it/s]
36%|###5 | 179/500 [01:35<02:50, 1.88it/s]
36%|###6 | 180/500 [01:35<02:49, 1.89it/s]
36%|###6 | 181/500 [01:36<02:49, 1.89it/s]
36%|###6 | 182/500 [01:36<02:48, 1.89it/s]
37%|###6 | 183/500 [01:37<02:48, 1.89it/s]
37%|###6 | 184/500 [01:37<02:47, 1.88it/s]
37%|###7 | 185/500 [01:38<02:47, 1.89it/s]
37%|###7 | 186/500 [01:38<02:46, 1.88it/s]
37%|###7 | 187/500 [01:39<02:45, 1.89it/s]
38%|###7 | 188/500 [01:39<02:45, 1.89it/s]
38%|###7 | 189/500 [01:40<02:44, 1.89it/s]
38%|###8 | 190/500 [01:40<02:44, 1.89it/s]
38%|###8 | 191/500 [01:41<02:43, 1.89it/s]
38%|###8 | 192/500 [01:42<02:43, 1.89it/s]
39%|###8 | 193/500 [01:42<02:42, 1.89it/s]
39%|###8 | 194/500 [01:43<02:42, 1.89it/s]
39%|###9 | 195/500 [01:43<02:41, 1.89it/s]
39%|###9 | 196/500 [01:44<02:41, 1.89it/s]
39%|###9 | 197/500 [01:44<02:40, 1.89it/s]
40%|###9 | 198/500 [01:45<02:39, 1.89it/s]
40%|###9 | 199/500 [01:45<02:39, 1.89it/s]
40%|#### | 200/500 [01:46<02:38, 1.89it/s]
{'loss': 1.7281, 'grad_norm': 12.751557350158691, 'learning_rate': 5e-05, 'epoch': 0.07}
40%|#### | 200/500 [01:46<02:38, 1.89it/s]
40%|#### | 201/500 [01:46<02:38, 1.89it/s]
40%|#### | 202/500 [01:47<02:37, 1.89it/s]
41%|#### | 203/500 [01:47<02:37, 1.89it/s]
41%|#### | 204/500 [01:48<02:36, 1.89it/s]
41%|####1 | 205/500 [01:48<02:36, 1.89it/s]
41%|####1 | 206/500 [01:49<02:35, 1.89it/s]
41%|####1 | 207/500 [01:49<02:35, 1.89it/s]
42%|####1 | 208/500 [01:50<02:34, 1.89it/s]
42%|####1 | 209/500 [01:51<02:34, 1.88it/s]
42%|####2 | 210/500 [01:51<02:34, 1.88it/s]
42%|####2 | 211/500 [01:52<02:33, 1.88it/s]
42%|####2 | 212/500 [01:52<02:33, 1.88it/s]
43%|####2 | 213/500 [01:53<02:32, 1.88it/s]
43%|####2 | 214/500 [01:53<02:32, 1.88it/s]
43%|####3 | 215/500 [01:54<02:31, 1.88it/s]
43%|####3 | 216/500 [01:54<02:31, 1.88it/s]
43%|####3 | 217/500 [01:55<02:30, 1.88it/s]
44%|####3 | 218/500 [01:55<02:30, 1.88it/s]
44%|####3 | 219/500 [01:56<02:29, 1.88it/s]
44%|####4 | 220/500 [01:56<02:29, 1.88it/s]
44%|####4 | 221/500 [01:57<02:28, 1.88it/s]
44%|####4 | 222/500 [01:57<02:28, 1.88it/s]
45%|####4 | 223/500 [01:58<02:27, 1.88it/s]
45%|####4 | 224/500 [01:59<02:27, 1.88it/s]
45%|####5 | 225/500 [01:59<02:26, 1.88it/s]
45%|####5 | 226/500 [02:00<02:26, 1.88it/s]
45%|####5 | 227/500 [02:00<02:25, 1.88it/s]
46%|####5 | 228/500 [02:01<02:24, 1.88it/s]
46%|####5 | 229/500 [02:01<02:24, 1.88it/s]
46%|####6 | 230/500 [02:02<02:23, 1.88it/s]
46%|####6 | 231/500 [02:02<02:23, 1.88it/s]
46%|####6 | 232/500 [02:03<02:22, 1.88it/s]
47%|####6 | 233/500 [02:03<02:22, 1.88it/s]
47%|####6 | 234/500 [02:04<02:21, 1.88it/s]
47%|####6 | 235/500 [02:04<02:21, 1.88it/s]
47%|####7 | 236/500 [02:05<02:20, 1.88it/s]
47%|####7 | 237/500 [02:05<02:20, 1.88it/s]
48%|####7 | 238/500 [02:06<02:19, 1.88it/s]
48%|####7 | 239/500 [02:07<02:19, 1.88it/s]
48%|####8 | 240/500 [02:07<02:18, 1.88it/s]
48%|####8 | 241/500 [02:08<02:18, 1.88it/s]
48%|####8 | 242/500 [02:08<02:17, 1.88it/s]
49%|####8 | 243/500 [02:09<02:16, 1.88it/s]
49%|####8 | 244/500 [02:09<02:16, 1.88it/s]
49%|####9 | 245/500 [02:10<02:15, 1.88it/s]
49%|####9 | 246/500 [02:10<02:15, 1.88it/s]
49%|####9 | 247/500 [02:11<02:14, 1.88it/s]
50%|####9 | 248/500 [02:11<02:14, 1.88it/s]
50%|####9 | 249/500 [02:12<02:13, 1.88it/s]
50%|##### | 250/500 [02:12<02:13, 1.88it/s]
{'loss': 1.5849, 'grad_norm': 12.03045654296875, 'learning_rate': 5e-05, 'epoch': 0.09}
50%|##### | 250/500 [02:12<02:13, 1.88it/s]
50%|##### | 251/500 [02:13<02:12, 1.88it/s]
50%|##### | 252/500 [02:13<02:12, 1.88it/s]
51%|##### | 253/500 [02:14<02:11, 1.88it/s]
51%|##### | 254/500 [02:15<02:11, 1.88it/s]
51%|#####1 | 255/500 [02:15<02:10, 1.88it/s]
51%|#####1 | 256/500 [02:16<02:10, 1.88it/s]
51%|#####1 | 257/500 [02:16<02:09, 1.88it/s]
52%|#####1 | 258/500 [02:17<02:09, 1.88it/s]
52%|#####1 | 259/500 [02:17<02:08, 1.88it/s]
52%|#####2 | 260/500 [02:18<02:07, 1.88it/s]
52%|#####2 | 261/500 [02:18<02:07, 1.88it/s]
52%|#####2 | 262/500 [02:19<02:06, 1.88it/s]
53%|#####2 | 263/500 [02:19<02:06, 1.88it/s]
53%|#####2 | 264/500 [02:20<02:05, 1.88it/s]
53%|#####3 | 265/500 [02:20<02:05, 1.88it/s]
53%|#####3 | 266/500 [02:21<02:04, 1.88it/s]
53%|#####3 | 267/500 [02:21<02:04, 1.88it/s]
54%|#####3 | 268/500 [02:22<02:03, 1.88it/s]
54%|#####3 | 269/500 [02:23<02:03, 1.88it/s]
54%|#####4 | 270/500 [02:23<02:02, 1.88it/s]
54%|#####4 | 271/500 [02:24<02:01, 1.88it/s]
54%|#####4 | 272/500 [02:24<02:01, 1.88it/s]
55%|#####4 | 273/500 [02:25<02:00, 1.88it/s]
55%|#####4 | 274/500 [02:25<02:00, 1.88it/s]
55%|#####5 | 275/500 [02:26<01:59, 1.88it/s]
55%|#####5 | 276/500 [02:26<01:59, 1.88it/s]
55%|#####5 | 277/500 [02:27<01:58, 1.88it/s]
56%|#####5 | 278/500 [02:27<01:58, 1.88it/s]
56%|#####5 | 279/500 [02:28<01:57, 1.88it/s]
56%|#####6 | 280/500 [02:28<01:57, 1.88it/s]
56%|#####6 | 281/500 [02:29<01:56, 1.88it/s]
56%|#####6 | 282/500 [02:29<01:56, 1.88it/s]
57%|#####6 | 283/500 [02:30<01:55, 1.88it/s]
57%|#####6 | 284/500 [02:30<01:55, 1.88it/s]
57%|#####6 | 285/500 [02:31<01:54, 1.88it/s]
57%|#####7 | 286/500 [02:32<01:54, 1.88it/s]
57%|#####7 | 287/500 [02:32<01:53, 1.88it/s]
58%|#####7 | 288/500 [02:33<01:52, 1.88it/s]
58%|#####7 | 289/500 [02:33<01:52, 1.88it/s]
58%|#####8 | 290/500 [02:34<01:51, 1.88it/s]
58%|#####8 | 291/500 [02:34<01:51, 1.88it/s]
58%|#####8 | 292/500 [02:35<01:50, 1.88it/s]
59%|#####8 | 293/500 [02:35<01:50, 1.88it/s]
59%|#####8 | 294/500 [02:36<01:49, 1.88it/s]
59%|#####8 | 295/500 [02:36<01:49, 1.88it/s]
59%|#####9 | 296/500 [02:37<01:48, 1.88it/s]
59%|#####9 | 297/500 [02:37<01:48, 1.88it/s]
60%|#####9 | 298/500 [02:38<01:47, 1.88it/s]
60%|#####9 | 299/500 [02:38<01:47, 1.88it/s]
60%|###### | 300/500 [02:39<01:46, 1.88it/s]
{'loss': 1.5287, 'grad_norm': 12.408611297607422, 'learning_rate': 5e-05, 'epoch': 0.11}
60%|###### | 300/500 [02:39<01:46, 1.88it/s]
60%|###### | 301/500 [02:40<01:46, 1.88it/s]
60%|###### | 302/500 [02:40<01:45, 1.88it/s]
61%|###### | 303/500 [02:41<01:45, 1.88it/s]
61%|###### | 304/500 [02:41<01:44, 1.88it/s]
61%|######1 | 305/500 [02:42<01:43, 1.88it/s]
61%|######1 | 306/500 [02:42<01:43, 1.88it/s]
61%|######1 | 307/500 [02:43<01:42, 1.88it/s]
62%|######1 | 308/500 [02:43<01:42, 1.88it/s]
62%|######1 | 309/500 [02:44<01:41, 1.88it/s]
62%|######2 | 310/500 [02:44<01:41, 1.88it/s]
62%|######2 | 311/500 [02:45<01:40, 1.88it/s]
62%|######2 | 312/500 [02:45<01:40, 1.88it/s]
63%|######2 | 313/500 [02:46<01:39, 1.88it/s]
63%|######2 | 314/500 [02:46<01:39, 1.88it/s]
63%|######3 | 315/500 [02:47<01:38, 1.88it/s]
63%|######3 | 316/500 [02:48<01:38, 1.88it/s]
63%|######3 | 317/500 [02:48<01:37, 1.88it/s]
64%|######3 | 318/500 [02:49<01:36, 1.88it/s]
64%|######3 | 319/500 [02:49<01:36, 1.88it/s]
64%|######4 | 320/500 [02:50<01:35, 1.88it/s]
64%|######4 | 321/500 [02:50<01:35, 1.88it/s]
64%|######4 | 322/500 [02:51<01:34, 1.88it/s]
65%|######4 | 323/500 [02:51<01:34, 1.88it/s]
65%|######4 | 324/500 [02:52<01:33, 1.88it/s]
65%|######5 | 325/500 [02:52<01:33, 1.88it/s]
65%|######5 | 326/500 [02:53<01:32, 1.88it/s]
65%|######5 | 327/500 [02:53<01:32, 1.88it/s]
66%|######5 | 328/500 [02:54<01:31, 1.88it/s]
66%|######5 | 329/500 [02:54<01:31, 1.88it/s]
66%|######6 | 330/500 [02:55<01:30, 1.88it/s]
66%|######6 | 331/500 [02:56<01:30, 1.87it/s]
66%|######6 | 332/500 [02:56<01:29, 1.88it/s]
67%|######6 | 333/500 [02:57<01:29, 1.88it/s]
67%|######6 | 334/500 [02:57<01:28, 1.88it/s]
67%|######7 | 335/500 [02:58<01:27, 1.88it/s]
67%|######7 | 336/500 [02:58<01:27, 1.88it/s]
67%|######7 | 337/500 [02:59<01:26, 1.88it/s]
68%|######7 | 338/500 [02:59<01:26, 1.88it/s]
68%|######7 | 339/500 [03:00<01:25, 1.88it/s]
68%|######8 | 340/500 [03:00<01:25, 1.88it/s]
68%|######8 | 341/500 [03:01<01:24, 1.88it/s]
68%|######8 | 342/500 [03:01<01:24, 1.87it/s]
69%|######8 | 343/500 [03:02<01:23, 1.88it/s]
69%|######8 | 344/500 [03:02<01:23, 1.88it/s]
69%|######9 | 345/500 [03:03<01:22, 1.88it/s]
69%|######9 | 346/500 [03:04<01:22, 1.88it/s]
69%|######9 | 347/500 [03:04<01:21, 1.88it/s]
70%|######9 | 348/500 [03:05<01:21, 1.88it/s]
70%|######9 | 349/500 [03:05<01:20, 1.88it/s]
70%|####### | 350/500 [03:06<01:19, 1.88it/s]
{'loss': 1.4997, 'grad_norm': 9.444414138793945, 'learning_rate': 5e-05, 'epoch': 0.13}
70%|####### | 350/500 [03:06<01:19, 1.88it/s]
70%|####### | 351/500 [03:06<01:19, 1.88it/s]
70%|####### | 352/500 [03:07<01:18, 1.88it/s]
71%|####### | 353/500 [03:07<01:18, 1.88it/s]
71%|####### | 354/500 [03:08<01:17, 1.88it/s]
71%|#######1 | 355/500 [03:08<01:17, 1.88it/s]
71%|#######1 | 356/500 [03:09<01:16, 1.88it/s]
71%|#######1 | 357/500 [03:09<01:16, 1.88it/s]
72%|#######1 | 358/500 [03:10<01:15, 1.88it/s]
72%|#######1 | 359/500 [03:10<01:15, 1.88it/s]
72%|#######2 | 360/500 [03:11<01:14, 1.88it/s]
72%|#######2 | 361/500 [03:12<01:14, 1.88it/s]
72%|#######2 | 362/500 [03:12<01:13, 1.88it/s]
73%|#######2 | 363/500 [03:13<01:13, 1.88it/s]
73%|#######2 | 364/500 [03:13<01:12, 1.88it/s]
73%|#######3 | 365/500 [03:14<01:11, 1.88it/s]
73%|#######3 | 366/500 [03:14<01:11, 1.88it/s]
73%|#######3 | 367/500 [03:15<01:10, 1.88it/s]
74%|#######3 | 368/500 [03:15<01:10, 1.88it/s]
74%|#######3 | 369/500 [03:16<01:09, 1.88it/s]
74%|#######4 | 370/500 [03:16<01:09, 1.88it/s]
74%|#######4 | 371/500 [03:17<01:08, 1.88it/s]
74%|#######4 | 372/500 [03:17<01:08, 1.88it/s]
75%|#######4 | 373/500 [03:18<01:07, 1.88it/s]
75%|#######4 | 374/500 [03:18<01:07, 1.88it/s]
75%|#######5 | 375/500 [03:19<01:06, 1.88it/s]
75%|#######5 | 376/500 [03:20<01:06, 1.88it/s]
75%|#######5 | 377/500 [03:20<01:05, 1.88it/s]
76%|#######5 | 378/500 [03:21<01:05, 1.88it/s]
76%|#######5 | 379/500 [03:21<01:04, 1.88it/s]
76%|#######6 | 380/500 [03:22<01:03, 1.88it/s]
76%|#######6 | 381/500 [03:22<01:03, 1.88it/s]
76%|#######6 | 382/500 [03:23<01:02, 1.88it/s]
77%|#######6 | 383/500 [03:23<01:02, 1.88it/s]
77%|#######6 | 384/500 [03:24<01:01, 1.88it/s]
77%|#######7 | 385/500 [03:24<01:01, 1.88it/s]
77%|#######7 | 386/500 [03:25<01:00, 1.88it/s]
77%|#######7 | 387/500 [03:25<01:00, 1.87it/s]
78%|#######7 | 388/500 [03:26<00:59, 1.88it/s]
78%|#######7 | 389/500 [03:26<00:59, 1.88it/s]
78%|#######8 | 390/500 [03:27<00:58, 1.88it/s]
78%|#######8 | 391/500 [03:28<00:58, 1.87it/s]
78%|#######8 | 392/500 [03:28<00:57, 1.88it/s]
79%|#######8 | 393/500 [03:29<00:57, 1.88it/s]
79%|#######8 | 394/500 [03:29<00:56, 1.88it/s]
79%|#######9 | 395/500 [03:30<00:55, 1.88it/s]
79%|#######9 | 396/500 [03:30<00:55, 1.88it/s]
79%|#######9 | 397/500 [03:31<00:54, 1.88it/s]
80%|#######9 | 398/500 [03:31<00:54, 1.88it/s]
80%|#######9 | 399/500 [03:32<00:53, 1.88it/s]
80%|######## | 400/500 [03:32<00:53, 1.87it/s]
{'loss': 1.3912, 'grad_norm': 10.076640129089355, 'learning_rate': 5e-05, 'epoch': 0.15}
80%|######## | 400/500 [03:32<00:53, 1.87it/s]
80%|######## | 401/500 [03:33<00:52, 1.87it/s]
80%|######## | 402/500 [03:33<00:52, 1.88it/s]
81%|######## | 403/500 [03:34<00:51, 1.87it/s]
81%|######## | 404/500 [03:34<00:51, 1.88it/s]
81%|########1 | 405/500 [03:35<00:50, 1.88it/s]
81%|########1 | 406/500 [03:36<00:50, 1.88it/s]
81%|########1 | 407/500 [03:36<00:49, 1.88it/s]
82%|########1 | 408/500 [03:37<00:49, 1.88it/s]
82%|########1 | 409/500 [03:37<00:48, 1.88it/s]
82%|########2 | 410/500 [03:38<00:47, 1.88it/s]
82%|########2 | 411/500 [03:38<00:47, 1.88it/s]
82%|########2 | 412/500 [03:39<00:46, 1.88it/s]
83%|########2 | 413/500 [03:39<00:46, 1.88it/s]
83%|########2 | 414/500 [03:40<00:45, 1.88it/s]
83%|########2 | 415/500 [03:40<00:45, 1.87it/s]
83%|########3 | 416/500 [03:41<00:44, 1.87it/s]
83%|########3 | 417/500 [03:41<00:44, 1.87it/s]
84%|########3 | 418/500 [03:42<00:43, 1.88it/s]
84%|########3 | 419/500 [03:42<00:43, 1.88it/s]
84%|########4 | 420/500 [03:43<00:42, 1.88it/s]
84%|########4 | 421/500 [03:44<00:42, 1.88it/s]
84%|########4 | 422/500 [03:44<00:41, 1.88it/s]
85%|########4 | 423/500 [03:45<00:41, 1.88it/s]
85%|########4 | 424/500 [03:45<00:40, 1.88it/s]
85%|########5 | 425/500 [03:46<00:39, 1.88it/s]
85%|########5 | 426/500 [03:46<00:39, 1.88it/s]
85%|########5 | 427/500 [03:47<00:38, 1.88it/s]
86%|########5 | 428/500 [03:47<00:38, 1.88it/s]
86%|########5 | 429/500 [03:48<00:37, 1.88it/s]
86%|########6 | 430/500 [03:48<00:37, 1.88it/s]
86%|########6 | 431/500 [03:49<00:36, 1.88it/s]
86%|########6 | 432/500 [03:49<00:36, 1.88it/s]
87%|########6 | 433/500 [03:50<00:35, 1.87it/s]
87%|########6 | 434/500 [03:50<00:35, 1.88it/s]
87%|########7 | 435/500 [03:51<00:34, 1.88it/s]
87%|########7 | 436/500 [03:52<00:34, 1.88it/s]
87%|########7 | 437/500 [03:52<00:33, 1.87it/s]
88%|########7 | 438/500 [03:53<00:33, 1.87it/s]
88%|########7 | 439/500 [03:53<00:32, 1.87it/s]
88%|########8 | 440/500 [03:54<00:31, 1.88it/s]
88%|########8 | 441/500 [03:54<00:31, 1.88it/s]
88%|########8 | 442/500 [03:55<00:30, 1.88it/s]
89%|########8 | 443/500 [03:55<00:30, 1.88it/s]
89%|########8 | 444/500 [03:56<00:29, 1.88it/s]
89%|########9 | 445/500 [03:56<00:29, 1.87it/s]
89%|########9 | 446/500 [03:57<00:28, 1.88it/s]
89%|########9 | 447/500 [03:57<00:28, 1.88it/s]
90%|########9 | 448/500 [03:58<00:27, 1.88it/s]
90%|########9 | 449/500 [03:58<00:27, 1.88it/s]
90%|######### | 450/500 [03:59<00:26, 1.88it/s]
{'loss': 1.3439, 'grad_norm': 11.830737113952637, 'learning_rate': 5e-05, 'epoch': 0.16}
90%|######### | 450/500 [03:59<00:26, 1.88it/s]
90%|######### | 451/500 [04:00<00:26, 1.87it/s]
90%|######### | 452/500 [04:00<00:25, 1.88it/s]
91%|######### | 453/500 [04:01<00:25, 1.88it/s]
91%|######### | 454/500 [04:01<00:24, 1.87it/s]
91%|#########1| 455/500 [04:02<00:23, 1.88it/s]
91%|#########1| 456/500 [04:02<00:23, 1.88it/s]
91%|#########1| 457/500 [04:03<00:22, 1.87it/s]
92%|#########1| 458/500 [04:03<00:22, 1.87it/s]
92%|#########1| 459/500 [04:04<00:21, 1.88it/s]
92%|#########2| 460/500 [04:04<00:21, 1.87it/s]
92%|#########2| 461/500 [04:05<00:20, 1.88it/s]
92%|#########2| 462/500 [04:05<00:20, 1.88it/s]
93%|#########2| 463/500 [04:06<00:19, 1.88it/s]
93%|#########2| 464/500 [04:06<00:19, 1.88it/s]
93%|#########3| 465/500 [04:07<00:18, 1.88it/s]
93%|#########3| 466/500 [04:08<00:18, 1.88it/s]
93%|#########3| 467/500 [04:08<00:17, 1.87it/s]
94%|#########3| 468/500 [04:09<00:17, 1.88it/s]
94%|#########3| 469/500 [04:09<00:16, 1.88it/s]
94%|#########3| 470/500 [04:10<00:15, 1.88it/s]
94%|#########4| 471/500 [04:10<00:15, 1.88it/s]
94%|#########4| 472/500 [04:11<00:14, 1.88it/s]
95%|#########4| 473/500 [04:11<00:14, 1.88it/s]
95%|#########4| 474/500 [04:12<00:13, 1.88it/s]
95%|#########5| 475/500 [04:12<00:13, 1.88it/s]
95%|#########5| 476/500 [04:13<00:12, 1.88it/s]
95%|#########5| 477/500 [04:13<00:12, 1.88it/s]
96%|#########5| 478/500 [04:14<00:11, 1.88it/s]
96%|#########5| 479/500 [04:14<00:11, 1.88it/s]
96%|#########6| 480/500 [04:15<00:10, 1.88it/s]
96%|#########6| 481/500 [04:15<00:10, 1.89it/s]
96%|#########6| 482/500 [04:16<00:09, 1.89it/s]
97%|#########6| 483/500 [04:17<00:09, 1.89it/s]
97%|#########6| 484/500 [04:17<00:08, 1.89it/s]
97%|#########7| 485/500 [04:18<00:07, 1.89it/s]
97%|#########7| 486/500 [04:18<00:07, 1.89it/s]
97%|#########7| 487/500 [04:19<00:06, 1.89it/s]
98%|#########7| 488/500 [04:19<00:06, 1.89it/s]
98%|#########7| 489/500 [04:20<00:05, 1.89it/s]
98%|#########8| 490/500 [04:20<00:05, 1.89it/s]
98%|#########8| 491/500 [04:21<00:04, 1.89it/s]
98%|#########8| 492/500 [04:21<00:04, 1.89it/s]
99%|#########8| 493/500 [04:22<00:03, 1.89it/s]
99%|#########8| 494/500 [04:22<00:03, 1.89it/s]
99%|#########9| 495/500 [04:23<00:02, 1.89it/s]
99%|#########9| 496/500 [04:23<00:02, 1.89it/s]
99%|#########9| 497/500 [04:24<00:01, 1.89it/s]
100%|#########9| 498/500 [04:25<00:01, 1.89it/s]
100%|#########9| 499/500 [04:25<00:00, 1.89it/s]
100%|##########| 500/500 [04:26<00:00, 1.89it/s]
{'loss': 1.3419, 'grad_norm': 15.239953994750977, 'learning_rate': 5e-05, 'epoch': 0.18}
100%|##########| 500/500 [04:26<00:00, 1.89it/s]
{'train_runtime': 267.4795, 'train_samples_per_second': 59.818, 'train_steps_per_second': 1.869, 'train_loss': 1.850412796020508, 'epoch': 0.18}
100%|##########| 500/500 [04:27<00:00, 1.89it/s]
100%|##########| 500/500 [04:27<00:00, 1.87it/s]
0%| | 0/43 [00:00<?, ?it/s]
5%|4 | 2/43 [00:00<00:13, 3.12it/s]
7%|6 | 3/43 [00:01<00:18, 2.20it/s]
9%|9 | 4/43 [00:01<00:20, 1.90it/s]
12%|#1 | 5/43 [00:02<00:21, 1.76it/s]
14%|#3 | 6/43 [00:03<00:21, 1.69it/s]
16%|#6 | 7/43 [00:03<00:21, 1.64it/s]
19%|#8 | 8/43 [00:04<00:21, 1.62it/s]
21%|## | 9/43 [00:05<00:21, 1.60it/s]
23%|##3 | 10/43 [00:05<00:20, 1.58it/s]
26%|##5 | 11/43 [00:06<00:20, 1.57it/s]
28%|##7 | 12/43 [00:07<00:19, 1.57it/s]
30%|### | 13/43 [00:07<00:19, 1.56it/s]
33%|###2 | 14/43 [00:08<00:18, 1.56it/s]
35%|###4 | 15/43 [00:09<00:17, 1.56it/s]
37%|###7 | 16/43 [00:09<00:17, 1.56it/s]
40%|###9 | 17/43 [00:10<00:16, 1.56it/s]
42%|####1 | 18/43 [00:10<00:16, 1.56it/s]
44%|####4 | 19/43 [00:11<00:15, 1.56it/s]
47%|####6 | 20/43 [00:12<00:14, 1.56it/s]
49%|####8 | 21/43 [00:12<00:14, 1.55it/s]
51%|#####1 | 22/43 [00:13<00:13, 1.55it/s]
53%|#####3 | 23/43 [00:14<00:12, 1.55it/s]
56%|#####5 | 24/43 [00:14<00:12, 1.55it/s]
58%|#####8 | 25/43 [00:15<00:11, 1.55it/s]
60%|###### | 26/43 [00:16<00:10, 1.55it/s]
63%|######2 | 27/43 [00:16<00:10, 1.55it/s]
65%|######5 | 28/43 [00:17<00:09, 1.55it/s]
67%|######7 | 29/43 [00:18<00:09, 1.55it/s]
70%|######9 | 30/43 [00:18<00:08, 1.55it/s]
72%|#######2 | 31/43 [00:19<00:07, 1.55it/s]
74%|#######4 | 32/43 [00:19<00:07, 1.55it/s]
77%|#######6 | 33/43 [00:20<00:06, 1.55it/s]
79%|#######9 | 34/43 [00:21<00:05, 1.55it/s]
81%|########1 | 35/43 [00:21<00:05, 1.55it/s]
84%|########3 | 36/43 [00:22<00:04, 1.55it/s]
86%|########6 | 37/43 [00:23<00:03, 1.55it/s]
88%|########8 | 38/43 [00:23<00:03, 1.55it/s]
91%|######### | 39/43 [00:24<00:02, 1.55it/s]
93%|#########3| 40/43 [00:25<00:01, 1.55it/s]
95%|#########5| 41/43 [00:25<00:01, 1.55it/s]
98%|#########7| 42/43 [00:26<00:00, 1.66it/s]
100%|##########| 43/43 [00:26<00:00, 1.59it/s]
100%|##########| 43/43 [00:26<00:00, 1.60it/s]
Downloading builder script: 0%| | 0.00/4.53k [00:00<?, ?B/s]
Downloading builder script: 100%|##########| 4.53k/4.53k [00:00<00:00, 26.4MB/s]
Downloading extra modules: 0%| | 0.00/3.32k [00:00<?, ?B/s]
Downloading extra modules: 100%|##########| 3.32k/3.32k [00:00<00:00, 33.4MB/s]
fp16 {'exact_match': 71.28666035950805, 'f1': 80.62619387545674}
cuda_fp16 time {4: 9.563420739996218, '4_compile': 9.004673000163166, 16: 31.877686199914024, '16_compile': 30.932841000321787, 64: 123.92699649990391, '64_compile': 104.97432750071312, 256: 476.81639599977643, '256_compile': 396.38552500036894}
<Axes: title={'center': 'Loss vs. # steps'}, xlabel='step', ylabel='loss'>
将 BERT 修剪为 2:4 稀疏¶
现在我们有了基线,是时候修剪 BERT 了。存在许多不同的修剪策略,但最常见的一种是**幅度修剪**,它旨在去除 L1 范数最低的权重。英伟达在其所有结果中都使用了幅度修剪,它是一种常见的基线。
为此,我们将使用 torch.ao.pruning
包,其中包含一个权重范数(幅度)稀疏器。这些稀疏器通过对模型中的权重张量应用掩码参数化来工作。这使它们能够通过屏蔽已修剪的权重来模拟稀疏性。
我们还需要决定对模型的哪些层应用稀疏性,在本例中是所有 nn.Linear
层,除了特定于任务的头输出。这是因为半结构化稀疏性具有 形状约束,而特定于任务的 nn.Linear
层不满足这些约束。
sparsifier = WeightNormSparsifier(
# apply sparsity to all blocks
sparsity_level=1.0,
# shape of 4 elements is a block
sparse_block_shape=(1, 4),
# two zeros for every block of 4
zeros_per_block=2
)
# add to config if ``nn.Linear`` and in the BERT model.
sparse_config = [
{"tensor_fqn": f"{fqn}.weight"}
for fqn, module in model.named_modules()
if isinstance(module, nn.Linear) and "layer" in fqn
]
修剪模型的第一步是插入用于掩码模型权重的参数化。这是由准备步骤完成的。每当我们尝试访问 .weight
时,我们将得到 mask * weight
而不是原始权重。
# Prepare the model, insert fake-sparsity parametrizations for training
sparsifier.prepare(model, sparse_config)
print(model.bert.encoder.layer[0].output)
BertOutput(
(dense): ParametrizedLinear(
in_features=3072, out_features=768, bias=True
(parametrizations): ModuleDict(
(weight): ParametrizationList(
(0): FakeSparsity()
)
)
)
(LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
(dropout): Dropout(p=0.1, inplace=False)
)
然后,我们将进行一次修剪步骤。所有修剪器都实现了一个 update_mask()
方法,该方法使用修剪器实现中确定的逻辑更新掩码。step 方法会为稀疏配置中指定的权重调用此 update_mask
函数。
我们还将评估模型以显示零样本修剪的准确性下降,即在不进行微调/重新训练的情况下进行修剪。
sparsifier.step()
with torch.autocast("cuda"):
with torch.no_grad():
predictions = trainer.predict(tokenized_squad_dataset["validation"])
pruned = compute_metrics(
*predictions.predictions,
tokenized_squad_dataset["validation"],
squad_dataset["validation"],
)
print("pruned eval metrics:", pruned)
0%| | 0/43 [00:00<?, ?it/s]
5%|4 | 2/43 [00:00<00:13, 3.10it/s]
7%|6 | 3/43 [00:01<00:18, 2.19it/s]
9%|9 | 4/43 [00:01<00:20, 1.90it/s]
12%|#1 | 5/43 [00:02<00:21, 1.76it/s]
14%|#3 | 6/43 [00:03<00:21, 1.69it/s]
16%|#6 | 7/43 [00:03<00:21, 1.64it/s]
19%|#8 | 8/43 [00:04<00:21, 1.61it/s]
21%|## | 9/43 [00:05<00:21, 1.59it/s]
23%|##3 | 10/43 [00:05<00:20, 1.58it/s]
26%|##5 | 11/43 [00:06<00:20, 1.57it/s]
28%|##7 | 12/43 [00:07<00:19, 1.56it/s]
30%|### | 13/43 [00:07<00:19, 1.56it/s]
33%|###2 | 14/43 [00:08<00:18, 1.55it/s]
35%|###4 | 15/43 [00:09<00:18, 1.55it/s]
37%|###7 | 16/43 [00:09<00:17, 1.55it/s]
40%|###9 | 17/43 [00:10<00:16, 1.55it/s]
42%|####1 | 18/43 [00:10<00:16, 1.55it/s]
44%|####4 | 19/43 [00:11<00:15, 1.55it/s]
47%|####6 | 20/43 [00:12<00:14, 1.55it/s]
49%|####8 | 21/43 [00:12<00:14, 1.55it/s]
51%|#####1 | 22/43 [00:13<00:13, 1.55it/s]
53%|#####3 | 23/43 [00:14<00:12, 1.55it/s]
56%|#####5 | 24/43 [00:14<00:12, 1.55it/s]
58%|#####8 | 25/43 [00:15<00:11, 1.55it/s]
60%|###### | 26/43 [00:16<00:10, 1.55it/s]
63%|######2 | 27/43 [00:16<00:10, 1.55it/s]
65%|######5 | 28/43 [00:17<00:09, 1.55it/s]
67%|######7 | 29/43 [00:18<00:09, 1.55it/s]
70%|######9 | 30/43 [00:18<00:08, 1.55it/s]
72%|#######2 | 31/43 [00:19<00:07, 1.55it/s]
74%|#######4 | 32/43 [00:20<00:07, 1.55it/s]
77%|#######6 | 33/43 [00:20<00:06, 1.55it/s]
79%|#######9 | 34/43 [00:21<00:05, 1.55it/s]
81%|########1 | 35/43 [00:21<00:05, 1.55it/s]
84%|########3 | 36/43 [00:22<00:04, 1.55it/s]
86%|########6 | 37/43 [00:23<00:03, 1.55it/s]
88%|########8 | 38/43 [00:23<00:03, 1.55it/s]
91%|######### | 39/43 [00:24<00:02, 1.55it/s]
93%|#########3| 40/43 [00:25<00:01, 1.55it/s]
95%|#########5| 41/43 [00:25<00:01, 1.55it/s]
98%|#########7| 42/43 [00:26<00:00, 1.65it/s]
100%|##########| 43/43 [00:26<00:00, 1.63it/s]
pruned eval metrics: {'exact_match': 29.87701040681173, 'f1': 41.483943521142855}
在此状态下,我们可以开始微调模型,更新不会被修剪的元素以更好地弥补准确性损失。一旦我们达到满意的状态,就可以调用 squash_mask
将掩码和权重融合在一起。这将删除参数化,我们得到一个置零的 2:4 密集模型。
trainer.train()
sparsifier.squash_mask()
torch.set_printoptions(edgeitems=4)
print(model.bert.encoder.layer[0].intermediate.dense.weight[:8, :8])
df["sparse_loss"] = pd.DataFrame(trainer.state.log_history)["loss"]
df.plot.line(x='step', y=["loss", "sparse_loss"], title="Loss vs. # steps", ylabel="loss")
0%| | 0/500 [00:00<?, ?it/s]
0%| | 1/500 [00:00<04:34, 1.82it/s]
0%| | 2/500 [00:01<04:28, 1.86it/s]
1%| | 3/500 [00:01<04:26, 1.87it/s]
1%| | 4/500 [00:02<04:24, 1.87it/s]
1%|1 | 5/500 [00:02<04:23, 1.88it/s]
1%|1 | 6/500 [00:03<04:22, 1.88it/s]
1%|1 | 7/500 [00:03<04:22, 1.88it/s]
2%|1 | 8/500 [00:04<04:21, 1.88it/s]
2%|1 | 9/500 [00:04<04:21, 1.88it/s]
2%|2 | 10/500 [00:05<04:20, 1.88it/s]
2%|2 | 11/500 [00:05<04:20, 1.88it/s]
2%|2 | 12/500 [00:06<04:19, 1.88it/s]
3%|2 | 13/500 [00:06<04:19, 1.88it/s]
3%|2 | 14/500 [00:07<04:19, 1.88it/s]
3%|3 | 15/500 [00:07<04:18, 1.88it/s]
3%|3 | 16/500 [00:08<04:18, 1.88it/s]
3%|3 | 17/500 [00:09<04:17, 1.87it/s]
4%|3 | 18/500 [00:09<04:17, 1.87it/s]
4%|3 | 19/500 [00:10<04:16, 1.87it/s]
4%|4 | 20/500 [00:10<04:16, 1.87it/s]
4%|4 | 21/500 [00:11<04:15, 1.87it/s]
4%|4 | 22/500 [00:11<04:15, 1.87it/s]
5%|4 | 23/500 [00:12<04:14, 1.87it/s]
5%|4 | 24/500 [00:12<04:14, 1.87it/s]
5%|5 | 25/500 [00:13<04:13, 1.87it/s]
5%|5 | 26/500 [00:13<04:13, 1.87it/s]
5%|5 | 27/500 [00:14<04:12, 1.87it/s]
6%|5 | 28/500 [00:14<04:12, 1.87it/s]
6%|5 | 29/500 [00:15<04:11, 1.87it/s]
6%|6 | 30/500 [00:16<04:11, 1.87it/s]
6%|6 | 31/500 [00:16<04:10, 1.87it/s]
6%|6 | 32/500 [00:17<04:10, 1.87it/s]
7%|6 | 33/500 [00:17<04:09, 1.87it/s]
7%|6 | 34/500 [00:18<04:09, 1.87it/s]
7%|7 | 35/500 [00:18<04:08, 1.87it/s]
7%|7 | 36/500 [00:19<04:07, 1.87it/s]
7%|7 | 37/500 [00:19<04:06, 1.88it/s]
8%|7 | 38/500 [00:20<04:05, 1.88it/s]
8%|7 | 39/500 [00:20<04:05, 1.88it/s]
8%|8 | 40/500 [00:21<04:04, 1.88it/s]
8%|8 | 41/500 [00:21<04:03, 1.88it/s]
8%|8 | 42/500 [00:22<04:03, 1.88it/s]
9%|8 | 43/500 [00:22<04:02, 1.88it/s]
9%|8 | 44/500 [00:23<04:02, 1.88it/s]
9%|9 | 45/500 [00:23<04:01, 1.88it/s]
9%|9 | 46/500 [00:24<04:01, 1.88it/s]
9%|9 | 47/500 [00:25<04:00, 1.88it/s]
10%|9 | 48/500 [00:25<04:00, 1.88it/s]
10%|9 | 49/500 [00:26<03:59, 1.88it/s]
10%|# | 50/500 [00:26<03:59, 1.88it/s]
{'loss': 1.8793, 'grad_norm': 11.104135513305664, 'learning_rate': 5e-05, 'epoch': 0.02}
10%|# | 50/500 [00:26<03:59, 1.88it/s]
10%|# | 51/500 [00:27<03:58, 1.88it/s]
10%|# | 52/500 [00:27<03:58, 1.88it/s]
11%|# | 53/500 [00:28<03:57, 1.88it/s]
11%|# | 54/500 [00:28<03:57, 1.88it/s]
11%|#1 | 55/500 [00:29<03:56, 1.88it/s]
11%|#1 | 56/500 [00:29<03:56, 1.88it/s]
11%|#1 | 57/500 [00:30<03:55, 1.88it/s]
12%|#1 | 58/500 [00:30<03:55, 1.88it/s]
12%|#1 | 59/500 [00:31<03:54, 1.88it/s]
12%|#2 | 60/500 [00:31<03:54, 1.88it/s]
12%|#2 | 61/500 [00:32<03:53, 1.88it/s]
12%|#2 | 62/500 [00:33<03:52, 1.88it/s]
13%|#2 | 63/500 [00:33<03:52, 1.88it/s]
13%|#2 | 64/500 [00:34<03:51, 1.88it/s]
13%|#3 | 65/500 [00:34<03:51, 1.88it/s]
13%|#3 | 66/500 [00:35<03:50, 1.88it/s]
13%|#3 | 67/500 [00:35<03:50, 1.88it/s]
14%|#3 | 68/500 [00:36<03:49, 1.88it/s]
14%|#3 | 69/500 [00:36<03:49, 1.88it/s]
14%|#4 | 70/500 [00:37<03:48, 1.88it/s]
14%|#4 | 71/500 [00:37<03:48, 1.88it/s]
14%|#4 | 72/500 [00:38<03:47, 1.88it/s]
15%|#4 | 73/500 [00:38<03:47, 1.88it/s]
15%|#4 | 74/500 [00:39<03:46, 1.88it/s]
15%|#5 | 75/500 [00:39<03:46, 1.88it/s]
15%|#5 | 76/500 [00:40<03:45, 1.88it/s]
15%|#5 | 77/500 [00:41<03:44, 1.88it/s]
16%|#5 | 78/500 [00:41<03:44, 1.88it/s]
16%|#5 | 79/500 [00:42<03:43, 1.88it/s]
16%|#6 | 80/500 [00:42<03:43, 1.88it/s]
16%|#6 | 81/500 [00:43<03:42, 1.88it/s]
16%|#6 | 82/500 [00:43<03:42, 1.88it/s]
17%|#6 | 83/500 [00:44<03:41, 1.88it/s]
17%|#6 | 84/500 [00:44<03:41, 1.88it/s]
17%|#7 | 85/500 [00:45<03:41, 1.88it/s]
17%|#7 | 86/500 [00:45<03:40, 1.88it/s]
17%|#7 | 87/500 [00:46<03:40, 1.88it/s]
18%|#7 | 88/500 [00:46<03:39, 1.88it/s]
18%|#7 | 89/500 [00:47<03:39, 1.88it/s]
18%|#8 | 90/500 [00:47<03:38, 1.88it/s]
18%|#8 | 91/500 [00:48<03:38, 1.88it/s]
18%|#8 | 92/500 [00:49<03:37, 1.88it/s]
19%|#8 | 93/500 [00:49<03:37, 1.88it/s]
19%|#8 | 94/500 [00:50<03:36, 1.88it/s]
19%|#9 | 95/500 [00:50<03:35, 1.88it/s]
19%|#9 | 96/500 [00:51<03:35, 1.88it/s]
19%|#9 | 97/500 [00:51<03:34, 1.88it/s]
20%|#9 | 98/500 [00:52<03:34, 1.88it/s]
20%|#9 | 99/500 [00:52<03:33, 1.88it/s]
20%|## | 100/500 [00:53<03:33, 1.88it/s]
{'loss': 1.4024, 'grad_norm': 8.995004653930664, 'learning_rate': 5e-05, 'epoch': 0.04}
20%|## | 100/500 [00:53<03:33, 1.88it/s]
20%|## | 101/500 [00:53<03:32, 1.87it/s]
20%|## | 102/500 [00:54<03:32, 1.88it/s]
21%|## | 103/500 [00:54<03:31, 1.88it/s]
21%|## | 104/500 [00:55<03:31, 1.87it/s]
21%|##1 | 105/500 [00:55<03:30, 1.87it/s]
21%|##1 | 106/500 [00:56<03:30, 1.87it/s]
21%|##1 | 107/500 [00:57<03:29, 1.87it/s]
22%|##1 | 108/500 [00:57<03:29, 1.87it/s]
22%|##1 | 109/500 [00:58<03:28, 1.88it/s]
22%|##2 | 110/500 [00:58<03:27, 1.88it/s]
22%|##2 | 111/500 [00:59<03:27, 1.87it/s]
22%|##2 | 112/500 [00:59<03:26, 1.88it/s]
23%|##2 | 113/500 [01:00<03:26, 1.88it/s]
23%|##2 | 114/500 [01:00<03:25, 1.87it/s]
23%|##3 | 115/500 [01:01<03:25, 1.87it/s]
23%|##3 | 116/500 [01:01<03:24, 1.87it/s]
23%|##3 | 117/500 [01:02<03:24, 1.87it/s]
24%|##3 | 118/500 [01:02<03:23, 1.87it/s]
24%|##3 | 119/500 [01:03<03:23, 1.87it/s]
24%|##4 | 120/500 [01:03<03:22, 1.87it/s]
24%|##4 | 121/500 [01:04<03:22, 1.87it/s]
24%|##4 | 122/500 [01:05<03:21, 1.87it/s]
25%|##4 | 123/500 [01:05<03:21, 1.87it/s]
25%|##4 | 124/500 [01:06<03:20, 1.87it/s]
25%|##5 | 125/500 [01:06<03:20, 1.87it/s]
25%|##5 | 126/500 [01:07<03:19, 1.87it/s]
25%|##5 | 127/500 [01:07<03:19, 1.87it/s]
26%|##5 | 128/500 [01:08<03:18, 1.87it/s]
26%|##5 | 129/500 [01:08<03:18, 1.87it/s]
26%|##6 | 130/500 [01:09<03:17, 1.87it/s]
26%|##6 | 131/500 [01:09<03:17, 1.87it/s]
26%|##6 | 132/500 [01:10<03:16, 1.87it/s]
27%|##6 | 133/500 [01:10<03:16, 1.87it/s]
27%|##6 | 134/500 [01:11<03:15, 1.87it/s]
27%|##7 | 135/500 [01:11<03:15, 1.87it/s]
27%|##7 | 136/500 [01:12<03:14, 1.87it/s]
27%|##7 | 137/500 [01:13<03:13, 1.87it/s]
28%|##7 | 138/500 [01:13<03:13, 1.87it/s]
28%|##7 | 139/500 [01:14<03:12, 1.87it/s]
28%|##8 | 140/500 [01:14<03:12, 1.87it/s]
28%|##8 | 141/500 [01:15<03:11, 1.87it/s]
28%|##8 | 142/500 [01:15<03:11, 1.87it/s]
29%|##8 | 143/500 [01:16<03:10, 1.87it/s]
29%|##8 | 144/500 [01:16<03:10, 1.87it/s]
29%|##9 | 145/500 [01:17<03:09, 1.87it/s]
29%|##9 | 146/500 [01:17<03:09, 1.87it/s]
29%|##9 | 147/500 [01:18<03:08, 1.87it/s]
30%|##9 | 148/500 [01:18<03:08, 1.87it/s]
30%|##9 | 149/500 [01:19<03:07, 1.87it/s]
30%|### | 150/500 [01:19<03:07, 1.87it/s]
{'loss': 1.1782, 'grad_norm': 9.000022888183594, 'learning_rate': 5e-05, 'epoch': 0.05}
30%|### | 150/500 [01:19<03:07, 1.87it/s]
30%|### | 151/500 [01:20<03:06, 1.87it/s]
30%|### | 152/500 [01:21<03:06, 1.87it/s]
31%|### | 153/500 [01:21<03:05, 1.87it/s]
31%|### | 154/500 [01:22<03:04, 1.87it/s]
31%|###1 | 155/500 [01:22<03:04, 1.87it/s]
31%|###1 | 156/500 [01:23<03:03, 1.87it/s]
31%|###1 | 157/500 [01:23<03:03, 1.87it/s]
32%|###1 | 158/500 [01:24<03:02, 1.87it/s]
32%|###1 | 159/500 [01:24<03:02, 1.87it/s]
32%|###2 | 160/500 [01:25<03:01, 1.87it/s]
32%|###2 | 161/500 [01:25<03:01, 1.87it/s]
32%|###2 | 162/500 [01:26<03:00, 1.87it/s]
33%|###2 | 163/500 [01:26<03:00, 1.87it/s]
33%|###2 | 164/500 [01:27<02:59, 1.87it/s]
33%|###3 | 165/500 [01:27<02:58, 1.87it/s]
33%|###3 | 166/500 [01:28<02:58, 1.87it/s]
33%|###3 | 167/500 [01:29<02:57, 1.87it/s]
34%|###3 | 168/500 [01:29<02:57, 1.87it/s]
34%|###3 | 169/500 [01:30<02:56, 1.87it/s]
34%|###4 | 170/500 [01:30<02:56, 1.87it/s]
34%|###4 | 171/500 [01:31<02:55, 1.87it/s]
34%|###4 | 172/500 [01:31<02:55, 1.87it/s]
35%|###4 | 173/500 [01:32<02:54, 1.87it/s]
35%|###4 | 174/500 [01:32<02:54, 1.87it/s]
35%|###5 | 175/500 [01:33<02:53, 1.87it/s]
35%|###5 | 176/500 [01:33<02:53, 1.87it/s]
35%|###5 | 177/500 [01:34<02:52, 1.87it/s]
36%|###5 | 178/500 [01:34<02:52, 1.87it/s]
36%|###5 | 179/500 [01:35<02:51, 1.87it/s]
36%|###6 | 180/500 [01:35<02:51, 1.87it/s]
36%|###6 | 181/500 [01:36<02:50, 1.87it/s]
36%|###6 | 182/500 [01:37<02:50, 1.87it/s]
37%|###6 | 183/500 [01:37<02:49, 1.87it/s]
37%|###6 | 184/500 [01:38<02:48, 1.87it/s]
37%|###7 | 185/500 [01:38<02:48, 1.87it/s]
37%|###7 | 186/500 [01:39<02:47, 1.87it/s]
37%|###7 | 187/500 [01:39<02:47, 1.87it/s]
38%|###7 | 188/500 [01:40<02:46, 1.87it/s]
38%|###7 | 189/500 [01:40<02:46, 1.87it/s]
38%|###8 | 190/500 [01:41<02:45, 1.87it/s]
38%|###8 | 191/500 [01:41<02:45, 1.87it/s]
38%|###8 | 192/500 [01:42<02:44, 1.87it/s]
39%|###8 | 193/500 [01:42<02:44, 1.87it/s]
39%|###8 | 194/500 [01:43<02:43, 1.87it/s]
39%|###9 | 195/500 [01:44<02:42, 1.87it/s]
39%|###9 | 196/500 [01:44<02:42, 1.87it/s]
39%|###9 | 197/500 [01:45<02:41, 1.87it/s]
40%|###9 | 198/500 [01:45<02:41, 1.87it/s]
40%|###9 | 199/500 [01:46<02:40, 1.87it/s]
40%|#### | 200/500 [01:46<02:40, 1.87it/s]
{'loss': 1.2231, 'grad_norm': 8.50694465637207, 'learning_rate': 5e-05, 'epoch': 0.07}
40%|#### | 200/500 [01:46<02:40, 1.87it/s]
40%|#### | 201/500 [01:47<02:39, 1.87it/s]
40%|#### | 202/500 [01:47<02:39, 1.87it/s]
41%|#### | 203/500 [01:48<02:38, 1.87it/s]
41%|#### | 204/500 [01:48<02:38, 1.87it/s]
41%|####1 | 205/500 [01:49<02:37, 1.87it/s]
41%|####1 | 206/500 [01:49<02:37, 1.87it/s]
41%|####1 | 207/500 [01:50<02:36, 1.87it/s]
42%|####1 | 208/500 [01:50<02:36, 1.87it/s]
42%|####1 | 209/500 [01:51<02:35, 1.87it/s]
42%|####2 | 210/500 [01:52<02:35, 1.87it/s]
42%|####2 | 211/500 [01:52<02:34, 1.87it/s]
42%|####2 | 212/500 [01:53<02:33, 1.87it/s]
43%|####2 | 213/500 [01:53<02:33, 1.87it/s]
43%|####2 | 214/500 [01:54<02:32, 1.87it/s]
43%|####3 | 215/500 [01:54<02:32, 1.87it/s]
43%|####3 | 216/500 [01:55<02:31, 1.87it/s]
43%|####3 | 217/500 [01:55<02:31, 1.87it/s]
44%|####3 | 218/500 [01:56<02:30, 1.87it/s]
44%|####3 | 219/500 [01:56<02:30, 1.87it/s]
44%|####4 | 220/500 [01:57<02:29, 1.87it/s]
44%|####4 | 221/500 [01:57<02:29, 1.87it/s]
44%|####4 | 222/500 [01:58<02:28, 1.87it/s]
45%|####4 | 223/500 [01:58<02:28, 1.87it/s]
45%|####4 | 224/500 [01:59<02:27, 1.87it/s]
45%|####5 | 225/500 [02:00<02:27, 1.87it/s]
45%|####5 | 226/500 [02:00<02:26, 1.87it/s]
45%|####5 | 227/500 [02:01<02:26, 1.87it/s]
46%|####5 | 228/500 [02:01<02:25, 1.87it/s]
46%|####5 | 229/500 [02:02<02:24, 1.87it/s]
46%|####6 | 230/500 [02:02<02:24, 1.87it/s]
46%|####6 | 231/500 [02:03<02:23, 1.87it/s]
46%|####6 | 232/500 [02:03<02:23, 1.87it/s]
47%|####6 | 233/500 [02:04<02:22, 1.87it/s]
47%|####6 | 234/500 [02:04<02:22, 1.87it/s]
47%|####6 | 235/500 [02:05<02:21, 1.87it/s]
47%|####7 | 236/500 [02:05<02:21, 1.87it/s]
47%|####7 | 237/500 [02:06<02:20, 1.87it/s]
48%|####7 | 238/500 [02:07<02:20, 1.87it/s]
48%|####7 | 239/500 [02:07<02:19, 1.87it/s]
48%|####8 | 240/500 [02:08<02:18, 1.87it/s]
48%|####8 | 241/500 [02:08<02:18, 1.87it/s]
48%|####8 | 242/500 [02:09<02:17, 1.87it/s]
49%|####8 | 243/500 [02:09<02:17, 1.87it/s]
49%|####8 | 244/500 [02:10<02:16, 1.87it/s]
49%|####9 | 245/500 [02:10<02:16, 1.87it/s]
49%|####9 | 246/500 [02:11<02:15, 1.87it/s]
49%|####9 | 247/500 [02:11<02:15, 1.87it/s]
50%|####9 | 248/500 [02:12<02:14, 1.87it/s]
50%|####9 | 249/500 [02:12<02:14, 1.87it/s]
50%|##### | 250/500 [02:13<02:13, 1.87it/s]
{'loss': 1.1114, 'grad_norm': 7.177275657653809, 'learning_rate': 5e-05, 'epoch': 0.09}
50%|##### | 250/500 [02:13<02:13, 1.87it/s]
50%|##### | 251/500 [02:13<02:13, 1.87it/s]
50%|##### | 252/500 [02:14<02:12, 1.87it/s]
51%|##### | 253/500 [02:15<02:12, 1.87it/s]
51%|##### | 254/500 [02:15<02:11, 1.87it/s]
51%|#####1 | 255/500 [02:16<02:10, 1.87it/s]
51%|#####1 | 256/500 [02:16<02:10, 1.87it/s]
51%|#####1 | 257/500 [02:17<02:09, 1.87it/s]
52%|#####1 | 258/500 [02:17<02:09, 1.87it/s]
52%|#####1 | 259/500 [02:18<02:08, 1.87it/s]
52%|#####2 | 260/500 [02:18<02:08, 1.87it/s]
52%|#####2 | 261/500 [02:19<02:07, 1.87it/s]
52%|#####2 | 262/500 [02:19<02:07, 1.87it/s]
53%|#####2 | 263/500 [02:20<02:06, 1.87it/s]
53%|#####2 | 264/500 [02:20<02:06, 1.87it/s]
53%|#####3 | 265/500 [02:21<02:05, 1.87it/s]
53%|#####3 | 266/500 [02:21<02:05, 1.87it/s]
53%|#####3 | 267/500 [02:22<02:04, 1.87it/s]
54%|#####3 | 268/500 [02:23<02:03, 1.87it/s]
54%|#####3 | 269/500 [02:23<02:03, 1.87it/s]
54%|#####4 | 270/500 [02:24<02:02, 1.87it/s]
54%|#####4 | 271/500 [02:24<02:02, 1.87it/s]
54%|#####4 | 272/500 [02:25<02:01, 1.87it/s]
55%|#####4 | 273/500 [02:25<02:01, 1.87it/s]
55%|#####4 | 274/500 [02:26<02:00, 1.87it/s]
55%|#####5 | 275/500 [02:26<02:00, 1.87it/s]
55%|#####5 | 276/500 [02:27<01:59, 1.87it/s]
55%|#####5 | 277/500 [02:27<01:59, 1.87it/s]
56%|#####5 | 278/500 [02:28<01:58, 1.87it/s]
56%|#####5 | 279/500 [02:28<01:58, 1.87it/s]
56%|#####6 | 280/500 [02:29<01:57, 1.87it/s]
56%|#####6 | 281/500 [02:29<01:57, 1.87it/s]
56%|#####6 | 282/500 [02:30<01:56, 1.87it/s]
57%|#####6 | 283/500 [02:31<01:55, 1.87it/s]
57%|#####6 | 284/500 [02:31<01:55, 1.87it/s]
57%|#####6 | 285/500 [02:32<01:54, 1.87it/s]
57%|#####7 | 286/500 [02:32<01:54, 1.87it/s]
57%|#####7 | 287/500 [02:33<01:53, 1.87it/s]
58%|#####7 | 288/500 [02:33<01:53, 1.87it/s]
58%|#####7 | 289/500 [02:34<01:52, 1.87it/s]
58%|#####8 | 290/500 [02:34<01:52, 1.87it/s]
58%|#####8 | 291/500 [02:35<01:51, 1.87it/s]
58%|#####8 | 292/500 [02:35<01:51, 1.87it/s]
59%|#####8 | 293/500 [02:36<01:50, 1.87it/s]
59%|#####8 | 294/500 [02:36<01:50, 1.87it/s]
59%|#####8 | 295/500 [02:37<01:49, 1.87it/s]
59%|#####9 | 296/500 [02:37<01:48, 1.87it/s]
59%|#####9 | 297/500 [02:38<01:48, 1.87it/s]
60%|#####9 | 298/500 [02:39<01:47, 1.87it/s]
60%|#####9 | 299/500 [02:39<01:47, 1.87it/s]
60%|###### | 300/500 [02:40<01:46, 1.87it/s]
{'loss': 1.1299, 'grad_norm': 7.932921409606934, 'learning_rate': 5e-05, 'epoch': 0.11}
60%|###### | 300/500 [02:40<01:46, 1.87it/s]
60%|###### | 301/500 [02:40<01:46, 1.87it/s]
60%|###### | 302/500 [02:41<01:45, 1.87it/s]
61%|###### | 303/500 [02:41<01:45, 1.87it/s]
61%|###### | 304/500 [02:42<01:44, 1.87it/s]
61%|######1 | 305/500 [02:42<01:44, 1.87it/s]
61%|######1 | 306/500 [02:43<01:43, 1.87it/s]
61%|######1 | 307/500 [02:43<01:43, 1.87it/s]
62%|######1 | 308/500 [02:44<01:42, 1.87it/s]
62%|######1 | 309/500 [02:44<01:42, 1.87it/s]
62%|######2 | 310/500 [02:45<01:41, 1.87it/s]
62%|######2 | 311/500 [02:46<01:40, 1.87it/s]
62%|######2 | 312/500 [02:46<01:40, 1.87it/s]
63%|######2 | 313/500 [02:47<01:39, 1.87it/s]
63%|######2 | 314/500 [02:47<01:39, 1.87it/s]
63%|######3 | 315/500 [02:48<01:38, 1.87it/s]
63%|######3 | 316/500 [02:48<01:38, 1.87it/s]
63%|######3 | 317/500 [02:49<01:37, 1.87it/s]
64%|######3 | 318/500 [02:49<01:37, 1.87it/s]
64%|######3 | 319/500 [02:50<01:36, 1.87it/s]
64%|######4 | 320/500 [02:50<01:36, 1.87it/s]
64%|######4 | 321/500 [02:51<01:35, 1.87it/s]
64%|######4 | 322/500 [02:51<01:35, 1.87it/s]
65%|######4 | 323/500 [02:52<01:34, 1.87it/s]
65%|######4 | 324/500 [02:52<01:34, 1.87it/s]
65%|######5 | 325/500 [02:53<01:33, 1.87it/s]
65%|######5 | 326/500 [02:54<01:32, 1.87it/s]
65%|######5 | 327/500 [02:54<01:32, 1.87it/s]
66%|######5 | 328/500 [02:55<01:31, 1.87it/s]
66%|######5 | 329/500 [02:55<01:31, 1.87it/s]
66%|######6 | 330/500 [02:56<01:30, 1.87it/s]
66%|######6 | 331/500 [02:56<01:30, 1.87it/s]
66%|######6 | 332/500 [02:57<01:29, 1.87it/s]
67%|######6 | 333/500 [02:57<01:29, 1.87it/s]
67%|######6 | 334/500 [02:58<01:28, 1.87it/s]
67%|######7 | 335/500 [02:58<01:28, 1.87it/s]
67%|######7 | 336/500 [02:59<01:27, 1.87it/s]
67%|######7 | 337/500 [02:59<01:27, 1.87it/s]
68%|######7 | 338/500 [03:00<01:26, 1.87it/s]
68%|######7 | 339/500 [03:00<01:26, 1.87it/s]
68%|######8 | 340/500 [03:01<01:25, 1.87it/s]
68%|######8 | 341/500 [03:02<01:25, 1.87it/s]
68%|######8 | 342/500 [03:02<01:24, 1.87it/s]
69%|######8 | 343/500 [03:03<01:23, 1.87it/s]
69%|######8 | 344/500 [03:03<01:23, 1.87it/s]
69%|######9 | 345/500 [03:04<01:22, 1.87it/s]
69%|######9 | 346/500 [03:04<01:22, 1.87it/s]
69%|######9 | 347/500 [03:05<01:21, 1.87it/s]
70%|######9 | 348/500 [03:05<01:21, 1.87it/s]
70%|######9 | 349/500 [03:06<01:20, 1.87it/s]
70%|####### | 350/500 [03:06<01:20, 1.87it/s]
{'loss': 1.1066, 'grad_norm': 8.763031959533691, 'learning_rate': 5e-05, 'epoch': 0.13}
70%|####### | 350/500 [03:06<01:20, 1.87it/s]
70%|####### | 351/500 [03:07<01:19, 1.87it/s]
70%|####### | 352/500 [03:07<01:19, 1.87it/s]
71%|####### | 353/500 [03:08<01:18, 1.87it/s]
71%|####### | 354/500 [03:09<01:18, 1.87it/s]
71%|#######1 | 355/500 [03:09<01:17, 1.87it/s]
71%|#######1 | 356/500 [03:10<01:16, 1.87it/s]
71%|#######1 | 357/500 [03:10<01:16, 1.87it/s]
72%|#######1 | 358/500 [03:11<01:15, 1.87it/s]
72%|#######1 | 359/500 [03:11<01:15, 1.87it/s]
72%|#######2 | 360/500 [03:12<01:14, 1.87it/s]
72%|#######2 | 361/500 [03:12<01:14, 1.87it/s]
72%|#######2 | 362/500 [03:13<01:13, 1.87it/s]
73%|#######2 | 363/500 [03:13<01:13, 1.87it/s]
73%|#######2 | 364/500 [03:14<01:12, 1.87it/s]
73%|#######3 | 365/500 [03:14<01:12, 1.87it/s]
73%|#######3 | 366/500 [03:15<01:11, 1.87it/s]
73%|#######3 | 367/500 [03:15<01:11, 1.87it/s]
74%|#######3 | 368/500 [03:16<01:10, 1.87it/s]
74%|#######3 | 369/500 [03:17<01:10, 1.87it/s]
74%|#######4 | 370/500 [03:17<01:09, 1.87it/s]
74%|#######4 | 371/500 [03:18<01:08, 1.87it/s]
74%|#######4 | 372/500 [03:18<01:08, 1.87it/s]
75%|#######4 | 373/500 [03:19<01:07, 1.87it/s]
75%|#######4 | 374/500 [03:19<01:07, 1.87it/s]
75%|#######5 | 375/500 [03:20<01:06, 1.87it/s]
75%|#######5 | 376/500 [03:20<01:06, 1.87it/s]
75%|#######5 | 377/500 [03:21<01:05, 1.87it/s]
76%|#######5 | 378/500 [03:21<01:05, 1.87it/s]
76%|#######5 | 379/500 [03:22<01:04, 1.87it/s]
76%|#######6 | 380/500 [03:22<01:04, 1.87it/s]
76%|#######6 | 381/500 [03:23<01:03, 1.87it/s]
76%|#######6 | 382/500 [03:23<01:03, 1.87it/s]
77%|#######6 | 383/500 [03:24<01:02, 1.87it/s]
77%|#######6 | 384/500 [03:25<01:01, 1.87it/s]
77%|#######7 | 385/500 [03:25<01:01, 1.87it/s]
77%|#######7 | 386/500 [03:26<01:00, 1.87it/s]
77%|#######7 | 387/500 [03:26<01:00, 1.87it/s]
78%|#######7 | 388/500 [03:27<00:59, 1.87it/s]
78%|#######7 | 389/500 [03:27<00:59, 1.87it/s]
78%|#######8 | 390/500 [03:28<00:58, 1.87it/s]
78%|#######8 | 391/500 [03:28<00:58, 1.87it/s]
78%|#######8 | 392/500 [03:29<00:57, 1.87it/s]
79%|#######8 | 393/500 [03:29<00:57, 1.87it/s]
79%|#######8 | 394/500 [03:30<00:56, 1.87it/s]
79%|#######9 | 395/500 [03:30<00:56, 1.87it/s]
79%|#######9 | 396/500 [03:31<00:55, 1.87it/s]
79%|#######9 | 397/500 [03:31<00:55, 1.87it/s]
80%|#######9 | 398/500 [03:32<00:54, 1.87it/s]
80%|#######9 | 399/500 [03:33<00:53, 1.87it/s]
80%|######## | 400/500 [03:33<00:53, 1.87it/s]
{'loss': 1.0148, 'grad_norm': 7.89182186126709, 'learning_rate': 5e-05, 'epoch': 0.15}
80%|######## | 400/500 [03:33<00:53, 1.87it/s]
80%|######## | 401/500 [03:34<00:52, 1.87it/s]
80%|######## | 402/500 [03:34<00:52, 1.87it/s]
81%|######## | 403/500 [03:35<00:51, 1.87it/s]
81%|######## | 404/500 [03:35<00:51, 1.87it/s]
81%|########1 | 405/500 [03:36<00:50, 1.87it/s]
81%|########1 | 406/500 [03:36<00:50, 1.87it/s]
81%|########1 | 407/500 [03:37<00:49, 1.87it/s]
82%|########1 | 408/500 [03:37<00:49, 1.87it/s]
82%|########1 | 409/500 [03:38<00:48, 1.87it/s]
82%|########2 | 410/500 [03:38<00:48, 1.87it/s]
82%|########2 | 411/500 [03:39<00:47, 1.87it/s]
82%|########2 | 412/500 [03:40<00:47, 1.87it/s]
83%|########2 | 413/500 [03:40<00:46, 1.87it/s]
83%|########2 | 414/500 [03:41<00:45, 1.87it/s]
83%|########2 | 415/500 [03:41<00:45, 1.87it/s]
83%|########3 | 416/500 [03:42<00:44, 1.87it/s]
83%|########3 | 417/500 [03:42<00:44, 1.87it/s]
84%|########3 | 418/500 [03:43<00:43, 1.87it/s]
84%|########3 | 419/500 [03:43<00:43, 1.87it/s]
84%|########4 | 420/500 [03:44<00:42, 1.87it/s]
84%|########4 | 421/500 [03:44<00:42, 1.87it/s]
84%|########4 | 422/500 [03:45<00:41, 1.87it/s]
85%|########4 | 423/500 [03:45<00:41, 1.87it/s]
85%|########4 | 424/500 [03:46<00:40, 1.87it/s]
85%|########5 | 425/500 [03:46<00:40, 1.87it/s]
85%|########5 | 426/500 [03:47<00:39, 1.87it/s]
85%|########5 | 427/500 [03:48<00:39, 1.87it/s]
86%|########5 | 428/500 [03:48<00:38, 1.87it/s]
86%|########5 | 429/500 [03:49<00:37, 1.87it/s]
86%|########6 | 430/500 [03:49<00:37, 1.87it/s]
86%|########6 | 431/500 [03:50<00:36, 1.87it/s]
86%|########6 | 432/500 [03:50<00:36, 1.87it/s]
87%|########6 | 433/500 [03:51<00:35, 1.87it/s]
87%|########6 | 434/500 [03:51<00:35, 1.87it/s]
87%|########7 | 435/500 [03:52<00:34, 1.87it/s]
87%|########7 | 436/500 [03:52<00:34, 1.87it/s]
87%|########7 | 437/500 [03:53<00:33, 1.87it/s]
88%|########7 | 438/500 [03:53<00:33, 1.87it/s]
88%|########7 | 439/500 [03:54<00:32, 1.87it/s]
88%|########8 | 440/500 [03:54<00:32, 1.87it/s]
88%|########8 | 441/500 [03:55<00:31, 1.87it/s]
88%|########8 | 442/500 [03:56<00:30, 1.87it/s]
89%|########8 | 443/500 [03:56<00:30, 1.87it/s]
89%|########8 | 444/500 [03:57<00:29, 1.87it/s]
89%|########9 | 445/500 [03:57<00:29, 1.87it/s]
89%|########9 | 446/500 [03:58<00:28, 1.87it/s]
89%|########9 | 447/500 [03:58<00:28, 1.87it/s]
90%|########9 | 448/500 [03:59<00:27, 1.87it/s]
90%|########9 | 449/500 [03:59<00:27, 1.87it/s]
90%|######### | 450/500 [04:00<00:26, 1.87it/s]
{'loss': 1.0004, 'grad_norm': 8.658374786376953, 'learning_rate': 5e-05, 'epoch': 0.16}
90%|######### | 450/500 [04:00<00:26, 1.87it/s]
90%|######### | 451/500 [04:00<00:26, 1.87it/s]
90%|######### | 452/500 [04:01<00:25, 1.87it/s]
91%|######### | 453/500 [04:01<00:25, 1.87it/s]
91%|######### | 454/500 [04:02<00:24, 1.87it/s]
91%|#########1| 455/500 [04:02<00:24, 1.87it/s]
91%|#########1| 456/500 [04:03<00:23, 1.87it/s]
91%|#########1| 457/500 [04:04<00:22, 1.87it/s]
92%|#########1| 458/500 [04:04<00:22, 1.87it/s]
92%|#########1| 459/500 [04:05<00:21, 1.87it/s]
92%|#########2| 460/500 [04:05<00:21, 1.87it/s]
92%|#########2| 461/500 [04:06<00:20, 1.87it/s]
92%|#########2| 462/500 [04:06<00:20, 1.87it/s]
93%|#########2| 463/500 [04:07<00:19, 1.87it/s]
93%|#########2| 464/500 [04:07<00:19, 1.87it/s]
93%|#########3| 465/500 [04:08<00:18, 1.87it/s]
93%|#########3| 466/500 [04:08<00:18, 1.87it/s]
93%|#########3| 467/500 [04:09<00:17, 1.87it/s]
94%|#########3| 468/500 [04:09<00:17, 1.87it/s]
94%|#########3| 469/500 [04:10<00:16, 1.87it/s]
94%|#########3| 470/500 [04:11<00:16, 1.87it/s]
94%|#########4| 471/500 [04:11<00:15, 1.87it/s]
94%|#########4| 472/500 [04:12<00:14, 1.87it/s]
95%|#########4| 473/500 [04:12<00:14, 1.87it/s]
95%|#########4| 474/500 [04:13<00:13, 1.87it/s]
95%|#########5| 475/500 [04:13<00:13, 1.87it/s]
95%|#########5| 476/500 [04:14<00:12, 1.87it/s]
95%|#########5| 477/500 [04:14<00:12, 1.87it/s]
96%|#########5| 478/500 [04:15<00:11, 1.87it/s]
96%|#########5| 479/500 [04:15<00:11, 1.87it/s]
96%|#########6| 480/500 [04:16<00:10, 1.87it/s]
96%|#########6| 481/500 [04:16<00:10, 1.87it/s]
96%|#########6| 482/500 [04:17<00:09, 1.87it/s]
97%|#########6| 483/500 [04:17<00:09, 1.87it/s]
97%|#########6| 484/500 [04:18<00:08, 1.87it/s]
97%|#########7| 485/500 [04:19<00:08, 1.87it/s]
97%|#########7| 486/500 [04:19<00:07, 1.87it/s]
97%|#########7| 487/500 [04:20<00:06, 1.87it/s]
98%|#########7| 488/500 [04:20<00:06, 1.87it/s]
98%|#########7| 489/500 [04:21<00:05, 1.87it/s]
98%|#########8| 490/500 [04:21<00:05, 1.87it/s]
98%|#########8| 491/500 [04:22<00:04, 1.87it/s]
98%|#########8| 492/500 [04:22<00:04, 1.87it/s]
99%|#########8| 493/500 [04:23<00:03, 1.87it/s]
99%|#########8| 494/500 [04:23<00:03, 1.87it/s]
99%|#########9| 495/500 [04:24<00:02, 1.87it/s]
99%|#########9| 496/500 [04:24<00:02, 1.87it/s]
99%|#########9| 497/500 [04:25<00:01, 1.87it/s]
100%|#########9| 498/500 [04:25<00:01, 1.87it/s]
100%|#########9| 499/500 [04:26<00:00, 1.87it/s]
100%|##########| 500/500 [04:27<00:00, 1.87it/s]
{'loss': 1.0101, 'grad_norm': 10.829127311706543, 'learning_rate': 5e-05, 'epoch': 0.18}
100%|##########| 500/500 [04:27<00:00, 1.87it/s]
{'train_runtime': 276.4621, 'train_samples_per_second': 57.874, 'train_steps_per_second': 1.809, 'train_loss': 1.2056114730834961, 'epoch': 0.18}
100%|##########| 500/500 [04:36<00:00, 1.87it/s]
100%|##########| 500/500 [04:36<00:00, 1.81it/s]
tensor([[ 0.0000, -0.0181, 0.0000, 0.0115, 0.0000, 0.0313, 0.0000, -0.0736],
[ 0.0412, -0.0000, -0.0000, 0.0499, 0.0346, -0.0000, -0.0000, -0.0172],
[-0.0336, -0.0328, 0.0000, 0.0000, -0.0000, -0.0000, 0.0572, -0.0192],
[ 0.0000, -0.0000, -0.0506, 0.0289, 0.0965, 0.0000, -0.0000, -0.0766],
[ 0.0088, -0.0000, 0.0000, -0.0406, -0.0000, 0.0185, 0.0657, -0.0000],
[-0.0000, -0.0311, 0.0000, 0.0308, 0.0394, -0.0172, 0.0000, -0.0000],
[-0.1069, -0.0296, 0.0000, -0.0000, 0.0425, 0.0661, -0.0000, -0.0000],
[-0.0000, 0.0298, -0.0655, -0.0000, 0.0316, -0.0000, 0.0000, -0.0602]],
device='cuda:0', grad_fn=<SliceBackward0>)
<Axes: title={'center': 'Loss vs. # steps'}, xlabel='step', ylabel='loss'>
加速 2:4 稀疏模型进行推理¶
现在我们有了这种格式的模型,我们可以像在快速入门指南中一样加速它进行推理。
model = model.cuda().half()
# accelerate for sparsity
for fqn, module in model.named_modules():
if isinstance(module, nn.Linear) and "layer" in fqn:
module.weight = nn.Parameter(to_sparse_semi_structured(module.weight))
with torch.no_grad():
predictions = trainer.predict(tokenized_squad_dataset["validation"])
start_logits, end_logits = predictions.predictions
metrics_sparse = compute_metrics(
start_logits,
end_logits,
tokenized_squad_dataset["validation"],
squad_dataset["validation"],
)
print("sparse eval metrics: ", metrics_sparse)
sparse_perf = measure_execution_time(
model,
batch_sizes,
tokenized_squad_dataset["validation"],
)
print("sparse perf metrics: ", sparse_perf)
0%| | 0/43 [00:00<?, ?it/s]
5%|4 | 2/43 [00:01<00:27, 1.47it/s]
7%|6 | 3/43 [00:02<00:38, 1.04it/s]
9%|9 | 4/43 [00:04<00:43, 1.11s/it]
12%|#1 | 5/43 [00:05<00:45, 1.20s/it]
14%|#3 | 6/43 [00:06<00:46, 1.26s/it]
16%|#6 | 7/43 [00:08<00:46, 1.29s/it]
19%|#8 | 8/43 [00:09<00:45, 1.31s/it]
21%|## | 9/43 [00:10<00:45, 1.33s/it]
23%|##3 | 10/43 [00:12<00:44, 1.34s/it]
26%|##5 | 11/43 [00:13<00:43, 1.35s/it]
28%|##7 | 12/43 [00:14<00:41, 1.35s/it]
30%|### | 13/43 [00:16<00:40, 1.36s/it]
33%|###2 | 14/43 [00:17<00:39, 1.36s/it]
35%|###4 | 15/43 [00:19<00:38, 1.36s/it]
37%|###7 | 16/43 [00:20<00:36, 1.36s/it]
40%|###9 | 17/43 [00:21<00:35, 1.36s/it]
42%|####1 | 18/43 [00:23<00:34, 1.36s/it]
44%|####4 | 19/43 [00:24<00:32, 1.36s/it]
47%|####6 | 20/43 [00:25<00:31, 1.36s/it]
49%|####8 | 21/43 [00:27<00:30, 1.36s/it]
51%|#####1 | 22/43 [00:28<00:28, 1.36s/it]
53%|#####3 | 23/43 [00:30<00:27, 1.37s/it]
56%|#####5 | 24/43 [00:31<00:25, 1.37s/it]
58%|#####8 | 25/43 [00:32<00:24, 1.37s/it]
60%|###### | 26/43 [00:34<00:23, 1.37s/it]
63%|######2 | 27/43 [00:35<00:21, 1.36s/it]
65%|######5 | 28/43 [00:36<00:20, 1.36s/it]
67%|######7 | 29/43 [00:38<00:19, 1.37s/it]
70%|######9 | 30/43 [00:39<00:17, 1.37s/it]
72%|#######2 | 31/43 [00:40<00:16, 1.37s/it]
74%|#######4 | 32/43 [00:42<00:15, 1.37s/it]
77%|#######6 | 33/43 [00:43<00:13, 1.37s/it]
79%|#######9 | 34/43 [00:45<00:12, 1.37s/it]
81%|########1 | 35/43 [00:46<00:10, 1.37s/it]
84%|########3 | 36/43 [00:47<00:09, 1.36s/it]
86%|########6 | 37/43 [00:49<00:08, 1.36s/it]
88%|########8 | 38/43 [00:50<00:06, 1.36s/it]
91%|######### | 39/43 [00:51<00:05, 1.36s/it]
93%|#########3| 40/43 [00:53<00:04, 1.36s/it]
95%|#########5| 41/43 [00:54<00:02, 1.36s/it]
98%|#########7| 42/43 [00:55<00:01, 1.32s/it]
100%|##########| 43/43 [00:55<00:00, 1.02it/s]
100%|##########| 43/43 [00:55<00:00, 1.30s/it]
sparse eval metrics: {'exact_match': 71.36234626300852, 'f1': 80.46835071907935}
sparse perf metrics: {4: 16.749242000059894, '4_compile': 9.570522999638342, 16: 62.150992899842095, '16_compile': 34.265135300120164, 64: 243.14898499869741, '64_compile': 133.18610300029832, 256: 1195.5928600000334, '256_compile': 542.8887739999482}
在幅度修剪后重新训练我们的模型已恢复了模型被修剪时损失的几乎所有 F1 分数。同时,我们为 bs=16
获得了 1.28 倍的加速。请注意,并非所有形状都适合性能改进。当批次大小较小且在计算上花费的时间有限时,稀疏内核可能比其密集对应内核慢。
由于半结构化稀疏性作为张量子类实现,因此它与 torch.compile
兼容。当与 to_sparse_semi_structured
组合时,我们能够在 BERT 上实现 2 倍的加速。
指标 |
fp16 |
2:4 稀疏 |
增量/加速 |
已编译 |
---|---|---|---|---|
精确匹配(%) |
78.53 |
78.44 |
-0.09 |
|
F1(%) |
86.93 |
86.49 |
-0.44 |
|
时间(bs=4) |
11.10 |
15.54 |
0.71x |
否 |
时间(bs=16) |
19.35 |
15.74 |
1.23x |
否 |
时间(bs=64) |
72.71 |
59.41 |
1.22x |
否 |
时间(bs=256) |
286.65 |
247.63 |
1.14x |
否 |
时间(bs=4) |
7.59 |
7.46 |
1.02x |
是 |
时间(bs=16) |
11.47 |
9.68 |
1.18x |
是 |
时间(bs=64) |
41.57 |
36.92 |
1.13x |
是 |
时间(bs=256) |
159.22 |
142.23 |
1.12x |
是 |
结论¶
在本教程中,我们展示了如何将 BERT 修剪为 2:4 稀疏以及如何加速 2:4 稀疏模型以进行推理。通过利用我们的 SparseSemiStructuredTensor
子类,我们能够实现比 fp16 基线快 1.3 倍的速度,并且使用 torch.compile
可以达到 2 倍的速度。我们还通过微调 BERT 来恢复任何丢失的 F1 分数(密集 86.92 对比稀疏 86.48),证明了 2:4 稀疏性的优势。
脚本的总运行时间:(15 分钟 30.187 秒)