• 教程 >
  • (beta) 使用半结构化 (2:4) 稀疏性加速 BERT
快捷方式

(beta) 使用半结构化 (2:4) 稀疏性加速 BERT

**作者**:Jesse Cai

概述

与其他形式的稀疏性一样,**半结构化稀疏性**是一种模型优化技术,它试图以牺牲一些模型精度为代价来减少神经网络的内存开销和延迟。它也称为**细粒度结构化稀疏性**或**2:4 结构化稀疏性**。

半结构化稀疏性得名于其独特的稀疏性模式,其中每 2n 个元素中有 n 个被剪枝。我们最常看到 n=2,因此是 2:4 稀疏性。半结构化稀疏性特别有趣,因为它可以在 GPU 上高效加速,并且不会像其他稀疏性模式那样降低模型精度。

随着半结构化稀疏性支持的引入,可以在不离开 PyTorch 的情况下剪枝和加速半结构化稀疏模型。我们将在本教程中解释此过程。

../_static/img/pruning_flow.jpg

在本教程结束时,我们将把 BERT 问答模型稀疏化为 2:4 稀疏,并对其进行微调以恢复几乎所有 F1 损失(密集模型 86.92 对比稀疏模型 86.48)。最后,我们将加速此 2:4 稀疏模型以进行推理,从而获得 1.3 倍的加速。

要求

  • PyTorch >= 2.1。

  • 支持半结构化稀疏性的 NVIDIA GPU(计算能力 8.0+)。

本教程专为半结构化稀疏性以及稀疏性初学者设计。对于拥有现有 2:4 稀疏模型的用户,使用 to_sparse_semi_structured 加速推理中的 nn.Linear 层非常简单。这是一个示例

import torch
from torch.sparse import to_sparse_semi_structured, SparseSemiStructuredTensor
from torch.utils.benchmark import Timer
SparseSemiStructuredTensor._FORCE_CUTLASS = True

# mask Linear weight to be 2:4 sparse
mask = torch.Tensor([0, 0, 1, 1]).tile((3072, 2560)).cuda().bool()
linear = torch.nn.Linear(10240, 3072).half().cuda().eval()
linear.weight = torch.nn.Parameter(mask * linear.weight)

x = torch.rand(3072, 10240).half().cuda()

with torch.inference_mode():
    dense_output = linear(x)
    dense_t = Timer(stmt="linear(x)",
                    globals={"linear": linear,
                             "x": x}).blocked_autorange().median * 1e3

    # accelerate via SparseSemiStructuredTensor
    linear.weight = torch.nn.Parameter(to_sparse_semi_structured(linear.weight))

    sparse_output = linear(x)
    sparse_t = Timer(stmt="linear(x)",
                    globals={"linear": linear,
                             "x": x}).blocked_autorange().median * 1e3

    # sparse and dense matmul are numerically equivalent
    # On an A100 80GB, we see: `Dense: 0.870ms Sparse: 0.630ms | Speedup: 1.382x`
    assert torch.allclose(sparse_output, dense_output, atol=1e-3)
    print(f"Dense: {dense_t:.3f}ms Sparse: {sparse_t:.3f}ms | Speedup: {(dense_t / sparse_t):.3f}x")
Dense: 2.920ms Sparse: 1.646ms | Speedup: 1.773x

半结构化稀疏性解决了什么问题?

稀疏性的基本动机很简单:如果网络中存在零元素,可以通过不存储或计算这些参数来优化效率。然而,稀疏性的细节却很棘手。直接将参数置零并不会立即影响模型的延迟/内存开销。

这是因为密集张量仍然包含已修剪(为零)的元素,密集矩阵乘法内核仍然会对这些元素进行操作。为了实现性能提升,我们需要将密集内核替换为稀疏内核,后者会跳过涉及已修剪元素的计算。

为此,这些内核处理稀疏矩阵,这些矩阵不存储已修剪的元素,而是以压缩格式存储指定的元素。

对于半结构化稀疏性,我们存储原始参数的一半以及关于元素排列方式的一些压缩元数据。

存在许多不同的稀疏布局,每种布局都有其自身的优缺点。2:4 半结构化稀疏布局尤其引人注目,原因有两个。

  • 与以前的稀疏格式不同,半结构化稀疏性旨在在 GPU 上高效加速。2020 年,英伟达在其 Ampere 架构中引入了对半结构化稀疏性的硬件支持,并通过 CUTLASS cuSPARSELt 发布了快速的稀疏内核。

  • 同时,与其他稀疏格式相比,半结构化稀疏性往往对模型准确性的影响较小,尤其是在考虑更高级的修剪/微调方法时。英伟达在其 白皮书 中表明,只需进行一次幅度修剪以达到 2:4 稀疏,然后重新训练模型,即可获得几乎相同的模型准确率。

半结构化稀疏性处于一个最佳点,在较低的稀疏度水平(50%)下提供了 2 倍(理论)的加速,同时仍然足够细粒度以保留模型准确性。

网络

数据集

指标

密集 FP16

稀疏 FP16

ResNet-50

ImageNet

Top-1

76.1

76.2

ResNeXt-101_32x8d

ImageNet

Top-1

79.3

79.3

Xception

ImageNet

Top-1

79.2

79.2

SSD-RN50

COCO2017

bbAP

24.8

24.8

MaskRCNN-RN50

COCO2017

bbAP

37.9

37.9

FairSeq Transformer

EN-DE WMT14

BLEU

28.2

28.5

BERT-Large

SQuAD v1.1

F1

91.9

91.9

从工作流程的角度来看,半结构化稀疏性具有额外的优势。由于稀疏度级别固定为 50%,因此可以更容易地将模型稀疏化的问题分解为两个不同的子问题。

  • 准确性 - 我们如何找到一组 2:4 稀疏权重,以最大程度地减少模型的准确性下降?

  • 性能 - 我们如何加速 2:4 稀疏权重以进行推理并降低内存开销?

\[\begin{bmatrix} 1 & 1 & 0 & 0 \\ 0 & 0 & 1 & 1 \\ 1 & 0 & 0 & 0 \\ 0 & 0 & 1 & 1 \\ \end{bmatrix}\]

这两个问题之间的自然交接点是置零的密集张量。我们的推理解决方案旨在压缩和加速此格式的张量。我们预计许多用户会提出自定义掩码解决方案,因为这是一个活跃的研究领域。

现在,我们已经对半结构化稀疏性有了更多了解,让我们将其应用于在问答任务 SQuAD 上训练的 BERT 模型。

简介和设置

让我们首先导入所有需要的包。

# If you are running this in Google Colab, run:
# .. code-block: python
#
#    !pip install datasets transformers evaluate accelerate pandas
#
import os
os.environ["WANDB_DISABLED"] = "true"

import collections
import datasets
import evaluate
import numpy as np
import torch
import torch.utils.benchmark as benchmark
from torch import nn
from torch.sparse import to_sparse_semi_structured, SparseSemiStructuredTensor
from torch.ao.pruning import WeightNormSparsifier
import transformers

# force CUTLASS use if ``cuSPARSELt`` is not available
SparseSemiStructuredTensor._FORCE_CUTLASS = True
torch.manual_seed(100)
<torch._C.Generator object at 0x7efbfdcd3950>

我们还需要定义一些特定于数据集/任务的辅助函数。这些函数改编自 Hugging Face 课程作为参考。

def preprocess_validation_function(examples, tokenizer):
    inputs = tokenizer(
        [q.strip() for q in examples["question"]],
        examples["context"],
        max_length=384,
        truncation="only_second",
        return_overflowing_tokens=True,
        return_offsets_mapping=True,
        padding="max_length",
    )
    sample_map = inputs.pop("overflow_to_sample_mapping")
    example_ids = []

    for i in range(len(inputs["input_ids"])):
        sample_idx = sample_map[i]
        example_ids.append(examples["id"][sample_idx])
        sequence_ids = inputs.sequence_ids(i)
        offset = inputs["offset_mapping"][i]
        inputs["offset_mapping"][i] = [
            o if sequence_ids[k] == 1 else None for k, o in enumerate(offset)
        ]

    inputs["example_id"] = example_ids
    return inputs


def preprocess_train_function(examples, tokenizer):
    inputs = tokenizer(
        [q.strip() for q in examples["question"]],
        examples["context"],
        max_length=384,
        truncation="only_second",
        return_offsets_mapping=True,
        padding="max_length",
    )

    offset_mapping = inputs["offset_mapping"]
    answers = examples["answers"]
    start_positions = []
    end_positions = []

    for i, (offset, answer) in enumerate(zip(offset_mapping, answers)):
        start_char = answer["answer_start"][0]
        end_char = start_char + len(answer["text"][0])
        sequence_ids = inputs.sequence_ids(i)

        # Find the start and end of the context
        idx = 0
        while sequence_ids[idx] != 1:
            idx += 1
        context_start = idx
        while sequence_ids[idx] == 1:
            idx += 1
        context_end = idx - 1

        # If the answer is not fully inside the context, label it (0, 0)
        if offset[context_start][0] > end_char or offset[context_end][1] < start_char:
            start_positions.append(0)
            end_positions.append(0)
        else:
            # Otherwise it's the start and end token positions
            idx = context_start
            while idx <= context_end and offset[idx][0] <= start_char:
                idx += 1
            start_positions.append(idx - 1)

            idx = context_end
            while idx >= context_start and offset[idx][1] >= end_char:
                idx -= 1
            end_positions.append(idx + 1)

    inputs["start_positions"] = start_positions
    inputs["end_positions"] = end_positions
    return inputs


def compute_metrics(start_logits, end_logits, features, examples):
    n_best = 20
    max_answer_length = 30
    metric = evaluate.load("squad")

    example_to_features = collections.defaultdict(list)
    for idx, feature in enumerate(features):
        example_to_features[feature["example_id"]].append(idx)

    predicted_answers = []
    # for example in ``tqdm`` (examples):
    for example in examples:
        example_id = example["id"]
        context = example["context"]
        answers = []

        # Loop through all features associated with that example
        for feature_index in example_to_features[example_id]:
            start_logit = start_logits[feature_index]
            end_logit = end_logits[feature_index]
            offsets = features[feature_index]["offset_mapping"]

            start_indexes = np.argsort(start_logit)[-1 : -n_best - 1 : -1].tolist()
            end_indexes = np.argsort(end_logit)[-1 : -n_best - 1 : -1].tolist()
            for start_index in start_indexes:
                for end_index in end_indexes:
                    # Skip answers that are not fully in the context
                    if offsets[start_index] is None or offsets[end_index] is None:
                        continue
                    # Skip answers with a length that is either < 0
                    # or > max_answer_length
                    if (
                        end_index < start_index
                        or end_index - start_index + 1 > max_answer_length
                    ):
                        continue

                    answer = {
                        "text": context[
                            offsets[start_index][0] : offsets[end_index][1]
                        ],
                        "logit_score": start_logit[start_index] + end_logit[end_index],
                    }
                    answers.append(answer)

        # Select the answer with the best score
        if len(answers) > 0:
            best_answer = max(answers, key=lambda x: x["logit_score"])
            predicted_answers.append(
                {"id": example_id, "prediction_text": best_answer["text"]}
            )
        else:
            predicted_answers.append({"id": example_id, "prediction_text": ""})

    theoretical_answers = [
        {"id": ex["id"], "answers": ex["answers"]} for ex in examples
    ]
    return metric.compute(predictions=predicted_answers, references=theoretical_answers)

定义完这些后,我们只需要一个额外的辅助函数,它将帮助我们对模型进行基准测试。

def measure_execution_time(model, batch_sizes, dataset):
    dataset_for_model = dataset.remove_columns(["example_id", "offset_mapping"])
    dataset_for_model.set_format("torch")
    batch_size_to_time_sec = {}
    for batch_size in batch_sizes:
        batch = {
            k: dataset_for_model[k][:batch_size].cuda()
            for k in dataset_for_model.column_names
        }

        with torch.no_grad():
            baseline_predictions = model(**batch)
            timer = benchmark.Timer(
                stmt="model(**batch)", globals={"model": model, "batch": batch}
            )
            p50 = timer.blocked_autorange().median * 1000
            batch_size_to_time_sec[batch_size] = p50

            model_c = torch.compile(model, fullgraph=True)
            timer = benchmark.Timer(
                stmt="model(**batch)", globals={"model": model_c, "batch": batch}
            )
            p50 = timer.blocked_autorange().median * 1000
            batch_size_to_time_sec[f"{batch_size}_compile"] = p50
            new_predictions = model_c(**batch)

    return batch_size_to_time_sec

我们将从加载模型和分词器开始,然后设置数据集。

# load model
model_name = "bert-base-cased"
tokenizer = transformers.AutoTokenizer.from_pretrained(model_name)
model = transformers.AutoModelForQuestionAnswering.from_pretrained(model_name)
print(f"Loading tokenizer: {model_name}")
print(f"Loading model: {model_name}")

# set up train and val dataset
squad_dataset = datasets.load_dataset("squad")
tokenized_squad_dataset = {}
tokenized_squad_dataset["train"] = squad_dataset["train"].map(
    lambda x: preprocess_train_function(x, tokenizer), batched=True
)
tokenized_squad_dataset["validation"] = squad_dataset["validation"].map(
    lambda x: preprocess_validation_function(x, tokenizer),
    batched=True,
    remove_columns=squad_dataset["train"].column_names,
)
data_collator = transformers.DataCollatorWithPadding(tokenizer=tokenizer)
Some weights of BertForQuestionAnswering were not initialized from the model checkpoint at bert-base-cased and are newly initialized: ['qa_outputs.bias', 'qa_outputs.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
Loading tokenizer: bert-base-cased
Loading model: bert-base-cased

Downloading readme:   0%|          | 0.00/7.62k [00:00<?, ?B/s]
Downloading readme: 100%|##########| 7.62k/7.62k [00:00<00:00, 49.7MB/s]

Downloading data:   0%|          | 0.00/14.5M [00:00<?, ?B/s]
Downloading data:  73%|#######2  | 10.5M/14.5M [00:00<00:00, 68.5MB/s]
Downloading data: 100%|##########| 14.5M/14.5M [00:00<00:00, 87.3MB/s]

Downloading data:   0%|          | 0.00/1.82M [00:00<?, ?B/s]
Downloading data: 100%|##########| 1.82M/1.82M [00:00<00:00, 40.4MB/s]

Generating train split:   0%|          | 0/87599 [00:00<?, ? examples/s]
Generating train split:  73%|#######3  | 64000/87599 [00:00<00:00, 634428.21 examples/s]
Generating train split: 100%|##########| 87599/87599 [00:00<00:00, 627386.25 examples/s]

Generating validation split:   0%|          | 0/10570 [00:00<?, ? examples/s]
Generating validation split: 100%|##########| 10570/10570 [00:00<00:00, 588480.85 examples/s]

Map:   0%|          | 0/87599 [00:00<?, ? examples/s]
Map:   1%|1         | 1000/87599 [00:00<00:50, 1726.27 examples/s]
Map:   2%|2         | 2000/87599 [00:01<00:44, 1933.28 examples/s]
Map:   3%|3         | 3000/87599 [00:01<00:41, 2019.67 examples/s]
Map:   5%|4         | 4000/87599 [00:01<00:40, 2064.60 examples/s]
Map:   6%|5         | 5000/87599 [00:02<00:39, 2085.01 examples/s]
Map:   7%|6         | 6000/87599 [00:02<00:38, 2103.06 examples/s]
Map:   8%|7         | 7000/87599 [00:03<00:38, 2116.80 examples/s]
Map:   9%|9         | 8000/87599 [00:03<00:37, 2143.54 examples/s]
Map:  10%|#         | 9000/87599 [00:04<00:36, 2134.16 examples/s]
Map:  11%|#1        | 10000/87599 [00:04<00:36, 2138.86 examples/s]
Map:  13%|#2        | 11000/87599 [00:05<00:35, 2155.32 examples/s]
Map:  14%|#3        | 12000/87599 [00:05<00:35, 2150.59 examples/s]
Map:  15%|#4        | 13000/87599 [00:06<00:34, 2152.33 examples/s]
Map:  16%|#5        | 14000/87599 [00:06<00:34, 2154.70 examples/s]
Map:  17%|#7        | 15000/87599 [00:07<00:33, 2153.25 examples/s]
Map:  18%|#8        | 16000/87599 [00:07<00:33, 2166.28 examples/s]
Map:  19%|#9        | 17000/87599 [00:08<00:32, 2159.96 examples/s]
Map:  21%|##        | 18000/87599 [00:08<00:32, 2168.22 examples/s]
Map:  22%|##1       | 19000/87599 [00:08<00:31, 2163.13 examples/s]
Map:  23%|##2       | 20000/87599 [00:09<00:31, 2163.40 examples/s]
Map:  24%|##3       | 21000/87599 [00:09<00:30, 2158.70 examples/s]
Map:  25%|##5       | 22000/87599 [00:10<00:30, 2157.39 examples/s]
Map:  26%|##6       | 23000/87599 [00:10<00:29, 2160.35 examples/s]
Map:  27%|##7       | 24000/87599 [00:11<00:29, 2143.82 examples/s]
Map:  29%|##8       | 25000/87599 [00:11<00:29, 2149.47 examples/s]
Map:  30%|##9       | 26000/87599 [00:12<00:28, 2144.96 examples/s]
Map:  31%|###       | 27000/87599 [00:12<00:28, 2134.07 examples/s]
Map:  32%|###1      | 28000/87599 [00:13<00:28, 2121.42 examples/s]
Map:  33%|###3      | 29000/87599 [00:13<00:27, 2112.69 examples/s]
Map:  34%|###4      | 30000/87599 [00:14<00:27, 2100.83 examples/s]
Map:  35%|###5      | 31000/87599 [00:14<00:26, 2098.66 examples/s]
Map:  37%|###6      | 32000/87599 [00:15<00:26, 2095.24 examples/s]
Map:  38%|###7      | 33000/87599 [00:15<00:26, 2085.60 examples/s]
Map:  39%|###8      | 34000/87599 [00:16<00:25, 2081.96 examples/s]
Map:  40%|###9      | 35000/87599 [00:16<00:25, 2081.83 examples/s]
Map:  41%|####1     | 36000/87599 [00:16<00:24, 2083.05 examples/s]
Map:  42%|####2     | 37000/87599 [00:17<00:24, 2084.31 examples/s]
Map:  43%|####3     | 38000/87599 [00:17<00:23, 2078.45 examples/s]
Map:  45%|####4     | 39000/87599 [00:18<00:23, 2075.95 examples/s]
Map:  46%|####5     | 40000/87599 [00:18<00:22, 2071.82 examples/s]
Map:  47%|####6     | 41000/87599 [00:19<00:22, 2077.27 examples/s]
Map:  48%|####7     | 42000/87599 [00:19<00:21, 2083.74 examples/s]
Map:  49%|####9     | 43000/87599 [00:20<00:21, 2093.18 examples/s]
Map:  50%|#####     | 44000/87599 [00:20<00:20, 2087.08 examples/s]
Map:  51%|#####1    | 45000/87599 [00:21<00:20, 2090.34 examples/s]
Map:  53%|#####2    | 46000/87599 [00:21<00:19, 2084.59 examples/s]
Map:  54%|#####3    | 47000/87599 [00:22<00:19, 2084.82 examples/s]
Map:  55%|#####4    | 48000/87599 [00:22<00:18, 2089.12 examples/s]
Map:  56%|#####5    | 49000/87599 [00:23<00:18, 2090.96 examples/s]
Map:  57%|#####7    | 50000/87599 [00:23<00:17, 2091.74 examples/s]
Map:  58%|#####8    | 51000/87599 [00:24<00:17, 2083.05 examples/s]
Map:  59%|#####9    | 52000/87599 [00:24<00:17, 2084.47 examples/s]
Map:  61%|######    | 53000/87599 [00:25<00:16, 2086.25 examples/s]
Map:  62%|######1   | 54000/87599 [00:25<00:16, 2083.23 examples/s]
Map:  63%|######2   | 55000/87599 [00:26<00:15, 2079.12 examples/s]
Map:  64%|######3   | 56000/87599 [00:26<00:15, 2080.58 examples/s]
Map:  65%|######5   | 57000/87599 [00:27<00:14, 2082.59 examples/s]
Map:  66%|######6   | 58000/87599 [00:27<00:14, 2078.49 examples/s]
Map:  67%|######7   | 59000/87599 [00:28<00:13, 2063.44 examples/s]
Map:  68%|######8   | 60000/87599 [00:28<00:13, 2075.64 examples/s]
Map:  70%|######9   | 61000/87599 [00:28<00:12, 2089.36 examples/s]
Map:  71%|#######   | 62000/87599 [00:29<00:12, 2090.93 examples/s]
Map:  72%|#######1  | 63000/87599 [00:29<00:11, 2086.71 examples/s]
Map:  73%|#######3  | 64000/87599 [00:30<00:11, 2083.57 examples/s]
Map:  74%|#######4  | 65000/87599 [00:30<00:10, 2078.96 examples/s]
Map:  75%|#######5  | 66000/87599 [00:31<00:10, 2080.12 examples/s]
Map:  76%|#######6  | 67000/87599 [00:31<00:09, 2085.24 examples/s]
Map:  78%|#######7  | 68000/87599 [00:32<00:09, 2083.06 examples/s]
Map:  79%|#######8  | 69000/87599 [00:32<00:08, 2077.96 examples/s]
Map:  80%|#######9  | 70000/87599 [00:33<00:08, 2076.96 examples/s]
Map:  81%|########1 | 71000/87599 [00:33<00:07, 2078.86 examples/s]
Map:  82%|########2 | 72000/87599 [00:34<00:07, 2073.63 examples/s]
Map:  83%|########3 | 73000/87599 [00:34<00:07, 2084.37 examples/s]
Map:  84%|########4 | 74000/87599 [00:35<00:06, 2082.80 examples/s]
Map:  86%|########5 | 75000/87599 [00:35<00:06, 2087.99 examples/s]
Map:  87%|########6 | 76000/87599 [00:36<00:05, 2088.58 examples/s]
Map:  88%|########7 | 77000/87599 [00:36<00:05, 2084.06 examples/s]
Map:  89%|########9 | 78000/87599 [00:37<00:04, 2081.21 examples/s]
Map:  90%|######### | 79000/87599 [00:37<00:04, 2073.82 examples/s]
Map:  91%|#########1| 80000/87599 [00:38<00:03, 2070.32 examples/s]
Map:  92%|#########2| 81000/87599 [00:38<00:03, 2073.04 examples/s]
Map:  94%|#########3| 82000/87599 [00:39<00:02, 2073.19 examples/s]
Map:  95%|#########4| 83000/87599 [00:39<00:02, 2057.16 examples/s]
Map:  96%|#########5| 84000/87599 [00:40<00:01, 2056.40 examples/s]
Map:  97%|#########7| 85000/87599 [00:42<00:03, 825.64 examples/s]
Map:  98%|#########8| 86000/87599 [00:43<00:01, 1006.96 examples/s]
Map:  99%|#########9| 87000/87599 [00:43<00:00, 1191.64 examples/s]
Map: 100%|##########| 87599/87599 [00:44<00:00, 1290.98 examples/s]
Map: 100%|##########| 87599/87599 [00:44<00:00, 1978.32 examples/s]

Map:   0%|          | 0/10570 [00:00<?, ? examples/s]
Map:   9%|9         | 1000/10570 [00:00<00:03, 2835.60 examples/s]
Map:  19%|#8        | 2000/10570 [00:00<00:02, 2861.25 examples/s]
Map:  28%|##8       | 3000/10570 [00:01<00:02, 2843.67 examples/s]
Map:  38%|###7      | 4000/10570 [00:01<00:02, 2805.33 examples/s]
Map:  47%|####7     | 5000/10570 [00:01<00:02, 2586.36 examples/s]
Map:  57%|#####6    | 6000/10570 [00:02<00:01, 2583.47 examples/s]
Map:  66%|######6   | 7000/10570 [00:02<00:01, 2594.94 examples/s]
Map:  76%|#######5  | 8000/10570 [00:03<00:00, 2606.32 examples/s]
Map:  85%|########5 | 9000/10570 [00:03<00:00, 2603.16 examples/s]
Map:  95%|#########4| 10000/10570 [00:03<00:00, 2629.37 examples/s]
Map: 100%|##########| 10570/10570 [00:04<00:00, 2582.67 examples/s]
Map: 100%|##########| 10570/10570 [00:04<00:00, 2634.93 examples/s]

建立基线

接下来,我们将对模型在 SQuAD 上进行快速基线训练。此任务要求我们的模型识别给定上下文(维基百科文章)中回答给定问题的文本段落或片段。运行以下代码会得到 86.9 的 F1 分数。这与英伟达报告的分数非常接近,差异可能是由于 BERT-base 与 BERT-large 或微调超参数造成的。

training_args = transformers.TrainingArguments(
    "trainer",
    num_train_epochs=1,
    lr_scheduler_type="constant",
    per_device_train_batch_size=32,
    per_device_eval_batch_size=256,
    logging_steps=50,
    # Limit max steps for tutorial runners. Delete the below line to see the reported accuracy numbers.
    max_steps=500,
    report_to=None,
)

trainer = transformers.Trainer(
    model,
    training_args,
    train_dataset=tokenized_squad_dataset["train"],
    eval_dataset=tokenized_squad_dataset["validation"],
    data_collator=data_collator,
    tokenizer=tokenizer,
)

trainer.train()

# batch sizes to compare for eval
batch_sizes = [4, 16, 64, 256]
# 2:4 sparsity require fp16, so we cast here for a fair comparison
with torch.autocast("cuda"):
    with torch.no_grad():
        predictions = trainer.predict(tokenized_squad_dataset["validation"])
        start_logits, end_logits = predictions.predictions
        fp16_baseline = compute_metrics(
            start_logits,
            end_logits,
            tokenized_squad_dataset["validation"],
            squad_dataset["validation"],
        )
        fp16_time = measure_execution_time(
            model,
            batch_sizes,
            tokenized_squad_dataset["validation"],
        )

print("fp16", fp16_baseline)
print("cuda_fp16 time", fp16_time)

import pandas as pd
df = pd.DataFrame(trainer.state.log_history)
df.plot.line(x='step', y='loss', title="Loss vs. # steps", ylabel="loss")
Loss vs. # steps
torch.distributed process group is initialized, but parallel_mode != ParallelMode.DISTRIBUTED. In order to use Torch DDP, launch your script with `python -m torch.distributed.launch
Using the `WANDB_DISABLED` environment variable is deprecated and will be removed in v5. Use the --report_to flag to control the integrations used for logging result (for instance --report_to none).
max_steps is given, it will override any value given in num_train_epochs

  0%|          | 0/500 [00:00<?, ?it/s]
  0%|          | 1/500 [00:00<04:41,  1.77it/s]
  0%|          | 2/500 [00:01<04:28,  1.85it/s]
  1%|          | 3/500 [00:01<04:25,  1.87it/s]
  1%|          | 4/500 [00:02<04:23,  1.88it/s]
  1%|1         | 5/500 [00:02<04:22,  1.89it/s]
  1%|1         | 6/500 [00:03<04:21,  1.89it/s]
  1%|1         | 7/500 [00:03<04:20,  1.89it/s]
  2%|1         | 8/500 [00:04<04:20,  1.89it/s]
  2%|1         | 9/500 [00:04<04:19,  1.89it/s]
  2%|2         | 10/500 [00:05<04:18,  1.89it/s]
  2%|2         | 11/500 [00:05<04:18,  1.89it/s]
  2%|2         | 12/500 [00:06<04:17,  1.89it/s]
  3%|2         | 13/500 [00:06<04:17,  1.89it/s]
  3%|2         | 14/500 [00:07<04:17,  1.89it/s]
  3%|3         | 15/500 [00:07<04:16,  1.89it/s]
  3%|3         | 16/500 [00:08<04:15,  1.89it/s]
  3%|3         | 17/500 [00:09<04:15,  1.89it/s]
  4%|3         | 18/500 [00:09<04:15,  1.89it/s]
  4%|3         | 19/500 [00:10<04:14,  1.89it/s]
  4%|4         | 20/500 [00:10<04:13,  1.89it/s]
  4%|4         | 21/500 [00:11<04:13,  1.89it/s]
  4%|4         | 22/500 [00:11<04:13,  1.89it/s]
  5%|4         | 23/500 [00:12<04:12,  1.89it/s]
  5%|4         | 24/500 [00:12<04:12,  1.89it/s]
  5%|5         | 25/500 [00:13<04:11,  1.89it/s]
  5%|5         | 26/500 [00:13<04:11,  1.89it/s]
  5%|5         | 27/500 [00:14<04:10,  1.89it/s]
  6%|5         | 28/500 [00:14<04:10,  1.89it/s]
  6%|5         | 29/500 [00:15<04:09,  1.89it/s]
  6%|6         | 30/500 [00:15<04:09,  1.88it/s]
  6%|6         | 31/500 [00:16<04:09,  1.88it/s]
  6%|6         | 32/500 [00:16<04:08,  1.88it/s]
  7%|6         | 33/500 [00:17<04:07,  1.88it/s]
  7%|6         | 34/500 [00:18<04:07,  1.88it/s]
  7%|7         | 35/500 [00:18<04:06,  1.88it/s]
  7%|7         | 36/500 [00:19<04:06,  1.88it/s]
  7%|7         | 37/500 [00:19<04:06,  1.88it/s]
  8%|7         | 38/500 [00:20<04:05,  1.88it/s]
  8%|7         | 39/500 [00:20<04:04,  1.88it/s]
  8%|8         | 40/500 [00:21<04:04,  1.88it/s]
  8%|8         | 41/500 [00:21<04:03,  1.88it/s]
  8%|8         | 42/500 [00:22<04:03,  1.88it/s]
  9%|8         | 43/500 [00:22<04:02,  1.88it/s]
  9%|8         | 44/500 [00:23<04:02,  1.88it/s]
  9%|9         | 45/500 [00:23<04:02,  1.88it/s]
  9%|9         | 46/500 [00:24<04:01,  1.88it/s]
  9%|9         | 47/500 [00:24<04:01,  1.88it/s]
 10%|9         | 48/500 [00:25<04:00,  1.88it/s]
 10%|9         | 49/500 [00:26<04:00,  1.88it/s]
 10%|#         | 50/500 [00:26<03:59,  1.88it/s]

{'loss': 3.8379, 'grad_norm': 14.342127799987793, 'learning_rate': 5e-05, 'epoch': 0.02}

 10%|#         | 50/500 [00:26<03:59,  1.88it/s]
 10%|#         | 51/500 [00:27<03:59,  1.88it/s]
 10%|#         | 52/500 [00:27<03:58,  1.88it/s]
 11%|#         | 53/500 [00:28<03:58,  1.88it/s]
 11%|#         | 54/500 [00:28<03:57,  1.88it/s]
 11%|#1        | 55/500 [00:29<03:57,  1.88it/s]
 11%|#1        | 56/500 [00:29<03:56,  1.88it/s]
 11%|#1        | 57/500 [00:30<03:56,  1.88it/s]
 12%|#1        | 58/500 [00:30<03:55,  1.88it/s]
 12%|#1        | 59/500 [00:31<03:55,  1.88it/s]
 12%|#2        | 60/500 [00:31<03:54,  1.88it/s]
 12%|#2        | 61/500 [00:32<03:54,  1.88it/s]
 12%|#2        | 62/500 [00:32<03:53,  1.88it/s]
 13%|#2        | 63/500 [00:33<03:53,  1.87it/s]
 13%|#2        | 64/500 [00:34<03:52,  1.87it/s]
 13%|#3        | 65/500 [00:34<03:51,  1.88it/s]
 13%|#3        | 66/500 [00:35<03:51,  1.87it/s]
 13%|#3        | 67/500 [00:35<03:50,  1.87it/s]
 14%|#3        | 68/500 [00:36<03:50,  1.87it/s]
 14%|#3        | 69/500 [00:36<03:49,  1.88it/s]
 14%|#4        | 70/500 [00:37<03:49,  1.88it/s]
 14%|#4        | 71/500 [00:37<03:48,  1.87it/s]
 14%|#4        | 72/500 [00:38<03:48,  1.88it/s]
 15%|#4        | 73/500 [00:38<03:47,  1.88it/s]
 15%|#4        | 74/500 [00:39<03:47,  1.88it/s]
 15%|#5        | 75/500 [00:39<03:46,  1.88it/s]
 15%|#5        | 76/500 [00:40<03:46,  1.88it/s]
 15%|#5        | 77/500 [00:40<03:45,  1.88it/s]
 16%|#5        | 78/500 [00:41<03:45,  1.88it/s]
 16%|#5        | 79/500 [00:42<03:44,  1.87it/s]
 16%|#6        | 80/500 [00:42<03:44,  1.87it/s]
 16%|#6        | 81/500 [00:43<03:43,  1.87it/s]
 16%|#6        | 82/500 [00:43<03:43,  1.87it/s]
 17%|#6        | 83/500 [00:44<03:42,  1.87it/s]
 17%|#6        | 84/500 [00:44<03:42,  1.87it/s]
 17%|#7        | 85/500 [00:45<03:41,  1.88it/s]
 17%|#7        | 86/500 [00:45<03:40,  1.88it/s]
 17%|#7        | 87/500 [00:46<03:39,  1.88it/s]
 18%|#7        | 88/500 [00:46<03:38,  1.88it/s]
 18%|#7        | 89/500 [00:47<03:38,  1.88it/s]
 18%|#8        | 90/500 [00:47<03:37,  1.88it/s]
 18%|#8        | 91/500 [00:48<03:36,  1.88it/s]
 18%|#8        | 92/500 [00:48<03:36,  1.88it/s]
 19%|#8        | 93/500 [00:49<03:35,  1.89it/s]
 19%|#8        | 94/500 [00:49<03:35,  1.89it/s]
 19%|#9        | 95/500 [00:50<03:34,  1.89it/s]
 19%|#9        | 96/500 [00:51<03:34,  1.89it/s]
 19%|#9        | 97/500 [00:51<03:33,  1.89it/s]
 20%|#9        | 98/500 [00:52<03:33,  1.89it/s]
 20%|#9        | 99/500 [00:52<03:32,  1.88it/s]
 20%|##        | 100/500 [00:53<03:32,  1.89it/s]

{'loss': 2.3803, 'grad_norm': 15.05634880065918, 'learning_rate': 5e-05, 'epoch': 0.04}

 20%|##        | 100/500 [00:53<03:32,  1.89it/s]
 20%|##        | 101/500 [00:53<03:31,  1.88it/s]
 20%|##        | 102/500 [00:54<03:31,  1.88it/s]
 21%|##        | 103/500 [00:54<03:30,  1.88it/s]
 21%|##        | 104/500 [00:55<03:30,  1.88it/s]
 21%|##1       | 105/500 [00:55<03:29,  1.88it/s]
 21%|##1       | 106/500 [00:56<03:29,  1.88it/s]
 21%|##1       | 107/500 [00:56<03:28,  1.88it/s]
 22%|##1       | 108/500 [00:57<03:28,  1.88it/s]
 22%|##1       | 109/500 [00:57<03:27,  1.89it/s]
 22%|##2       | 110/500 [00:58<03:27,  1.88it/s]
 22%|##2       | 111/500 [00:58<03:26,  1.89it/s]
 22%|##2       | 112/500 [00:59<03:25,  1.88it/s]
 23%|##2       | 113/500 [01:00<03:25,  1.89it/s]
 23%|##2       | 114/500 [01:00<03:24,  1.88it/s]
 23%|##3       | 115/500 [01:01<03:24,  1.88it/s]
 23%|##3       | 116/500 [01:01<03:23,  1.88it/s]
 23%|##3       | 117/500 [01:02<03:23,  1.88it/s]
 24%|##3       | 118/500 [01:02<03:22,  1.88it/s]
 24%|##3       | 119/500 [01:03<03:22,  1.88it/s]
 24%|##4       | 120/500 [01:03<03:21,  1.88it/s]
 24%|##4       | 121/500 [01:04<03:21,  1.88it/s]
 24%|##4       | 122/500 [01:04<03:20,  1.88it/s]
 25%|##4       | 123/500 [01:05<03:20,  1.88it/s]
 25%|##4       | 124/500 [01:05<03:19,  1.88it/s]
 25%|##5       | 125/500 [01:06<03:19,  1.88it/s]
 25%|##5       | 126/500 [01:06<03:18,  1.88it/s]
 25%|##5       | 127/500 [01:07<03:18,  1.88it/s]
 26%|##5       | 128/500 [01:08<03:17,  1.88it/s]
 26%|##5       | 129/500 [01:08<03:17,  1.88it/s]
 26%|##6       | 130/500 [01:09<03:16,  1.88it/s]
 26%|##6       | 131/500 [01:09<03:16,  1.88it/s]
 26%|##6       | 132/500 [01:10<03:15,  1.88it/s]
 27%|##6       | 133/500 [01:10<03:14,  1.88it/s]
 27%|##6       | 134/500 [01:11<03:14,  1.88it/s]
 27%|##7       | 135/500 [01:11<03:13,  1.88it/s]
 27%|##7       | 136/500 [01:12<03:13,  1.88it/s]
 27%|##7       | 137/500 [01:12<03:12,  1.88it/s]
 28%|##7       | 138/500 [01:13<03:12,  1.88it/s]
 28%|##7       | 139/500 [01:13<03:11,  1.88it/s]
 28%|##8       | 140/500 [01:14<03:11,  1.88it/s]
 28%|##8       | 141/500 [01:14<03:10,  1.88it/s]
 28%|##8       | 142/500 [01:15<03:10,  1.88it/s]
 29%|##8       | 143/500 [01:15<03:09,  1.88it/s]
 29%|##8       | 144/500 [01:16<03:09,  1.88it/s]
 29%|##9       | 145/500 [01:17<03:08,  1.88it/s]
 29%|##9       | 146/500 [01:17<03:08,  1.88it/s]
 29%|##9       | 147/500 [01:18<03:07,  1.88it/s]
 30%|##9       | 148/500 [01:18<03:06,  1.88it/s]
 30%|##9       | 149/500 [01:19<03:06,  1.88it/s]
 30%|###       | 150/500 [01:19<03:06,  1.88it/s]

{'loss': 1.8675, 'grad_norm': 11.2459716796875, 'learning_rate': 5e-05, 'epoch': 0.05}

 30%|###       | 150/500 [01:19<03:06,  1.88it/s]
 30%|###       | 151/500 [01:20<03:05,  1.88it/s]
 30%|###       | 152/500 [01:20<03:05,  1.88it/s]
 31%|###       | 153/500 [01:21<03:04,  1.88it/s]
 31%|###       | 154/500 [01:21<03:03,  1.88it/s]
 31%|###1      | 155/500 [01:22<03:03,  1.88it/s]
 31%|###1      | 156/500 [01:22<03:02,  1.88it/s]
 31%|###1      | 157/500 [01:23<03:02,  1.88it/s]
 32%|###1      | 158/500 [01:23<03:01,  1.88it/s]
 32%|###1      | 159/500 [01:24<03:01,  1.88it/s]
 32%|###2      | 160/500 [01:25<03:00,  1.88it/s]
 32%|###2      | 161/500 [01:25<03:00,  1.88it/s]
 32%|###2      | 162/500 [01:26<02:59,  1.88it/s]
 33%|###2      | 163/500 [01:26<02:59,  1.88it/s]
 33%|###2      | 164/500 [01:27<02:58,  1.88it/s]
 33%|###3      | 165/500 [01:27<02:57,  1.88it/s]
 33%|###3      | 166/500 [01:28<02:57,  1.88it/s]
 33%|###3      | 167/500 [01:28<02:56,  1.88it/s]
 34%|###3      | 168/500 [01:29<02:56,  1.88it/s]
 34%|###3      | 169/500 [01:29<02:55,  1.88it/s]
 34%|###4      | 170/500 [01:30<02:55,  1.88it/s]
 34%|###4      | 171/500 [01:30<02:54,  1.88it/s]
 34%|###4      | 172/500 [01:31<02:54,  1.88it/s]
 35%|###4      | 173/500 [01:31<02:53,  1.88it/s]
 35%|###4      | 174/500 [01:32<02:53,  1.88it/s]
 35%|###5      | 175/500 [01:33<02:52,  1.88it/s]
 35%|###5      | 176/500 [01:33<02:51,  1.88it/s]
 35%|###5      | 177/500 [01:34<02:51,  1.88it/s]
 36%|###5      | 178/500 [01:34<02:50,  1.88it/s]
 36%|###5      | 179/500 [01:35<02:50,  1.88it/s]
 36%|###6      | 180/500 [01:35<02:49,  1.89it/s]
 36%|###6      | 181/500 [01:36<02:49,  1.89it/s]
 36%|###6      | 182/500 [01:36<02:48,  1.89it/s]
 37%|###6      | 183/500 [01:37<02:48,  1.89it/s]
 37%|###6      | 184/500 [01:37<02:47,  1.88it/s]
 37%|###7      | 185/500 [01:38<02:47,  1.89it/s]
 37%|###7      | 186/500 [01:38<02:46,  1.88it/s]
 37%|###7      | 187/500 [01:39<02:45,  1.89it/s]
 38%|###7      | 188/500 [01:39<02:45,  1.89it/s]
 38%|###7      | 189/500 [01:40<02:44,  1.89it/s]
 38%|###8      | 190/500 [01:40<02:44,  1.89it/s]
 38%|###8      | 191/500 [01:41<02:43,  1.89it/s]
 38%|###8      | 192/500 [01:42<02:43,  1.89it/s]
 39%|###8      | 193/500 [01:42<02:42,  1.89it/s]
 39%|###8      | 194/500 [01:43<02:42,  1.89it/s]
 39%|###9      | 195/500 [01:43<02:41,  1.89it/s]
 39%|###9      | 196/500 [01:44<02:41,  1.89it/s]
 39%|###9      | 197/500 [01:44<02:40,  1.89it/s]
 40%|###9      | 198/500 [01:45<02:39,  1.89it/s]
 40%|###9      | 199/500 [01:45<02:39,  1.89it/s]
 40%|####      | 200/500 [01:46<02:38,  1.89it/s]

{'loss': 1.7281, 'grad_norm': 12.751557350158691, 'learning_rate': 5e-05, 'epoch': 0.07}

 40%|####      | 200/500 [01:46<02:38,  1.89it/s]
 40%|####      | 201/500 [01:46<02:38,  1.89it/s]
 40%|####      | 202/500 [01:47<02:37,  1.89it/s]
 41%|####      | 203/500 [01:47<02:37,  1.89it/s]
 41%|####      | 204/500 [01:48<02:36,  1.89it/s]
 41%|####1     | 205/500 [01:48<02:36,  1.89it/s]
 41%|####1     | 206/500 [01:49<02:35,  1.89it/s]
 41%|####1     | 207/500 [01:49<02:35,  1.89it/s]
 42%|####1     | 208/500 [01:50<02:34,  1.89it/s]
 42%|####1     | 209/500 [01:51<02:34,  1.88it/s]
 42%|####2     | 210/500 [01:51<02:34,  1.88it/s]
 42%|####2     | 211/500 [01:52<02:33,  1.88it/s]
 42%|####2     | 212/500 [01:52<02:33,  1.88it/s]
 43%|####2     | 213/500 [01:53<02:32,  1.88it/s]
 43%|####2     | 214/500 [01:53<02:32,  1.88it/s]
 43%|####3     | 215/500 [01:54<02:31,  1.88it/s]
 43%|####3     | 216/500 [01:54<02:31,  1.88it/s]
 43%|####3     | 217/500 [01:55<02:30,  1.88it/s]
 44%|####3     | 218/500 [01:55<02:30,  1.88it/s]
 44%|####3     | 219/500 [01:56<02:29,  1.88it/s]
 44%|####4     | 220/500 [01:56<02:29,  1.88it/s]
 44%|####4     | 221/500 [01:57<02:28,  1.88it/s]
 44%|####4     | 222/500 [01:57<02:28,  1.88it/s]
 45%|####4     | 223/500 [01:58<02:27,  1.88it/s]
 45%|####4     | 224/500 [01:59<02:27,  1.88it/s]
 45%|####5     | 225/500 [01:59<02:26,  1.88it/s]
 45%|####5     | 226/500 [02:00<02:26,  1.88it/s]
 45%|####5     | 227/500 [02:00<02:25,  1.88it/s]
 46%|####5     | 228/500 [02:01<02:24,  1.88it/s]
 46%|####5     | 229/500 [02:01<02:24,  1.88it/s]
 46%|####6     | 230/500 [02:02<02:23,  1.88it/s]
 46%|####6     | 231/500 [02:02<02:23,  1.88it/s]
 46%|####6     | 232/500 [02:03<02:22,  1.88it/s]
 47%|####6     | 233/500 [02:03<02:22,  1.88it/s]
 47%|####6     | 234/500 [02:04<02:21,  1.88it/s]
 47%|####6     | 235/500 [02:04<02:21,  1.88it/s]
 47%|####7     | 236/500 [02:05<02:20,  1.88it/s]
 47%|####7     | 237/500 [02:05<02:20,  1.88it/s]
 48%|####7     | 238/500 [02:06<02:19,  1.88it/s]
 48%|####7     | 239/500 [02:07<02:19,  1.88it/s]
 48%|####8     | 240/500 [02:07<02:18,  1.88it/s]
 48%|####8     | 241/500 [02:08<02:18,  1.88it/s]
 48%|####8     | 242/500 [02:08<02:17,  1.88it/s]
 49%|####8     | 243/500 [02:09<02:16,  1.88it/s]
 49%|####8     | 244/500 [02:09<02:16,  1.88it/s]
 49%|####9     | 245/500 [02:10<02:15,  1.88it/s]
 49%|####9     | 246/500 [02:10<02:15,  1.88it/s]
 49%|####9     | 247/500 [02:11<02:14,  1.88it/s]
 50%|####9     | 248/500 [02:11<02:14,  1.88it/s]
 50%|####9     | 249/500 [02:12<02:13,  1.88it/s]
 50%|#####     | 250/500 [02:12<02:13,  1.88it/s]

{'loss': 1.5849, 'grad_norm': 12.03045654296875, 'learning_rate': 5e-05, 'epoch': 0.09}

 50%|#####     | 250/500 [02:12<02:13,  1.88it/s]
 50%|#####     | 251/500 [02:13<02:12,  1.88it/s]
 50%|#####     | 252/500 [02:13<02:12,  1.88it/s]
 51%|#####     | 253/500 [02:14<02:11,  1.88it/s]
 51%|#####     | 254/500 [02:15<02:11,  1.88it/s]
 51%|#####1    | 255/500 [02:15<02:10,  1.88it/s]
 51%|#####1    | 256/500 [02:16<02:10,  1.88it/s]
 51%|#####1    | 257/500 [02:16<02:09,  1.88it/s]
 52%|#####1    | 258/500 [02:17<02:09,  1.88it/s]
 52%|#####1    | 259/500 [02:17<02:08,  1.88it/s]
 52%|#####2    | 260/500 [02:18<02:07,  1.88it/s]
 52%|#####2    | 261/500 [02:18<02:07,  1.88it/s]
 52%|#####2    | 262/500 [02:19<02:06,  1.88it/s]
 53%|#####2    | 263/500 [02:19<02:06,  1.88it/s]
 53%|#####2    | 264/500 [02:20<02:05,  1.88it/s]
 53%|#####3    | 265/500 [02:20<02:05,  1.88it/s]
 53%|#####3    | 266/500 [02:21<02:04,  1.88it/s]
 53%|#####3    | 267/500 [02:21<02:04,  1.88it/s]
 54%|#####3    | 268/500 [02:22<02:03,  1.88it/s]
 54%|#####3    | 269/500 [02:23<02:03,  1.88it/s]
 54%|#####4    | 270/500 [02:23<02:02,  1.88it/s]
 54%|#####4    | 271/500 [02:24<02:01,  1.88it/s]
 54%|#####4    | 272/500 [02:24<02:01,  1.88it/s]
 55%|#####4    | 273/500 [02:25<02:00,  1.88it/s]
 55%|#####4    | 274/500 [02:25<02:00,  1.88it/s]
 55%|#####5    | 275/500 [02:26<01:59,  1.88it/s]
 55%|#####5    | 276/500 [02:26<01:59,  1.88it/s]
 55%|#####5    | 277/500 [02:27<01:58,  1.88it/s]
 56%|#####5    | 278/500 [02:27<01:58,  1.88it/s]
 56%|#####5    | 279/500 [02:28<01:57,  1.88it/s]
 56%|#####6    | 280/500 [02:28<01:57,  1.88it/s]
 56%|#####6    | 281/500 [02:29<01:56,  1.88it/s]
 56%|#####6    | 282/500 [02:29<01:56,  1.88it/s]
 57%|#####6    | 283/500 [02:30<01:55,  1.88it/s]
 57%|#####6    | 284/500 [02:30<01:55,  1.88it/s]
 57%|#####6    | 285/500 [02:31<01:54,  1.88it/s]
 57%|#####7    | 286/500 [02:32<01:54,  1.88it/s]
 57%|#####7    | 287/500 [02:32<01:53,  1.88it/s]
 58%|#####7    | 288/500 [02:33<01:52,  1.88it/s]
 58%|#####7    | 289/500 [02:33<01:52,  1.88it/s]
 58%|#####8    | 290/500 [02:34<01:51,  1.88it/s]
 58%|#####8    | 291/500 [02:34<01:51,  1.88it/s]
 58%|#####8    | 292/500 [02:35<01:50,  1.88it/s]
 59%|#####8    | 293/500 [02:35<01:50,  1.88it/s]
 59%|#####8    | 294/500 [02:36<01:49,  1.88it/s]
 59%|#####8    | 295/500 [02:36<01:49,  1.88it/s]
 59%|#####9    | 296/500 [02:37<01:48,  1.88it/s]
 59%|#####9    | 297/500 [02:37<01:48,  1.88it/s]
 60%|#####9    | 298/500 [02:38<01:47,  1.88it/s]
 60%|#####9    | 299/500 [02:38<01:47,  1.88it/s]
 60%|######    | 300/500 [02:39<01:46,  1.88it/s]

{'loss': 1.5287, 'grad_norm': 12.408611297607422, 'learning_rate': 5e-05, 'epoch': 0.11}

 60%|######    | 300/500 [02:39<01:46,  1.88it/s]
 60%|######    | 301/500 [02:40<01:46,  1.88it/s]
 60%|######    | 302/500 [02:40<01:45,  1.88it/s]
 61%|######    | 303/500 [02:41<01:45,  1.88it/s]
 61%|######    | 304/500 [02:41<01:44,  1.88it/s]
 61%|######1   | 305/500 [02:42<01:43,  1.88it/s]
 61%|######1   | 306/500 [02:42<01:43,  1.88it/s]
 61%|######1   | 307/500 [02:43<01:42,  1.88it/s]
 62%|######1   | 308/500 [02:43<01:42,  1.88it/s]
 62%|######1   | 309/500 [02:44<01:41,  1.88it/s]
 62%|######2   | 310/500 [02:44<01:41,  1.88it/s]
 62%|######2   | 311/500 [02:45<01:40,  1.88it/s]
 62%|######2   | 312/500 [02:45<01:40,  1.88it/s]
 63%|######2   | 313/500 [02:46<01:39,  1.88it/s]
 63%|######2   | 314/500 [02:46<01:39,  1.88it/s]
 63%|######3   | 315/500 [02:47<01:38,  1.88it/s]
 63%|######3   | 316/500 [02:48<01:38,  1.88it/s]
 63%|######3   | 317/500 [02:48<01:37,  1.88it/s]
 64%|######3   | 318/500 [02:49<01:36,  1.88it/s]
 64%|######3   | 319/500 [02:49<01:36,  1.88it/s]
 64%|######4   | 320/500 [02:50<01:35,  1.88it/s]
 64%|######4   | 321/500 [02:50<01:35,  1.88it/s]
 64%|######4   | 322/500 [02:51<01:34,  1.88it/s]
 65%|######4   | 323/500 [02:51<01:34,  1.88it/s]
 65%|######4   | 324/500 [02:52<01:33,  1.88it/s]
 65%|######5   | 325/500 [02:52<01:33,  1.88it/s]
 65%|######5   | 326/500 [02:53<01:32,  1.88it/s]
 65%|######5   | 327/500 [02:53<01:32,  1.88it/s]
 66%|######5   | 328/500 [02:54<01:31,  1.88it/s]
 66%|######5   | 329/500 [02:54<01:31,  1.88it/s]
 66%|######6   | 330/500 [02:55<01:30,  1.88it/s]
 66%|######6   | 331/500 [02:56<01:30,  1.87it/s]
 66%|######6   | 332/500 [02:56<01:29,  1.88it/s]
 67%|######6   | 333/500 [02:57<01:29,  1.88it/s]
 67%|######6   | 334/500 [02:57<01:28,  1.88it/s]
 67%|######7   | 335/500 [02:58<01:27,  1.88it/s]
 67%|######7   | 336/500 [02:58<01:27,  1.88it/s]
 67%|######7   | 337/500 [02:59<01:26,  1.88it/s]
 68%|######7   | 338/500 [02:59<01:26,  1.88it/s]
 68%|######7   | 339/500 [03:00<01:25,  1.88it/s]
 68%|######8   | 340/500 [03:00<01:25,  1.88it/s]
 68%|######8   | 341/500 [03:01<01:24,  1.88it/s]
 68%|######8   | 342/500 [03:01<01:24,  1.87it/s]
 69%|######8   | 343/500 [03:02<01:23,  1.88it/s]
 69%|######8   | 344/500 [03:02<01:23,  1.88it/s]
 69%|######9   | 345/500 [03:03<01:22,  1.88it/s]
 69%|######9   | 346/500 [03:04<01:22,  1.88it/s]
 69%|######9   | 347/500 [03:04<01:21,  1.88it/s]
 70%|######9   | 348/500 [03:05<01:21,  1.88it/s]
 70%|######9   | 349/500 [03:05<01:20,  1.88it/s]
 70%|#######   | 350/500 [03:06<01:19,  1.88it/s]

{'loss': 1.4997, 'grad_norm': 9.444414138793945, 'learning_rate': 5e-05, 'epoch': 0.13}

 70%|#######   | 350/500 [03:06<01:19,  1.88it/s]
 70%|#######   | 351/500 [03:06<01:19,  1.88it/s]
 70%|#######   | 352/500 [03:07<01:18,  1.88it/s]
 71%|#######   | 353/500 [03:07<01:18,  1.88it/s]
 71%|#######   | 354/500 [03:08<01:17,  1.88it/s]
 71%|#######1  | 355/500 [03:08<01:17,  1.88it/s]
 71%|#######1  | 356/500 [03:09<01:16,  1.88it/s]
 71%|#######1  | 357/500 [03:09<01:16,  1.88it/s]
 72%|#######1  | 358/500 [03:10<01:15,  1.88it/s]
 72%|#######1  | 359/500 [03:10<01:15,  1.88it/s]
 72%|#######2  | 360/500 [03:11<01:14,  1.88it/s]
 72%|#######2  | 361/500 [03:12<01:14,  1.88it/s]
 72%|#######2  | 362/500 [03:12<01:13,  1.88it/s]
 73%|#######2  | 363/500 [03:13<01:13,  1.88it/s]
 73%|#######2  | 364/500 [03:13<01:12,  1.88it/s]
 73%|#######3  | 365/500 [03:14<01:11,  1.88it/s]
 73%|#######3  | 366/500 [03:14<01:11,  1.88it/s]
 73%|#######3  | 367/500 [03:15<01:10,  1.88it/s]
 74%|#######3  | 368/500 [03:15<01:10,  1.88it/s]
 74%|#######3  | 369/500 [03:16<01:09,  1.88it/s]
 74%|#######4  | 370/500 [03:16<01:09,  1.88it/s]
 74%|#######4  | 371/500 [03:17<01:08,  1.88it/s]
 74%|#######4  | 372/500 [03:17<01:08,  1.88it/s]
 75%|#######4  | 373/500 [03:18<01:07,  1.88it/s]
 75%|#######4  | 374/500 [03:18<01:07,  1.88it/s]
 75%|#######5  | 375/500 [03:19<01:06,  1.88it/s]
 75%|#######5  | 376/500 [03:20<01:06,  1.88it/s]
 75%|#######5  | 377/500 [03:20<01:05,  1.88it/s]
 76%|#######5  | 378/500 [03:21<01:05,  1.88it/s]
 76%|#######5  | 379/500 [03:21<01:04,  1.88it/s]
 76%|#######6  | 380/500 [03:22<01:03,  1.88it/s]
 76%|#######6  | 381/500 [03:22<01:03,  1.88it/s]
 76%|#######6  | 382/500 [03:23<01:02,  1.88it/s]
 77%|#######6  | 383/500 [03:23<01:02,  1.88it/s]
 77%|#######6  | 384/500 [03:24<01:01,  1.88it/s]
 77%|#######7  | 385/500 [03:24<01:01,  1.88it/s]
 77%|#######7  | 386/500 [03:25<01:00,  1.88it/s]
 77%|#######7  | 387/500 [03:25<01:00,  1.87it/s]
 78%|#######7  | 388/500 [03:26<00:59,  1.88it/s]
 78%|#######7  | 389/500 [03:26<00:59,  1.88it/s]
 78%|#######8  | 390/500 [03:27<00:58,  1.88it/s]
 78%|#######8  | 391/500 [03:28<00:58,  1.87it/s]
 78%|#######8  | 392/500 [03:28<00:57,  1.88it/s]
 79%|#######8  | 393/500 [03:29<00:57,  1.88it/s]
 79%|#######8  | 394/500 [03:29<00:56,  1.88it/s]
 79%|#######9  | 395/500 [03:30<00:55,  1.88it/s]
 79%|#######9  | 396/500 [03:30<00:55,  1.88it/s]
 79%|#######9  | 397/500 [03:31<00:54,  1.88it/s]
 80%|#######9  | 398/500 [03:31<00:54,  1.88it/s]
 80%|#######9  | 399/500 [03:32<00:53,  1.88it/s]
 80%|########  | 400/500 [03:32<00:53,  1.87it/s]

{'loss': 1.3912, 'grad_norm': 10.076640129089355, 'learning_rate': 5e-05, 'epoch': 0.15}

 80%|########  | 400/500 [03:32<00:53,  1.87it/s]
 80%|########  | 401/500 [03:33<00:52,  1.87it/s]
 80%|########  | 402/500 [03:33<00:52,  1.88it/s]
 81%|########  | 403/500 [03:34<00:51,  1.87it/s]
 81%|########  | 404/500 [03:34<00:51,  1.88it/s]
 81%|########1 | 405/500 [03:35<00:50,  1.88it/s]
 81%|########1 | 406/500 [03:36<00:50,  1.88it/s]
 81%|########1 | 407/500 [03:36<00:49,  1.88it/s]
 82%|########1 | 408/500 [03:37<00:49,  1.88it/s]
 82%|########1 | 409/500 [03:37<00:48,  1.88it/s]
 82%|########2 | 410/500 [03:38<00:47,  1.88it/s]
 82%|########2 | 411/500 [03:38<00:47,  1.88it/s]
 82%|########2 | 412/500 [03:39<00:46,  1.88it/s]
 83%|########2 | 413/500 [03:39<00:46,  1.88it/s]
 83%|########2 | 414/500 [03:40<00:45,  1.88it/s]
 83%|########2 | 415/500 [03:40<00:45,  1.87it/s]
 83%|########3 | 416/500 [03:41<00:44,  1.87it/s]
 83%|########3 | 417/500 [03:41<00:44,  1.87it/s]
 84%|########3 | 418/500 [03:42<00:43,  1.88it/s]
 84%|########3 | 419/500 [03:42<00:43,  1.88it/s]
 84%|########4 | 420/500 [03:43<00:42,  1.88it/s]
 84%|########4 | 421/500 [03:44<00:42,  1.88it/s]
 84%|########4 | 422/500 [03:44<00:41,  1.88it/s]
 85%|########4 | 423/500 [03:45<00:41,  1.88it/s]
 85%|########4 | 424/500 [03:45<00:40,  1.88it/s]
 85%|########5 | 425/500 [03:46<00:39,  1.88it/s]
 85%|########5 | 426/500 [03:46<00:39,  1.88it/s]
 85%|########5 | 427/500 [03:47<00:38,  1.88it/s]
 86%|########5 | 428/500 [03:47<00:38,  1.88it/s]
 86%|########5 | 429/500 [03:48<00:37,  1.88it/s]
 86%|########6 | 430/500 [03:48<00:37,  1.88it/s]
 86%|########6 | 431/500 [03:49<00:36,  1.88it/s]
 86%|########6 | 432/500 [03:49<00:36,  1.88it/s]
 87%|########6 | 433/500 [03:50<00:35,  1.87it/s]
 87%|########6 | 434/500 [03:50<00:35,  1.88it/s]
 87%|########7 | 435/500 [03:51<00:34,  1.88it/s]
 87%|########7 | 436/500 [03:52<00:34,  1.88it/s]
 87%|########7 | 437/500 [03:52<00:33,  1.87it/s]
 88%|########7 | 438/500 [03:53<00:33,  1.87it/s]
 88%|########7 | 439/500 [03:53<00:32,  1.87it/s]
 88%|########8 | 440/500 [03:54<00:31,  1.88it/s]
 88%|########8 | 441/500 [03:54<00:31,  1.88it/s]
 88%|########8 | 442/500 [03:55<00:30,  1.88it/s]
 89%|########8 | 443/500 [03:55<00:30,  1.88it/s]
 89%|########8 | 444/500 [03:56<00:29,  1.88it/s]
 89%|########9 | 445/500 [03:56<00:29,  1.87it/s]
 89%|########9 | 446/500 [03:57<00:28,  1.88it/s]
 89%|########9 | 447/500 [03:57<00:28,  1.88it/s]
 90%|########9 | 448/500 [03:58<00:27,  1.88it/s]
 90%|########9 | 449/500 [03:58<00:27,  1.88it/s]
 90%|######### | 450/500 [03:59<00:26,  1.88it/s]

{'loss': 1.3439, 'grad_norm': 11.830737113952637, 'learning_rate': 5e-05, 'epoch': 0.16}

 90%|######### | 450/500 [03:59<00:26,  1.88it/s]
 90%|######### | 451/500 [04:00<00:26,  1.87it/s]
 90%|######### | 452/500 [04:00<00:25,  1.88it/s]
 91%|######### | 453/500 [04:01<00:25,  1.88it/s]
 91%|######### | 454/500 [04:01<00:24,  1.87it/s]
 91%|#########1| 455/500 [04:02<00:23,  1.88it/s]
 91%|#########1| 456/500 [04:02<00:23,  1.88it/s]
 91%|#########1| 457/500 [04:03<00:22,  1.87it/s]
 92%|#########1| 458/500 [04:03<00:22,  1.87it/s]
 92%|#########1| 459/500 [04:04<00:21,  1.88it/s]
 92%|#########2| 460/500 [04:04<00:21,  1.87it/s]
 92%|#########2| 461/500 [04:05<00:20,  1.88it/s]
 92%|#########2| 462/500 [04:05<00:20,  1.88it/s]
 93%|#########2| 463/500 [04:06<00:19,  1.88it/s]
 93%|#########2| 464/500 [04:06<00:19,  1.88it/s]
 93%|#########3| 465/500 [04:07<00:18,  1.88it/s]
 93%|#########3| 466/500 [04:08<00:18,  1.88it/s]
 93%|#########3| 467/500 [04:08<00:17,  1.87it/s]
 94%|#########3| 468/500 [04:09<00:17,  1.88it/s]
 94%|#########3| 469/500 [04:09<00:16,  1.88it/s]
 94%|#########3| 470/500 [04:10<00:15,  1.88it/s]
 94%|#########4| 471/500 [04:10<00:15,  1.88it/s]
 94%|#########4| 472/500 [04:11<00:14,  1.88it/s]
 95%|#########4| 473/500 [04:11<00:14,  1.88it/s]
 95%|#########4| 474/500 [04:12<00:13,  1.88it/s]
 95%|#########5| 475/500 [04:12<00:13,  1.88it/s]
 95%|#########5| 476/500 [04:13<00:12,  1.88it/s]
 95%|#########5| 477/500 [04:13<00:12,  1.88it/s]
 96%|#########5| 478/500 [04:14<00:11,  1.88it/s]
 96%|#########5| 479/500 [04:14<00:11,  1.88it/s]
 96%|#########6| 480/500 [04:15<00:10,  1.88it/s]
 96%|#########6| 481/500 [04:15<00:10,  1.89it/s]
 96%|#########6| 482/500 [04:16<00:09,  1.89it/s]
 97%|#########6| 483/500 [04:17<00:09,  1.89it/s]
 97%|#########6| 484/500 [04:17<00:08,  1.89it/s]
 97%|#########7| 485/500 [04:18<00:07,  1.89it/s]
 97%|#########7| 486/500 [04:18<00:07,  1.89it/s]
 97%|#########7| 487/500 [04:19<00:06,  1.89it/s]
 98%|#########7| 488/500 [04:19<00:06,  1.89it/s]
 98%|#########7| 489/500 [04:20<00:05,  1.89it/s]
 98%|#########8| 490/500 [04:20<00:05,  1.89it/s]
 98%|#########8| 491/500 [04:21<00:04,  1.89it/s]
 98%|#########8| 492/500 [04:21<00:04,  1.89it/s]
 99%|#########8| 493/500 [04:22<00:03,  1.89it/s]
 99%|#########8| 494/500 [04:22<00:03,  1.89it/s]
 99%|#########9| 495/500 [04:23<00:02,  1.89it/s]
 99%|#########9| 496/500 [04:23<00:02,  1.89it/s]
 99%|#########9| 497/500 [04:24<00:01,  1.89it/s]
100%|#########9| 498/500 [04:25<00:01,  1.89it/s]
100%|#########9| 499/500 [04:25<00:00,  1.89it/s]
100%|##########| 500/500 [04:26<00:00,  1.89it/s]

{'loss': 1.3419, 'grad_norm': 15.239953994750977, 'learning_rate': 5e-05, 'epoch': 0.18}

100%|##########| 500/500 [04:26<00:00,  1.89it/s]

{'train_runtime': 267.4795, 'train_samples_per_second': 59.818, 'train_steps_per_second': 1.869, 'train_loss': 1.850412796020508, 'epoch': 0.18}

100%|##########| 500/500 [04:27<00:00,  1.89it/s]
100%|##########| 500/500 [04:27<00:00,  1.87it/s]

  0%|          | 0/43 [00:00<?, ?it/s]
  5%|4         | 2/43 [00:00<00:13,  3.12it/s]
  7%|6         | 3/43 [00:01<00:18,  2.20it/s]
  9%|9         | 4/43 [00:01<00:20,  1.90it/s]
 12%|#1        | 5/43 [00:02<00:21,  1.76it/s]
 14%|#3        | 6/43 [00:03<00:21,  1.69it/s]
 16%|#6        | 7/43 [00:03<00:21,  1.64it/s]
 19%|#8        | 8/43 [00:04<00:21,  1.62it/s]
 21%|##        | 9/43 [00:05<00:21,  1.60it/s]
 23%|##3       | 10/43 [00:05<00:20,  1.58it/s]
 26%|##5       | 11/43 [00:06<00:20,  1.57it/s]
 28%|##7       | 12/43 [00:07<00:19,  1.57it/s]
 30%|###       | 13/43 [00:07<00:19,  1.56it/s]
 33%|###2      | 14/43 [00:08<00:18,  1.56it/s]
 35%|###4      | 15/43 [00:09<00:17,  1.56it/s]
 37%|###7      | 16/43 [00:09<00:17,  1.56it/s]
 40%|###9      | 17/43 [00:10<00:16,  1.56it/s]
 42%|####1     | 18/43 [00:10<00:16,  1.56it/s]
 44%|####4     | 19/43 [00:11<00:15,  1.56it/s]
 47%|####6     | 20/43 [00:12<00:14,  1.56it/s]
 49%|####8     | 21/43 [00:12<00:14,  1.55it/s]
 51%|#####1    | 22/43 [00:13<00:13,  1.55it/s]
 53%|#####3    | 23/43 [00:14<00:12,  1.55it/s]
 56%|#####5    | 24/43 [00:14<00:12,  1.55it/s]
 58%|#####8    | 25/43 [00:15<00:11,  1.55it/s]
 60%|######    | 26/43 [00:16<00:10,  1.55it/s]
 63%|######2   | 27/43 [00:16<00:10,  1.55it/s]
 65%|######5   | 28/43 [00:17<00:09,  1.55it/s]
 67%|######7   | 29/43 [00:18<00:09,  1.55it/s]
 70%|######9   | 30/43 [00:18<00:08,  1.55it/s]
 72%|#######2  | 31/43 [00:19<00:07,  1.55it/s]
 74%|#######4  | 32/43 [00:19<00:07,  1.55it/s]
 77%|#######6  | 33/43 [00:20<00:06,  1.55it/s]
 79%|#######9  | 34/43 [00:21<00:05,  1.55it/s]
 81%|########1 | 35/43 [00:21<00:05,  1.55it/s]
 84%|########3 | 36/43 [00:22<00:04,  1.55it/s]
 86%|########6 | 37/43 [00:23<00:03,  1.55it/s]
 88%|########8 | 38/43 [00:23<00:03,  1.55it/s]
 91%|######### | 39/43 [00:24<00:02,  1.55it/s]
 93%|#########3| 40/43 [00:25<00:01,  1.55it/s]
 95%|#########5| 41/43 [00:25<00:01,  1.55it/s]
 98%|#########7| 42/43 [00:26<00:00,  1.66it/s]
100%|##########| 43/43 [00:26<00:00,  1.59it/s]
100%|##########| 43/43 [00:26<00:00,  1.60it/s]

Downloading builder script:   0%|          | 0.00/4.53k [00:00<?, ?B/s]
Downloading builder script: 100%|##########| 4.53k/4.53k [00:00<00:00, 26.4MB/s]

Downloading extra modules:   0%|          | 0.00/3.32k [00:00<?, ?B/s]
Downloading extra modules: 100%|##########| 3.32k/3.32k [00:00<00:00, 33.4MB/s]
fp16 {'exact_match': 71.28666035950805, 'f1': 80.62619387545674}
cuda_fp16 time {4: 9.563420739996218, '4_compile': 9.004673000163166, 16: 31.877686199914024, '16_compile': 30.932841000321787, 64: 123.92699649990391, '64_compile': 104.97432750071312, 256: 476.81639599977643, '256_compile': 396.38552500036894}

<Axes: title={'center': 'Loss vs. # steps'}, xlabel='step', ylabel='loss'>

将 BERT 修剪为 2:4 稀疏

现在我们有了基线,是时候修剪 BERT 了。存在许多不同的修剪策略,但最常见的一种是**幅度修剪**,它旨在去除 L1 范数最低的权重。英伟达在其所有结果中都使用了幅度修剪,它是一种常见的基线。

为此,我们将使用 torch.ao.pruning 包,其中包含一个权重范数(幅度)稀疏器。这些稀疏器通过对模型中的权重张量应用掩码参数化来工作。这使它们能够通过屏蔽已修剪的权重来模拟稀疏性。

我们还需要决定对模型的哪些层应用稀疏性,在本例中是所有 nn.Linear 层,除了特定于任务的头输出。这是因为半结构化稀疏性具有 形状约束,而特定于任务的 nn.Linear 层不满足这些约束。

sparsifier = WeightNormSparsifier(
    # apply sparsity to all blocks
    sparsity_level=1.0,
    # shape of 4 elements is a block
    sparse_block_shape=(1, 4),
    # two zeros for every block of 4
    zeros_per_block=2
)

# add to config if ``nn.Linear`` and in the BERT model.
sparse_config = [
    {"tensor_fqn": f"{fqn}.weight"}
    for fqn, module in model.named_modules()
    if isinstance(module, nn.Linear) and "layer" in fqn
]

修剪模型的第一步是插入用于掩码模型权重的参数化。这是由准备步骤完成的。每当我们尝试访问 .weight 时,我们将得到 mask * weight 而不是原始权重。

# Prepare the model, insert fake-sparsity parametrizations for training
sparsifier.prepare(model, sparse_config)
print(model.bert.encoder.layer[0].output)
BertOutput(
  (dense): ParametrizedLinear(
    in_features=3072, out_features=768, bias=True
    (parametrizations): ModuleDict(
      (weight): ParametrizationList(
        (0): FakeSparsity()
      )
    )
  )
  (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
  (dropout): Dropout(p=0.1, inplace=False)
)

然后,我们将进行一次修剪步骤。所有修剪器都实现了一个 update_mask() 方法,该方法使用修剪器实现中确定的逻辑更新掩码。step 方法会为稀疏配置中指定的权重调用此 update_mask 函数。

我们还将评估模型以显示零样本修剪的准确性下降,即在不进行微调/重新训练的情况下进行修剪。

sparsifier.step()
with torch.autocast("cuda"):
    with torch.no_grad():
        predictions = trainer.predict(tokenized_squad_dataset["validation"])
    pruned = compute_metrics(
        *predictions.predictions,
        tokenized_squad_dataset["validation"],
        squad_dataset["validation"],
    )
print("pruned eval metrics:", pruned)
  0%|          | 0/43 [00:00<?, ?it/s]
  5%|4         | 2/43 [00:00<00:13,  3.10it/s]
  7%|6         | 3/43 [00:01<00:18,  2.19it/s]
  9%|9         | 4/43 [00:01<00:20,  1.90it/s]
 12%|#1        | 5/43 [00:02<00:21,  1.76it/s]
 14%|#3        | 6/43 [00:03<00:21,  1.69it/s]
 16%|#6        | 7/43 [00:03<00:21,  1.64it/s]
 19%|#8        | 8/43 [00:04<00:21,  1.61it/s]
 21%|##        | 9/43 [00:05<00:21,  1.59it/s]
 23%|##3       | 10/43 [00:05<00:20,  1.58it/s]
 26%|##5       | 11/43 [00:06<00:20,  1.57it/s]
 28%|##7       | 12/43 [00:07<00:19,  1.56it/s]
 30%|###       | 13/43 [00:07<00:19,  1.56it/s]
 33%|###2      | 14/43 [00:08<00:18,  1.55it/s]
 35%|###4      | 15/43 [00:09<00:18,  1.55it/s]
 37%|###7      | 16/43 [00:09<00:17,  1.55it/s]
 40%|###9      | 17/43 [00:10<00:16,  1.55it/s]
 42%|####1     | 18/43 [00:10<00:16,  1.55it/s]
 44%|####4     | 19/43 [00:11<00:15,  1.55it/s]
 47%|####6     | 20/43 [00:12<00:14,  1.55it/s]
 49%|####8     | 21/43 [00:12<00:14,  1.55it/s]
 51%|#####1    | 22/43 [00:13<00:13,  1.55it/s]
 53%|#####3    | 23/43 [00:14<00:12,  1.55it/s]
 56%|#####5    | 24/43 [00:14<00:12,  1.55it/s]
 58%|#####8    | 25/43 [00:15<00:11,  1.55it/s]
 60%|######    | 26/43 [00:16<00:10,  1.55it/s]
 63%|######2   | 27/43 [00:16<00:10,  1.55it/s]
 65%|######5   | 28/43 [00:17<00:09,  1.55it/s]
 67%|######7   | 29/43 [00:18<00:09,  1.55it/s]
 70%|######9   | 30/43 [00:18<00:08,  1.55it/s]
 72%|#######2  | 31/43 [00:19<00:07,  1.55it/s]
 74%|#######4  | 32/43 [00:20<00:07,  1.55it/s]
 77%|#######6  | 33/43 [00:20<00:06,  1.55it/s]
 79%|#######9  | 34/43 [00:21<00:05,  1.55it/s]
 81%|########1 | 35/43 [00:21<00:05,  1.55it/s]
 84%|########3 | 36/43 [00:22<00:04,  1.55it/s]
 86%|########6 | 37/43 [00:23<00:03,  1.55it/s]
 88%|########8 | 38/43 [00:23<00:03,  1.55it/s]
 91%|######### | 39/43 [00:24<00:02,  1.55it/s]
 93%|#########3| 40/43 [00:25<00:01,  1.55it/s]
 95%|#########5| 41/43 [00:25<00:01,  1.55it/s]
 98%|#########7| 42/43 [00:26<00:00,  1.65it/s]
100%|##########| 43/43 [00:26<00:00,  1.63it/s]
pruned eval metrics: {'exact_match': 29.87701040681173, 'f1': 41.483943521142855}

在此状态下,我们可以开始微调模型,更新不会被修剪的元素以更好地弥补准确性损失。一旦我们达到满意的状态,就可以调用 squash_mask 将掩码和权重融合在一起。这将删除参数化,我们得到一个置零的 2:4 密集模型。

trainer.train()
sparsifier.squash_mask()
torch.set_printoptions(edgeitems=4)
print(model.bert.encoder.layer[0].intermediate.dense.weight[:8, :8])

df["sparse_loss"] = pd.DataFrame(trainer.state.log_history)["loss"]
df.plot.line(x='step', y=["loss", "sparse_loss"], title="Loss vs. # steps", ylabel="loss")
Loss vs. # steps
  0%|          | 0/500 [00:00<?, ?it/s]
  0%|          | 1/500 [00:00<04:34,  1.82it/s]
  0%|          | 2/500 [00:01<04:28,  1.86it/s]
  1%|          | 3/500 [00:01<04:26,  1.87it/s]
  1%|          | 4/500 [00:02<04:24,  1.87it/s]
  1%|1         | 5/500 [00:02<04:23,  1.88it/s]
  1%|1         | 6/500 [00:03<04:22,  1.88it/s]
  1%|1         | 7/500 [00:03<04:22,  1.88it/s]
  2%|1         | 8/500 [00:04<04:21,  1.88it/s]
  2%|1         | 9/500 [00:04<04:21,  1.88it/s]
  2%|2         | 10/500 [00:05<04:20,  1.88it/s]
  2%|2         | 11/500 [00:05<04:20,  1.88it/s]
  2%|2         | 12/500 [00:06<04:19,  1.88it/s]
  3%|2         | 13/500 [00:06<04:19,  1.88it/s]
  3%|2         | 14/500 [00:07<04:19,  1.88it/s]
  3%|3         | 15/500 [00:07<04:18,  1.88it/s]
  3%|3         | 16/500 [00:08<04:18,  1.88it/s]
  3%|3         | 17/500 [00:09<04:17,  1.87it/s]
  4%|3         | 18/500 [00:09<04:17,  1.87it/s]
  4%|3         | 19/500 [00:10<04:16,  1.87it/s]
  4%|4         | 20/500 [00:10<04:16,  1.87it/s]
  4%|4         | 21/500 [00:11<04:15,  1.87it/s]
  4%|4         | 22/500 [00:11<04:15,  1.87it/s]
  5%|4         | 23/500 [00:12<04:14,  1.87it/s]
  5%|4         | 24/500 [00:12<04:14,  1.87it/s]
  5%|5         | 25/500 [00:13<04:13,  1.87it/s]
  5%|5         | 26/500 [00:13<04:13,  1.87it/s]
  5%|5         | 27/500 [00:14<04:12,  1.87it/s]
  6%|5         | 28/500 [00:14<04:12,  1.87it/s]
  6%|5         | 29/500 [00:15<04:11,  1.87it/s]
  6%|6         | 30/500 [00:16<04:11,  1.87it/s]
  6%|6         | 31/500 [00:16<04:10,  1.87it/s]
  6%|6         | 32/500 [00:17<04:10,  1.87it/s]
  7%|6         | 33/500 [00:17<04:09,  1.87it/s]
  7%|6         | 34/500 [00:18<04:09,  1.87it/s]
  7%|7         | 35/500 [00:18<04:08,  1.87it/s]
  7%|7         | 36/500 [00:19<04:07,  1.87it/s]
  7%|7         | 37/500 [00:19<04:06,  1.88it/s]
  8%|7         | 38/500 [00:20<04:05,  1.88it/s]
  8%|7         | 39/500 [00:20<04:05,  1.88it/s]
  8%|8         | 40/500 [00:21<04:04,  1.88it/s]
  8%|8         | 41/500 [00:21<04:03,  1.88it/s]
  8%|8         | 42/500 [00:22<04:03,  1.88it/s]
  9%|8         | 43/500 [00:22<04:02,  1.88it/s]
  9%|8         | 44/500 [00:23<04:02,  1.88it/s]
  9%|9         | 45/500 [00:23<04:01,  1.88it/s]
  9%|9         | 46/500 [00:24<04:01,  1.88it/s]
  9%|9         | 47/500 [00:25<04:00,  1.88it/s]
 10%|9         | 48/500 [00:25<04:00,  1.88it/s]
 10%|9         | 49/500 [00:26<03:59,  1.88it/s]
 10%|#         | 50/500 [00:26<03:59,  1.88it/s]

{'loss': 1.8793, 'grad_norm': 11.104135513305664, 'learning_rate': 5e-05, 'epoch': 0.02}

 10%|#         | 50/500 [00:26<03:59,  1.88it/s]
 10%|#         | 51/500 [00:27<03:58,  1.88it/s]
 10%|#         | 52/500 [00:27<03:58,  1.88it/s]
 11%|#         | 53/500 [00:28<03:57,  1.88it/s]
 11%|#         | 54/500 [00:28<03:57,  1.88it/s]
 11%|#1        | 55/500 [00:29<03:56,  1.88it/s]
 11%|#1        | 56/500 [00:29<03:56,  1.88it/s]
 11%|#1        | 57/500 [00:30<03:55,  1.88it/s]
 12%|#1        | 58/500 [00:30<03:55,  1.88it/s]
 12%|#1        | 59/500 [00:31<03:54,  1.88it/s]
 12%|#2        | 60/500 [00:31<03:54,  1.88it/s]
 12%|#2        | 61/500 [00:32<03:53,  1.88it/s]
 12%|#2        | 62/500 [00:33<03:52,  1.88it/s]
 13%|#2        | 63/500 [00:33<03:52,  1.88it/s]
 13%|#2        | 64/500 [00:34<03:51,  1.88it/s]
 13%|#3        | 65/500 [00:34<03:51,  1.88it/s]
 13%|#3        | 66/500 [00:35<03:50,  1.88it/s]
 13%|#3        | 67/500 [00:35<03:50,  1.88it/s]
 14%|#3        | 68/500 [00:36<03:49,  1.88it/s]
 14%|#3        | 69/500 [00:36<03:49,  1.88it/s]
 14%|#4        | 70/500 [00:37<03:48,  1.88it/s]
 14%|#4        | 71/500 [00:37<03:48,  1.88it/s]
 14%|#4        | 72/500 [00:38<03:47,  1.88it/s]
 15%|#4        | 73/500 [00:38<03:47,  1.88it/s]
 15%|#4        | 74/500 [00:39<03:46,  1.88it/s]
 15%|#5        | 75/500 [00:39<03:46,  1.88it/s]
 15%|#5        | 76/500 [00:40<03:45,  1.88it/s]
 15%|#5        | 77/500 [00:41<03:44,  1.88it/s]
 16%|#5        | 78/500 [00:41<03:44,  1.88it/s]
 16%|#5        | 79/500 [00:42<03:43,  1.88it/s]
 16%|#6        | 80/500 [00:42<03:43,  1.88it/s]
 16%|#6        | 81/500 [00:43<03:42,  1.88it/s]
 16%|#6        | 82/500 [00:43<03:42,  1.88it/s]
 17%|#6        | 83/500 [00:44<03:41,  1.88it/s]
 17%|#6        | 84/500 [00:44<03:41,  1.88it/s]
 17%|#7        | 85/500 [00:45<03:41,  1.88it/s]
 17%|#7        | 86/500 [00:45<03:40,  1.88it/s]
 17%|#7        | 87/500 [00:46<03:40,  1.88it/s]
 18%|#7        | 88/500 [00:46<03:39,  1.88it/s]
 18%|#7        | 89/500 [00:47<03:39,  1.88it/s]
 18%|#8        | 90/500 [00:47<03:38,  1.88it/s]
 18%|#8        | 91/500 [00:48<03:38,  1.88it/s]
 18%|#8        | 92/500 [00:49<03:37,  1.88it/s]
 19%|#8        | 93/500 [00:49<03:37,  1.88it/s]
 19%|#8        | 94/500 [00:50<03:36,  1.88it/s]
 19%|#9        | 95/500 [00:50<03:35,  1.88it/s]
 19%|#9        | 96/500 [00:51<03:35,  1.88it/s]
 19%|#9        | 97/500 [00:51<03:34,  1.88it/s]
 20%|#9        | 98/500 [00:52<03:34,  1.88it/s]
 20%|#9        | 99/500 [00:52<03:33,  1.88it/s]
 20%|##        | 100/500 [00:53<03:33,  1.88it/s]

{'loss': 1.4024, 'grad_norm': 8.995004653930664, 'learning_rate': 5e-05, 'epoch': 0.04}

 20%|##        | 100/500 [00:53<03:33,  1.88it/s]
 20%|##        | 101/500 [00:53<03:32,  1.87it/s]
 20%|##        | 102/500 [00:54<03:32,  1.88it/s]
 21%|##        | 103/500 [00:54<03:31,  1.88it/s]
 21%|##        | 104/500 [00:55<03:31,  1.87it/s]
 21%|##1       | 105/500 [00:55<03:30,  1.87it/s]
 21%|##1       | 106/500 [00:56<03:30,  1.87it/s]
 21%|##1       | 107/500 [00:57<03:29,  1.87it/s]
 22%|##1       | 108/500 [00:57<03:29,  1.87it/s]
 22%|##1       | 109/500 [00:58<03:28,  1.88it/s]
 22%|##2       | 110/500 [00:58<03:27,  1.88it/s]
 22%|##2       | 111/500 [00:59<03:27,  1.87it/s]
 22%|##2       | 112/500 [00:59<03:26,  1.88it/s]
 23%|##2       | 113/500 [01:00<03:26,  1.88it/s]
 23%|##2       | 114/500 [01:00<03:25,  1.87it/s]
 23%|##3       | 115/500 [01:01<03:25,  1.87it/s]
 23%|##3       | 116/500 [01:01<03:24,  1.87it/s]
 23%|##3       | 117/500 [01:02<03:24,  1.87it/s]
 24%|##3       | 118/500 [01:02<03:23,  1.87it/s]
 24%|##3       | 119/500 [01:03<03:23,  1.87it/s]
 24%|##4       | 120/500 [01:03<03:22,  1.87it/s]
 24%|##4       | 121/500 [01:04<03:22,  1.87it/s]
 24%|##4       | 122/500 [01:05<03:21,  1.87it/s]
 25%|##4       | 123/500 [01:05<03:21,  1.87it/s]
 25%|##4       | 124/500 [01:06<03:20,  1.87it/s]
 25%|##5       | 125/500 [01:06<03:20,  1.87it/s]
 25%|##5       | 126/500 [01:07<03:19,  1.87it/s]
 25%|##5       | 127/500 [01:07<03:19,  1.87it/s]
 26%|##5       | 128/500 [01:08<03:18,  1.87it/s]
 26%|##5       | 129/500 [01:08<03:18,  1.87it/s]
 26%|##6       | 130/500 [01:09<03:17,  1.87it/s]
 26%|##6       | 131/500 [01:09<03:17,  1.87it/s]
 26%|##6       | 132/500 [01:10<03:16,  1.87it/s]
 27%|##6       | 133/500 [01:10<03:16,  1.87it/s]
 27%|##6       | 134/500 [01:11<03:15,  1.87it/s]
 27%|##7       | 135/500 [01:11<03:15,  1.87it/s]
 27%|##7       | 136/500 [01:12<03:14,  1.87it/s]
 27%|##7       | 137/500 [01:13<03:13,  1.87it/s]
 28%|##7       | 138/500 [01:13<03:13,  1.87it/s]
 28%|##7       | 139/500 [01:14<03:12,  1.87it/s]
 28%|##8       | 140/500 [01:14<03:12,  1.87it/s]
 28%|##8       | 141/500 [01:15<03:11,  1.87it/s]
 28%|##8       | 142/500 [01:15<03:11,  1.87it/s]
 29%|##8       | 143/500 [01:16<03:10,  1.87it/s]
 29%|##8       | 144/500 [01:16<03:10,  1.87it/s]
 29%|##9       | 145/500 [01:17<03:09,  1.87it/s]
 29%|##9       | 146/500 [01:17<03:09,  1.87it/s]
 29%|##9       | 147/500 [01:18<03:08,  1.87it/s]
 30%|##9       | 148/500 [01:18<03:08,  1.87it/s]
 30%|##9       | 149/500 [01:19<03:07,  1.87it/s]
 30%|###       | 150/500 [01:19<03:07,  1.87it/s]

{'loss': 1.1782, 'grad_norm': 9.000022888183594, 'learning_rate': 5e-05, 'epoch': 0.05}

 30%|###       | 150/500 [01:19<03:07,  1.87it/s]
 30%|###       | 151/500 [01:20<03:06,  1.87it/s]
 30%|###       | 152/500 [01:21<03:06,  1.87it/s]
 31%|###       | 153/500 [01:21<03:05,  1.87it/s]
 31%|###       | 154/500 [01:22<03:04,  1.87it/s]
 31%|###1      | 155/500 [01:22<03:04,  1.87it/s]
 31%|###1      | 156/500 [01:23<03:03,  1.87it/s]
 31%|###1      | 157/500 [01:23<03:03,  1.87it/s]
 32%|###1      | 158/500 [01:24<03:02,  1.87it/s]
 32%|###1      | 159/500 [01:24<03:02,  1.87it/s]
 32%|###2      | 160/500 [01:25<03:01,  1.87it/s]
 32%|###2      | 161/500 [01:25<03:01,  1.87it/s]
 32%|###2      | 162/500 [01:26<03:00,  1.87it/s]
 33%|###2      | 163/500 [01:26<03:00,  1.87it/s]
 33%|###2      | 164/500 [01:27<02:59,  1.87it/s]
 33%|###3      | 165/500 [01:27<02:58,  1.87it/s]
 33%|###3      | 166/500 [01:28<02:58,  1.87it/s]
 33%|###3      | 167/500 [01:29<02:57,  1.87it/s]
 34%|###3      | 168/500 [01:29<02:57,  1.87it/s]
 34%|###3      | 169/500 [01:30<02:56,  1.87it/s]
 34%|###4      | 170/500 [01:30<02:56,  1.87it/s]
 34%|###4      | 171/500 [01:31<02:55,  1.87it/s]
 34%|###4      | 172/500 [01:31<02:55,  1.87it/s]
 35%|###4      | 173/500 [01:32<02:54,  1.87it/s]
 35%|###4      | 174/500 [01:32<02:54,  1.87it/s]
 35%|###5      | 175/500 [01:33<02:53,  1.87it/s]
 35%|###5      | 176/500 [01:33<02:53,  1.87it/s]
 35%|###5      | 177/500 [01:34<02:52,  1.87it/s]
 36%|###5      | 178/500 [01:34<02:52,  1.87it/s]
 36%|###5      | 179/500 [01:35<02:51,  1.87it/s]
 36%|###6      | 180/500 [01:35<02:51,  1.87it/s]
 36%|###6      | 181/500 [01:36<02:50,  1.87it/s]
 36%|###6      | 182/500 [01:37<02:50,  1.87it/s]
 37%|###6      | 183/500 [01:37<02:49,  1.87it/s]
 37%|###6      | 184/500 [01:38<02:48,  1.87it/s]
 37%|###7      | 185/500 [01:38<02:48,  1.87it/s]
 37%|###7      | 186/500 [01:39<02:47,  1.87it/s]
 37%|###7      | 187/500 [01:39<02:47,  1.87it/s]
 38%|###7      | 188/500 [01:40<02:46,  1.87it/s]
 38%|###7      | 189/500 [01:40<02:46,  1.87it/s]
 38%|###8      | 190/500 [01:41<02:45,  1.87it/s]
 38%|###8      | 191/500 [01:41<02:45,  1.87it/s]
 38%|###8      | 192/500 [01:42<02:44,  1.87it/s]
 39%|###8      | 193/500 [01:42<02:44,  1.87it/s]
 39%|###8      | 194/500 [01:43<02:43,  1.87it/s]
 39%|###9      | 195/500 [01:44<02:42,  1.87it/s]
 39%|###9      | 196/500 [01:44<02:42,  1.87it/s]
 39%|###9      | 197/500 [01:45<02:41,  1.87it/s]
 40%|###9      | 198/500 [01:45<02:41,  1.87it/s]
 40%|###9      | 199/500 [01:46<02:40,  1.87it/s]
 40%|####      | 200/500 [01:46<02:40,  1.87it/s]

{'loss': 1.2231, 'grad_norm': 8.50694465637207, 'learning_rate': 5e-05, 'epoch': 0.07}

 40%|####      | 200/500 [01:46<02:40,  1.87it/s]
 40%|####      | 201/500 [01:47<02:39,  1.87it/s]
 40%|####      | 202/500 [01:47<02:39,  1.87it/s]
 41%|####      | 203/500 [01:48<02:38,  1.87it/s]
 41%|####      | 204/500 [01:48<02:38,  1.87it/s]
 41%|####1     | 205/500 [01:49<02:37,  1.87it/s]
 41%|####1     | 206/500 [01:49<02:37,  1.87it/s]
 41%|####1     | 207/500 [01:50<02:36,  1.87it/s]
 42%|####1     | 208/500 [01:50<02:36,  1.87it/s]
 42%|####1     | 209/500 [01:51<02:35,  1.87it/s]
 42%|####2     | 210/500 [01:52<02:35,  1.87it/s]
 42%|####2     | 211/500 [01:52<02:34,  1.87it/s]
 42%|####2     | 212/500 [01:53<02:33,  1.87it/s]
 43%|####2     | 213/500 [01:53<02:33,  1.87it/s]
 43%|####2     | 214/500 [01:54<02:32,  1.87it/s]
 43%|####3     | 215/500 [01:54<02:32,  1.87it/s]
 43%|####3     | 216/500 [01:55<02:31,  1.87it/s]
 43%|####3     | 217/500 [01:55<02:31,  1.87it/s]
 44%|####3     | 218/500 [01:56<02:30,  1.87it/s]
 44%|####3     | 219/500 [01:56<02:30,  1.87it/s]
 44%|####4     | 220/500 [01:57<02:29,  1.87it/s]
 44%|####4     | 221/500 [01:57<02:29,  1.87it/s]
 44%|####4     | 222/500 [01:58<02:28,  1.87it/s]
 45%|####4     | 223/500 [01:58<02:28,  1.87it/s]
 45%|####4     | 224/500 [01:59<02:27,  1.87it/s]
 45%|####5     | 225/500 [02:00<02:27,  1.87it/s]
 45%|####5     | 226/500 [02:00<02:26,  1.87it/s]
 45%|####5     | 227/500 [02:01<02:26,  1.87it/s]
 46%|####5     | 228/500 [02:01<02:25,  1.87it/s]
 46%|####5     | 229/500 [02:02<02:24,  1.87it/s]
 46%|####6     | 230/500 [02:02<02:24,  1.87it/s]
 46%|####6     | 231/500 [02:03<02:23,  1.87it/s]
 46%|####6     | 232/500 [02:03<02:23,  1.87it/s]
 47%|####6     | 233/500 [02:04<02:22,  1.87it/s]
 47%|####6     | 234/500 [02:04<02:22,  1.87it/s]
 47%|####6     | 235/500 [02:05<02:21,  1.87it/s]
 47%|####7     | 236/500 [02:05<02:21,  1.87it/s]
 47%|####7     | 237/500 [02:06<02:20,  1.87it/s]
 48%|####7     | 238/500 [02:07<02:20,  1.87it/s]
 48%|####7     | 239/500 [02:07<02:19,  1.87it/s]
 48%|####8     | 240/500 [02:08<02:18,  1.87it/s]
 48%|####8     | 241/500 [02:08<02:18,  1.87it/s]
 48%|####8     | 242/500 [02:09<02:17,  1.87it/s]
 49%|####8     | 243/500 [02:09<02:17,  1.87it/s]
 49%|####8     | 244/500 [02:10<02:16,  1.87it/s]
 49%|####9     | 245/500 [02:10<02:16,  1.87it/s]
 49%|####9     | 246/500 [02:11<02:15,  1.87it/s]
 49%|####9     | 247/500 [02:11<02:15,  1.87it/s]
 50%|####9     | 248/500 [02:12<02:14,  1.87it/s]
 50%|####9     | 249/500 [02:12<02:14,  1.87it/s]
 50%|#####     | 250/500 [02:13<02:13,  1.87it/s]

{'loss': 1.1114, 'grad_norm': 7.177275657653809, 'learning_rate': 5e-05, 'epoch': 0.09}

 50%|#####     | 250/500 [02:13<02:13,  1.87it/s]
 50%|#####     | 251/500 [02:13<02:13,  1.87it/s]
 50%|#####     | 252/500 [02:14<02:12,  1.87it/s]
 51%|#####     | 253/500 [02:15<02:12,  1.87it/s]
 51%|#####     | 254/500 [02:15<02:11,  1.87it/s]
 51%|#####1    | 255/500 [02:16<02:10,  1.87it/s]
 51%|#####1    | 256/500 [02:16<02:10,  1.87it/s]
 51%|#####1    | 257/500 [02:17<02:09,  1.87it/s]
 52%|#####1    | 258/500 [02:17<02:09,  1.87it/s]
 52%|#####1    | 259/500 [02:18<02:08,  1.87it/s]
 52%|#####2    | 260/500 [02:18<02:08,  1.87it/s]
 52%|#####2    | 261/500 [02:19<02:07,  1.87it/s]
 52%|#####2    | 262/500 [02:19<02:07,  1.87it/s]
 53%|#####2    | 263/500 [02:20<02:06,  1.87it/s]
 53%|#####2    | 264/500 [02:20<02:06,  1.87it/s]
 53%|#####3    | 265/500 [02:21<02:05,  1.87it/s]
 53%|#####3    | 266/500 [02:21<02:05,  1.87it/s]
 53%|#####3    | 267/500 [02:22<02:04,  1.87it/s]
 54%|#####3    | 268/500 [02:23<02:03,  1.87it/s]
 54%|#####3    | 269/500 [02:23<02:03,  1.87it/s]
 54%|#####4    | 270/500 [02:24<02:02,  1.87it/s]
 54%|#####4    | 271/500 [02:24<02:02,  1.87it/s]
 54%|#####4    | 272/500 [02:25<02:01,  1.87it/s]
 55%|#####4    | 273/500 [02:25<02:01,  1.87it/s]
 55%|#####4    | 274/500 [02:26<02:00,  1.87it/s]
 55%|#####5    | 275/500 [02:26<02:00,  1.87it/s]
 55%|#####5    | 276/500 [02:27<01:59,  1.87it/s]
 55%|#####5    | 277/500 [02:27<01:59,  1.87it/s]
 56%|#####5    | 278/500 [02:28<01:58,  1.87it/s]
 56%|#####5    | 279/500 [02:28<01:58,  1.87it/s]
 56%|#####6    | 280/500 [02:29<01:57,  1.87it/s]
 56%|#####6    | 281/500 [02:29<01:57,  1.87it/s]
 56%|#####6    | 282/500 [02:30<01:56,  1.87it/s]
 57%|#####6    | 283/500 [02:31<01:55,  1.87it/s]
 57%|#####6    | 284/500 [02:31<01:55,  1.87it/s]
 57%|#####6    | 285/500 [02:32<01:54,  1.87it/s]
 57%|#####7    | 286/500 [02:32<01:54,  1.87it/s]
 57%|#####7    | 287/500 [02:33<01:53,  1.87it/s]
 58%|#####7    | 288/500 [02:33<01:53,  1.87it/s]
 58%|#####7    | 289/500 [02:34<01:52,  1.87it/s]
 58%|#####8    | 290/500 [02:34<01:52,  1.87it/s]
 58%|#####8    | 291/500 [02:35<01:51,  1.87it/s]
 58%|#####8    | 292/500 [02:35<01:51,  1.87it/s]
 59%|#####8    | 293/500 [02:36<01:50,  1.87it/s]
 59%|#####8    | 294/500 [02:36<01:50,  1.87it/s]
 59%|#####8    | 295/500 [02:37<01:49,  1.87it/s]
 59%|#####9    | 296/500 [02:37<01:48,  1.87it/s]
 59%|#####9    | 297/500 [02:38<01:48,  1.87it/s]
 60%|#####9    | 298/500 [02:39<01:47,  1.87it/s]
 60%|#####9    | 299/500 [02:39<01:47,  1.87it/s]
 60%|######    | 300/500 [02:40<01:46,  1.87it/s]

{'loss': 1.1299, 'grad_norm': 7.932921409606934, 'learning_rate': 5e-05, 'epoch': 0.11}

 60%|######    | 300/500 [02:40<01:46,  1.87it/s]
 60%|######    | 301/500 [02:40<01:46,  1.87it/s]
 60%|######    | 302/500 [02:41<01:45,  1.87it/s]
 61%|######    | 303/500 [02:41<01:45,  1.87it/s]
 61%|######    | 304/500 [02:42<01:44,  1.87it/s]
 61%|######1   | 305/500 [02:42<01:44,  1.87it/s]
 61%|######1   | 306/500 [02:43<01:43,  1.87it/s]
 61%|######1   | 307/500 [02:43<01:43,  1.87it/s]
 62%|######1   | 308/500 [02:44<01:42,  1.87it/s]
 62%|######1   | 309/500 [02:44<01:42,  1.87it/s]
 62%|######2   | 310/500 [02:45<01:41,  1.87it/s]
 62%|######2   | 311/500 [02:46<01:40,  1.87it/s]
 62%|######2   | 312/500 [02:46<01:40,  1.87it/s]
 63%|######2   | 313/500 [02:47<01:39,  1.87it/s]
 63%|######2   | 314/500 [02:47<01:39,  1.87it/s]
 63%|######3   | 315/500 [02:48<01:38,  1.87it/s]
 63%|######3   | 316/500 [02:48<01:38,  1.87it/s]
 63%|######3   | 317/500 [02:49<01:37,  1.87it/s]
 64%|######3   | 318/500 [02:49<01:37,  1.87it/s]
 64%|######3   | 319/500 [02:50<01:36,  1.87it/s]
 64%|######4   | 320/500 [02:50<01:36,  1.87it/s]
 64%|######4   | 321/500 [02:51<01:35,  1.87it/s]
 64%|######4   | 322/500 [02:51<01:35,  1.87it/s]
 65%|######4   | 323/500 [02:52<01:34,  1.87it/s]
 65%|######4   | 324/500 [02:52<01:34,  1.87it/s]
 65%|######5   | 325/500 [02:53<01:33,  1.87it/s]
 65%|######5   | 326/500 [02:54<01:32,  1.87it/s]
 65%|######5   | 327/500 [02:54<01:32,  1.87it/s]
 66%|######5   | 328/500 [02:55<01:31,  1.87it/s]
 66%|######5   | 329/500 [02:55<01:31,  1.87it/s]
 66%|######6   | 330/500 [02:56<01:30,  1.87it/s]
 66%|######6   | 331/500 [02:56<01:30,  1.87it/s]
 66%|######6   | 332/500 [02:57<01:29,  1.87it/s]
 67%|######6   | 333/500 [02:57<01:29,  1.87it/s]
 67%|######6   | 334/500 [02:58<01:28,  1.87it/s]
 67%|######7   | 335/500 [02:58<01:28,  1.87it/s]
 67%|######7   | 336/500 [02:59<01:27,  1.87it/s]
 67%|######7   | 337/500 [02:59<01:27,  1.87it/s]
 68%|######7   | 338/500 [03:00<01:26,  1.87it/s]
 68%|######7   | 339/500 [03:00<01:26,  1.87it/s]
 68%|######8   | 340/500 [03:01<01:25,  1.87it/s]
 68%|######8   | 341/500 [03:02<01:25,  1.87it/s]
 68%|######8   | 342/500 [03:02<01:24,  1.87it/s]
 69%|######8   | 343/500 [03:03<01:23,  1.87it/s]
 69%|######8   | 344/500 [03:03<01:23,  1.87it/s]
 69%|######9   | 345/500 [03:04<01:22,  1.87it/s]
 69%|######9   | 346/500 [03:04<01:22,  1.87it/s]
 69%|######9   | 347/500 [03:05<01:21,  1.87it/s]
 70%|######9   | 348/500 [03:05<01:21,  1.87it/s]
 70%|######9   | 349/500 [03:06<01:20,  1.87it/s]
 70%|#######   | 350/500 [03:06<01:20,  1.87it/s]

{'loss': 1.1066, 'grad_norm': 8.763031959533691, 'learning_rate': 5e-05, 'epoch': 0.13}

 70%|#######   | 350/500 [03:06<01:20,  1.87it/s]
 70%|#######   | 351/500 [03:07<01:19,  1.87it/s]
 70%|#######   | 352/500 [03:07<01:19,  1.87it/s]
 71%|#######   | 353/500 [03:08<01:18,  1.87it/s]
 71%|#######   | 354/500 [03:09<01:18,  1.87it/s]
 71%|#######1  | 355/500 [03:09<01:17,  1.87it/s]
 71%|#######1  | 356/500 [03:10<01:16,  1.87it/s]
 71%|#######1  | 357/500 [03:10<01:16,  1.87it/s]
 72%|#######1  | 358/500 [03:11<01:15,  1.87it/s]
 72%|#######1  | 359/500 [03:11<01:15,  1.87it/s]
 72%|#######2  | 360/500 [03:12<01:14,  1.87it/s]
 72%|#######2  | 361/500 [03:12<01:14,  1.87it/s]
 72%|#######2  | 362/500 [03:13<01:13,  1.87it/s]
 73%|#######2  | 363/500 [03:13<01:13,  1.87it/s]
 73%|#######2  | 364/500 [03:14<01:12,  1.87it/s]
 73%|#######3  | 365/500 [03:14<01:12,  1.87it/s]
 73%|#######3  | 366/500 [03:15<01:11,  1.87it/s]
 73%|#######3  | 367/500 [03:15<01:11,  1.87it/s]
 74%|#######3  | 368/500 [03:16<01:10,  1.87it/s]
 74%|#######3  | 369/500 [03:17<01:10,  1.87it/s]
 74%|#######4  | 370/500 [03:17<01:09,  1.87it/s]
 74%|#######4  | 371/500 [03:18<01:08,  1.87it/s]
 74%|#######4  | 372/500 [03:18<01:08,  1.87it/s]
 75%|#######4  | 373/500 [03:19<01:07,  1.87it/s]
 75%|#######4  | 374/500 [03:19<01:07,  1.87it/s]
 75%|#######5  | 375/500 [03:20<01:06,  1.87it/s]
 75%|#######5  | 376/500 [03:20<01:06,  1.87it/s]
 75%|#######5  | 377/500 [03:21<01:05,  1.87it/s]
 76%|#######5  | 378/500 [03:21<01:05,  1.87it/s]
 76%|#######5  | 379/500 [03:22<01:04,  1.87it/s]
 76%|#######6  | 380/500 [03:22<01:04,  1.87it/s]
 76%|#######6  | 381/500 [03:23<01:03,  1.87it/s]
 76%|#######6  | 382/500 [03:23<01:03,  1.87it/s]
 77%|#######6  | 383/500 [03:24<01:02,  1.87it/s]
 77%|#######6  | 384/500 [03:25<01:01,  1.87it/s]
 77%|#######7  | 385/500 [03:25<01:01,  1.87it/s]
 77%|#######7  | 386/500 [03:26<01:00,  1.87it/s]
 77%|#######7  | 387/500 [03:26<01:00,  1.87it/s]
 78%|#######7  | 388/500 [03:27<00:59,  1.87it/s]
 78%|#######7  | 389/500 [03:27<00:59,  1.87it/s]
 78%|#######8  | 390/500 [03:28<00:58,  1.87it/s]
 78%|#######8  | 391/500 [03:28<00:58,  1.87it/s]
 78%|#######8  | 392/500 [03:29<00:57,  1.87it/s]
 79%|#######8  | 393/500 [03:29<00:57,  1.87it/s]
 79%|#######8  | 394/500 [03:30<00:56,  1.87it/s]
 79%|#######9  | 395/500 [03:30<00:56,  1.87it/s]
 79%|#######9  | 396/500 [03:31<00:55,  1.87it/s]
 79%|#######9  | 397/500 [03:31<00:55,  1.87it/s]
 80%|#######9  | 398/500 [03:32<00:54,  1.87it/s]
 80%|#######9  | 399/500 [03:33<00:53,  1.87it/s]
 80%|########  | 400/500 [03:33<00:53,  1.87it/s]

{'loss': 1.0148, 'grad_norm': 7.89182186126709, 'learning_rate': 5e-05, 'epoch': 0.15}

 80%|########  | 400/500 [03:33<00:53,  1.87it/s]
 80%|########  | 401/500 [03:34<00:52,  1.87it/s]
 80%|########  | 402/500 [03:34<00:52,  1.87it/s]
 81%|########  | 403/500 [03:35<00:51,  1.87it/s]
 81%|########  | 404/500 [03:35<00:51,  1.87it/s]
 81%|########1 | 405/500 [03:36<00:50,  1.87it/s]
 81%|########1 | 406/500 [03:36<00:50,  1.87it/s]
 81%|########1 | 407/500 [03:37<00:49,  1.87it/s]
 82%|########1 | 408/500 [03:37<00:49,  1.87it/s]
 82%|########1 | 409/500 [03:38<00:48,  1.87it/s]
 82%|########2 | 410/500 [03:38<00:48,  1.87it/s]
 82%|########2 | 411/500 [03:39<00:47,  1.87it/s]
 82%|########2 | 412/500 [03:40<00:47,  1.87it/s]
 83%|########2 | 413/500 [03:40<00:46,  1.87it/s]
 83%|########2 | 414/500 [03:41<00:45,  1.87it/s]
 83%|########2 | 415/500 [03:41<00:45,  1.87it/s]
 83%|########3 | 416/500 [03:42<00:44,  1.87it/s]
 83%|########3 | 417/500 [03:42<00:44,  1.87it/s]
 84%|########3 | 418/500 [03:43<00:43,  1.87it/s]
 84%|########3 | 419/500 [03:43<00:43,  1.87it/s]
 84%|########4 | 420/500 [03:44<00:42,  1.87it/s]
 84%|########4 | 421/500 [03:44<00:42,  1.87it/s]
 84%|########4 | 422/500 [03:45<00:41,  1.87it/s]
 85%|########4 | 423/500 [03:45<00:41,  1.87it/s]
 85%|########4 | 424/500 [03:46<00:40,  1.87it/s]
 85%|########5 | 425/500 [03:46<00:40,  1.87it/s]
 85%|########5 | 426/500 [03:47<00:39,  1.87it/s]
 85%|########5 | 427/500 [03:48<00:39,  1.87it/s]
 86%|########5 | 428/500 [03:48<00:38,  1.87it/s]
 86%|########5 | 429/500 [03:49<00:37,  1.87it/s]
 86%|########6 | 430/500 [03:49<00:37,  1.87it/s]
 86%|########6 | 431/500 [03:50<00:36,  1.87it/s]
 86%|########6 | 432/500 [03:50<00:36,  1.87it/s]
 87%|########6 | 433/500 [03:51<00:35,  1.87it/s]
 87%|########6 | 434/500 [03:51<00:35,  1.87it/s]
 87%|########7 | 435/500 [03:52<00:34,  1.87it/s]
 87%|########7 | 436/500 [03:52<00:34,  1.87it/s]
 87%|########7 | 437/500 [03:53<00:33,  1.87it/s]
 88%|########7 | 438/500 [03:53<00:33,  1.87it/s]
 88%|########7 | 439/500 [03:54<00:32,  1.87it/s]
 88%|########8 | 440/500 [03:54<00:32,  1.87it/s]
 88%|########8 | 441/500 [03:55<00:31,  1.87it/s]
 88%|########8 | 442/500 [03:56<00:30,  1.87it/s]
 89%|########8 | 443/500 [03:56<00:30,  1.87it/s]
 89%|########8 | 444/500 [03:57<00:29,  1.87it/s]
 89%|########9 | 445/500 [03:57<00:29,  1.87it/s]
 89%|########9 | 446/500 [03:58<00:28,  1.87it/s]
 89%|########9 | 447/500 [03:58<00:28,  1.87it/s]
 90%|########9 | 448/500 [03:59<00:27,  1.87it/s]
 90%|########9 | 449/500 [03:59<00:27,  1.87it/s]
 90%|######### | 450/500 [04:00<00:26,  1.87it/s]

{'loss': 1.0004, 'grad_norm': 8.658374786376953, 'learning_rate': 5e-05, 'epoch': 0.16}

 90%|######### | 450/500 [04:00<00:26,  1.87it/s]
 90%|######### | 451/500 [04:00<00:26,  1.87it/s]
 90%|######### | 452/500 [04:01<00:25,  1.87it/s]
 91%|######### | 453/500 [04:01<00:25,  1.87it/s]
 91%|######### | 454/500 [04:02<00:24,  1.87it/s]
 91%|#########1| 455/500 [04:02<00:24,  1.87it/s]
 91%|#########1| 456/500 [04:03<00:23,  1.87it/s]
 91%|#########1| 457/500 [04:04<00:22,  1.87it/s]
 92%|#########1| 458/500 [04:04<00:22,  1.87it/s]
 92%|#########1| 459/500 [04:05<00:21,  1.87it/s]
 92%|#########2| 460/500 [04:05<00:21,  1.87it/s]
 92%|#########2| 461/500 [04:06<00:20,  1.87it/s]
 92%|#########2| 462/500 [04:06<00:20,  1.87it/s]
 93%|#########2| 463/500 [04:07<00:19,  1.87it/s]
 93%|#########2| 464/500 [04:07<00:19,  1.87it/s]
 93%|#########3| 465/500 [04:08<00:18,  1.87it/s]
 93%|#########3| 466/500 [04:08<00:18,  1.87it/s]
 93%|#########3| 467/500 [04:09<00:17,  1.87it/s]
 94%|#########3| 468/500 [04:09<00:17,  1.87it/s]
 94%|#########3| 469/500 [04:10<00:16,  1.87it/s]
 94%|#########3| 470/500 [04:11<00:16,  1.87it/s]
 94%|#########4| 471/500 [04:11<00:15,  1.87it/s]
 94%|#########4| 472/500 [04:12<00:14,  1.87it/s]
 95%|#########4| 473/500 [04:12<00:14,  1.87it/s]
 95%|#########4| 474/500 [04:13<00:13,  1.87it/s]
 95%|#########5| 475/500 [04:13<00:13,  1.87it/s]
 95%|#########5| 476/500 [04:14<00:12,  1.87it/s]
 95%|#########5| 477/500 [04:14<00:12,  1.87it/s]
 96%|#########5| 478/500 [04:15<00:11,  1.87it/s]
 96%|#########5| 479/500 [04:15<00:11,  1.87it/s]
 96%|#########6| 480/500 [04:16<00:10,  1.87it/s]
 96%|#########6| 481/500 [04:16<00:10,  1.87it/s]
 96%|#########6| 482/500 [04:17<00:09,  1.87it/s]
 97%|#########6| 483/500 [04:17<00:09,  1.87it/s]
 97%|#########6| 484/500 [04:18<00:08,  1.87it/s]
 97%|#########7| 485/500 [04:19<00:08,  1.87it/s]
 97%|#########7| 486/500 [04:19<00:07,  1.87it/s]
 97%|#########7| 487/500 [04:20<00:06,  1.87it/s]
 98%|#########7| 488/500 [04:20<00:06,  1.87it/s]
 98%|#########7| 489/500 [04:21<00:05,  1.87it/s]
 98%|#########8| 490/500 [04:21<00:05,  1.87it/s]
 98%|#########8| 491/500 [04:22<00:04,  1.87it/s]
 98%|#########8| 492/500 [04:22<00:04,  1.87it/s]
 99%|#########8| 493/500 [04:23<00:03,  1.87it/s]
 99%|#########8| 494/500 [04:23<00:03,  1.87it/s]
 99%|#########9| 495/500 [04:24<00:02,  1.87it/s]
 99%|#########9| 496/500 [04:24<00:02,  1.87it/s]
 99%|#########9| 497/500 [04:25<00:01,  1.87it/s]
100%|#########9| 498/500 [04:25<00:01,  1.87it/s]
100%|#########9| 499/500 [04:26<00:00,  1.87it/s]
100%|##########| 500/500 [04:27<00:00,  1.87it/s]

{'loss': 1.0101, 'grad_norm': 10.829127311706543, 'learning_rate': 5e-05, 'epoch': 0.18}

100%|##########| 500/500 [04:27<00:00,  1.87it/s]

{'train_runtime': 276.4621, 'train_samples_per_second': 57.874, 'train_steps_per_second': 1.809, 'train_loss': 1.2056114730834961, 'epoch': 0.18}

100%|##########| 500/500 [04:36<00:00,  1.87it/s]
100%|##########| 500/500 [04:36<00:00,  1.81it/s]
tensor([[ 0.0000, -0.0181,  0.0000,  0.0115,  0.0000,  0.0313,  0.0000, -0.0736],
        [ 0.0412, -0.0000, -0.0000,  0.0499,  0.0346, -0.0000, -0.0000, -0.0172],
        [-0.0336, -0.0328,  0.0000,  0.0000, -0.0000, -0.0000,  0.0572, -0.0192],
        [ 0.0000, -0.0000, -0.0506,  0.0289,  0.0965,  0.0000, -0.0000, -0.0766],
        [ 0.0088, -0.0000,  0.0000, -0.0406, -0.0000,  0.0185,  0.0657, -0.0000],
        [-0.0000, -0.0311,  0.0000,  0.0308,  0.0394, -0.0172,  0.0000, -0.0000],
        [-0.1069, -0.0296,  0.0000, -0.0000,  0.0425,  0.0661, -0.0000, -0.0000],
        [-0.0000,  0.0298, -0.0655, -0.0000,  0.0316, -0.0000,  0.0000, -0.0602]],
       device='cuda:0', grad_fn=<SliceBackward0>)

<Axes: title={'center': 'Loss vs. # steps'}, xlabel='step', ylabel='loss'>

加速 2:4 稀疏模型进行推理

现在我们有了这种格式的模型,我们可以像在快速入门指南中一样加速它进行推理。

model = model.cuda().half()
# accelerate for sparsity
for fqn, module in model.named_modules():
    if isinstance(module, nn.Linear) and "layer" in fqn:
        module.weight = nn.Parameter(to_sparse_semi_structured(module.weight))

with torch.no_grad():
    predictions = trainer.predict(tokenized_squad_dataset["validation"])
start_logits, end_logits = predictions.predictions
metrics_sparse = compute_metrics(
    start_logits,
    end_logits,
    tokenized_squad_dataset["validation"],
    squad_dataset["validation"],
)
print("sparse eval metrics: ", metrics_sparse)
sparse_perf = measure_execution_time(
    model,
    batch_sizes,
    tokenized_squad_dataset["validation"],
)
print("sparse perf metrics: ", sparse_perf)
  0%|          | 0/43 [00:00<?, ?it/s]
  5%|4         | 2/43 [00:01<00:27,  1.47it/s]
  7%|6         | 3/43 [00:02<00:38,  1.04it/s]
  9%|9         | 4/43 [00:04<00:43,  1.11s/it]
 12%|#1        | 5/43 [00:05<00:45,  1.20s/it]
 14%|#3        | 6/43 [00:06<00:46,  1.26s/it]
 16%|#6        | 7/43 [00:08<00:46,  1.29s/it]
 19%|#8        | 8/43 [00:09<00:45,  1.31s/it]
 21%|##        | 9/43 [00:10<00:45,  1.33s/it]
 23%|##3       | 10/43 [00:12<00:44,  1.34s/it]
 26%|##5       | 11/43 [00:13<00:43,  1.35s/it]
 28%|##7       | 12/43 [00:14<00:41,  1.35s/it]
 30%|###       | 13/43 [00:16<00:40,  1.36s/it]
 33%|###2      | 14/43 [00:17<00:39,  1.36s/it]
 35%|###4      | 15/43 [00:19<00:38,  1.36s/it]
 37%|###7      | 16/43 [00:20<00:36,  1.36s/it]
 40%|###9      | 17/43 [00:21<00:35,  1.36s/it]
 42%|####1     | 18/43 [00:23<00:34,  1.36s/it]
 44%|####4     | 19/43 [00:24<00:32,  1.36s/it]
 47%|####6     | 20/43 [00:25<00:31,  1.36s/it]
 49%|####8     | 21/43 [00:27<00:30,  1.36s/it]
 51%|#####1    | 22/43 [00:28<00:28,  1.36s/it]
 53%|#####3    | 23/43 [00:30<00:27,  1.37s/it]
 56%|#####5    | 24/43 [00:31<00:25,  1.37s/it]
 58%|#####8    | 25/43 [00:32<00:24,  1.37s/it]
 60%|######    | 26/43 [00:34<00:23,  1.37s/it]
 63%|######2   | 27/43 [00:35<00:21,  1.36s/it]
 65%|######5   | 28/43 [00:36<00:20,  1.36s/it]
 67%|######7   | 29/43 [00:38<00:19,  1.37s/it]
 70%|######9   | 30/43 [00:39<00:17,  1.37s/it]
 72%|#######2  | 31/43 [00:40<00:16,  1.37s/it]
 74%|#######4  | 32/43 [00:42<00:15,  1.37s/it]
 77%|#######6  | 33/43 [00:43<00:13,  1.37s/it]
 79%|#######9  | 34/43 [00:45<00:12,  1.37s/it]
 81%|########1 | 35/43 [00:46<00:10,  1.37s/it]
 84%|########3 | 36/43 [00:47<00:09,  1.36s/it]
 86%|########6 | 37/43 [00:49<00:08,  1.36s/it]
 88%|########8 | 38/43 [00:50<00:06,  1.36s/it]
 91%|######### | 39/43 [00:51<00:05,  1.36s/it]
 93%|#########3| 40/43 [00:53<00:04,  1.36s/it]
 95%|#########5| 41/43 [00:54<00:02,  1.36s/it]
 98%|#########7| 42/43 [00:55<00:01,  1.32s/it]
100%|##########| 43/43 [00:55<00:00,  1.02it/s]
100%|##########| 43/43 [00:55<00:00,  1.30s/it]
sparse eval metrics:  {'exact_match': 71.36234626300852, 'f1': 80.46835071907935}
sparse perf metrics:  {4: 16.749242000059894, '4_compile': 9.570522999638342, 16: 62.150992899842095, '16_compile': 34.265135300120164, 64: 243.14898499869741, '64_compile': 133.18610300029832, 256: 1195.5928600000334, '256_compile': 542.8887739999482}

在幅度修剪后重新训练我们的模型已恢复了模型被修剪时损失的几乎所有 F1 分数。同时,我们为 bs=16 获得了 1.28 倍的加速。请注意,并非所有形状都适合性能改进。当批次大小较小且在计算上花费的时间有限时,稀疏内核可能比其密集对应内核慢。

由于半结构化稀疏性作为张量子类实现,因此它与 torch.compile 兼容。当与 to_sparse_semi_structured 组合时,我们能够在 BERT 上实现 2 倍的加速。

指标

fp16

2:4 稀疏

增量/加速

已编译

精确匹配(%)

78.53

78.44

-0.09

F1(%)

86.93

86.49

-0.44

时间(bs=4)

11.10

15.54

0.71x

时间(bs=16)

19.35

15.74

1.23x

时间(bs=64)

72.71

59.41

1.22x

时间(bs=256)

286.65

247.63

1.14x

时间(bs=4)

7.59

7.46

1.02x

时间(bs=16)

11.47

9.68

1.18x

时间(bs=64)

41.57

36.92

1.13x

时间(bs=256)

159.22

142.23

1.12x

结论

在本教程中,我们展示了如何将 BERT 修剪为 2:4 稀疏以及如何加速 2:4 稀疏模型以进行推理。通过利用我们的 SparseSemiStructuredTensor 子类,我们能够实现比 fp16 基线快 1.3 倍的速度,并且使用 torch.compile 可以达到 2 倍的速度。我们还通过微调 BERT 来恢复任何丢失的 F1 分数(密集 86.92 对比稀疏 86.48),证明了 2:4 稀疏性的优势。

脚本的总运行时间:(15 分钟 30.187 秒)

由 Sphinx-Gallery 生成的图库

文档

访问 PyTorch 的全面开发者文档

查看文档

教程

获取适合初学者和高级开发人员的深入教程

查看教程

资源

查找开发资源并获得问题的解答

查看资源