• 教程 >
  • (beta)使用半结构化 (2:4) 稀疏加速 BERT
快捷方式

(beta)使用半结构化 (2:4) 稀疏加速 BERT

创建于:2024 年 4 月 22 日 | 最后更新:2024 年 4 月 22 日 | 最后验证:2024 年 11 月 05 日

作者Jesse Cai

概述

与其他形式的稀疏性一样,半结构化稀疏性是一种模型优化技术,旨在减少神经网络的内存开销和延迟,但会牺牲一些模型精度。它也称为细粒度结构化稀疏性2:4 结构化稀疏性

半结构化稀疏性的名称源于其独特的稀疏模式,即每 2n 个元素中修剪掉 n 个。我们最常见的是 n=2,因此是 2:4 稀疏性。半结构化稀疏性特别有趣,因为它可以在 GPU 上高效加速,并且不会像其他稀疏模式那样严重降低模型精度。

随着 半结构化稀疏性支持的引入,无需离开 PyTorch 即可修剪和加速半结构化稀疏模型。我们将在本教程中解释此过程。

../_static/img/pruning_flow.jpg

在本教程结束时,我们将把 BERT 问答模型稀疏化为 2:4 稀疏,对其进行微调以恢复几乎所有 F1 损失(密集型为 86.92,稀疏型为 86.48)。最后,我们将加速此 2:4 稀疏模型以进行推理,从而产生 1.3 倍的加速。

要求

  • PyTorch >= 2.1。

  • 支持半结构化稀疏性的 NVIDIA GPU(计算能力 8.0+)。

本教程专为半结构化稀疏性和一般稀疏性的初学者设计。对于已有 2:4 稀疏模型的用户,使用 to_sparse_semi_structured 加速 nn.Linear 层以进行推理非常简单。这是一个示例

import torch
from torch.sparse import to_sparse_semi_structured, SparseSemiStructuredTensor
from torch.utils.benchmark import Timer
SparseSemiStructuredTensor._FORCE_CUTLASS = True

# mask Linear weight to be 2:4 sparse
mask = torch.Tensor([0, 0, 1, 1]).tile((3072, 2560)).cuda().bool()
linear = torch.nn.Linear(10240, 3072).half().cuda().eval()
linear.weight = torch.nn.Parameter(mask * linear.weight)

x = torch.rand(3072, 10240).half().cuda()

with torch.inference_mode():
    dense_output = linear(x)
    dense_t = Timer(stmt="linear(x)",
                    globals={"linear": linear,
                             "x": x}).blocked_autorange().median * 1e3

    # accelerate via SparseSemiStructuredTensor
    linear.weight = torch.nn.Parameter(to_sparse_semi_structured(linear.weight))

    sparse_output = linear(x)
    sparse_t = Timer(stmt="linear(x)",
                    globals={"linear": linear,
                             "x": x}).blocked_autorange().median * 1e3

    # sparse and dense matmul are numerically equivalent
    # On an A100 80GB, we see: `Dense: 0.870ms Sparse: 0.630ms | Speedup: 1.382x`
    assert torch.allclose(sparse_output, dense_output, atol=1e-3)
    print(f"Dense: {dense_t:.3f}ms Sparse: {sparse_t:.3f}ms | Speedup: {(dense_t / sparse_t):.3f}x")
Dense: 2.919ms Sparse: 1.631ms | Speedup: 1.790x

半结构化稀疏性解决了什么问题?

稀疏性背后的总体动机很简单:如果您的网络中存在零,则可以通过不存储或计算这些参数来优化效率。但是,稀疏性的具体细节很棘手。从表面上看,将参数归零并不会影响模型的延迟/内存开销。

这是因为密集张量仍然包含已修剪(零)的元素,而密集矩阵乘法内核仍将对这些元素进行运算。为了实现性能提升,我们需要将密集内核换成稀疏内核,后者跳过涉及已修剪元素的计算。

为此,这些内核在稀疏矩阵上工作,稀疏矩阵不存储已修剪的元素,而是以压缩格式存储指定的元素。

对于半结构化稀疏性,我们存储原始参数的一半,以及一些关于元素排列方式的压缩元数据。

有许多不同的稀疏布局,每种布局都有其自身的优点和缺点。2:4 半结构化稀疏布局特别有趣,原因有二:

  • 与之前的稀疏格式不同,半结构化稀疏性旨在在 GPU 上高效加速。2020 年,NVIDIA 在其 Ampere 架构中引入了对半结构化稀疏性的硬件支持,并通过 CUTLASS cuSPARSELt 发布了快速稀疏内核。

  • 与此同时,与其他稀疏格式相比,半结构化稀疏性对模型精度的影响往往较小,尤其是在考虑更高级的剪枝/微调方法时。NVIDIA 在其 白皮书 中表明,一个简单的范例,即一次性幅度剪枝为 2:4 稀疏,然后重新训练模型,可以产生几乎相同的模型精度。

半结构化稀疏性处于一个最佳位置,在较低的稀疏度级别 (50%) 下提供 2 倍(理论上)的加速,同时仍然足够精细以保持模型精度。

网络

数据集

指标

密集 FP16

稀疏 FP16

ResNet-50

ImageNet

Top-1

76.1

76.2

ResNeXt-101_32x8d

ImageNet

Top-1

79.3

79.3

Xception

ImageNet

Top-1

79.2

79.2

SSD-RN50

COCO2017

bbAP

24.8

24.8

MaskRCNN-RN50

COCO2017

bbAP

37.9

37.9

FairSeq Transformer

EN-DE WMT14

BLEU

28.2

28.5

BERT-Large

SQuAD v1.1

F1

91.9

91.9

从工作流程的角度来看,半结构化稀疏性还有一个额外的优势。由于稀疏度级别固定为 50%,因此更容易将稀疏化模型的问题分解为两个不同的子问题:

  • 精度 - 我们如何找到一组 2:4 稀疏权重,使模型的精度下降最小?

  • 性能 - 我们如何加速我们的 2:4 稀疏权重以进行推理并减少内存开销?

\[\begin{bmatrix} 1 & 1 & 0 & 0 \\ 0 & 0 & 1 & 1 \\ 1 & 0 & 0 & 0 \\ 0 & 0 & 1 & 1 \\ \end{bmatrix}\]

这两个问题之间的自然交接点是归零的密集张量。我们的推理解决方案旨在压缩和加速这种格式的张量。我们预计许多用户会提出自定义掩码解决方案,因为这是一个活跃的研究领域。

现在我们已经对半结构化稀疏性有了更多了解,让我们将其应用于在问答任务 SQuAD 上训练的 BERT 模型。

简介和设置

让我们首先导入我们需要的所有包。

# If you are running this in Google Colab, run:
# .. code-block: python
#
#    !pip install datasets transformers evaluate accelerate pandas
#
import os
os.environ["WANDB_DISABLED"] = "true"

import collections
import datasets
import evaluate
import numpy as np
import torch
import torch.utils.benchmark as benchmark
from torch import nn
from torch.sparse import to_sparse_semi_structured, SparseSemiStructuredTensor
from torch.ao.pruning import WeightNormSparsifier
import transformers

# force CUTLASS use if ``cuSPARSELt`` is not available
SparseSemiStructuredTensor._FORCE_CUTLASS = True
torch.manual_seed(100)
<torch._C.Generator object at 0x7f38f13863d0>

我们还需要定义一些特定于手头数据集/任务的辅助函数。这些函数改编自 这个 Hugging Face 课程作为参考。

def preprocess_validation_function(examples, tokenizer):
    inputs = tokenizer(
        [q.strip() for q in examples["question"]],
        examples["context"],
        max_length=384,
        truncation="only_second",
        return_overflowing_tokens=True,
        return_offsets_mapping=True,
        padding="max_length",
    )
    sample_map = inputs.pop("overflow_to_sample_mapping")
    example_ids = []

    for i in range(len(inputs["input_ids"])):
        sample_idx = sample_map[i]
        example_ids.append(examples["id"][sample_idx])
        sequence_ids = inputs.sequence_ids(i)
        offset = inputs["offset_mapping"][i]
        inputs["offset_mapping"][i] = [
            o if sequence_ids[k] == 1 else None for k, o in enumerate(offset)
        ]

    inputs["example_id"] = example_ids
    return inputs


def preprocess_train_function(examples, tokenizer):
    inputs = tokenizer(
        [q.strip() for q in examples["question"]],
        examples["context"],
        max_length=384,
        truncation="only_second",
        return_offsets_mapping=True,
        padding="max_length",
    )

    offset_mapping = inputs["offset_mapping"]
    answers = examples["answers"]
    start_positions = []
    end_positions = []

    for i, (offset, answer) in enumerate(zip(offset_mapping, answers)):
        start_char = answer["answer_start"][0]
        end_char = start_char + len(answer["text"][0])
        sequence_ids = inputs.sequence_ids(i)

        # Find the start and end of the context
        idx = 0
        while sequence_ids[idx] != 1:
            idx += 1
        context_start = idx
        while sequence_ids[idx] == 1:
            idx += 1
        context_end = idx - 1

        # If the answer is not fully inside the context, label it (0, 0)
        if offset[context_start][0] > end_char or offset[context_end][1] < start_char:
            start_positions.append(0)
            end_positions.append(0)
        else:
            # Otherwise it's the start and end token positions
            idx = context_start
            while idx <= context_end and offset[idx][0] <= start_char:
                idx += 1
            start_positions.append(idx - 1)

            idx = context_end
            while idx >= context_start and offset[idx][1] >= end_char:
                idx -= 1
            end_positions.append(idx + 1)

    inputs["start_positions"] = start_positions
    inputs["end_positions"] = end_positions
    return inputs


def compute_metrics(start_logits, end_logits, features, examples):
    n_best = 20
    max_answer_length = 30
    metric = evaluate.load("squad")

    example_to_features = collections.defaultdict(list)
    for idx, feature in enumerate(features):
        example_to_features[feature["example_id"]].append(idx)

    predicted_answers = []
    # for example in ``tqdm`` (examples):
    for example in examples:
        example_id = example["id"]
        context = example["context"]
        answers = []

        # Loop through all features associated with that example
        for feature_index in example_to_features[example_id]:
            start_logit = start_logits[feature_index]
            end_logit = end_logits[feature_index]
            offsets = features[feature_index]["offset_mapping"]

            start_indexes = np.argsort(start_logit)[-1 : -n_best - 1 : -1].tolist()
            end_indexes = np.argsort(end_logit)[-1 : -n_best - 1 : -1].tolist()
            for start_index in start_indexes:
                for end_index in end_indexes:
                    # Skip answers that are not fully in the context
                    if offsets[start_index] is None or offsets[end_index] is None:
                        continue
                    # Skip answers with a length that is either < 0
                    # or > max_answer_length
                    if (
                        end_index < start_index
                        or end_index - start_index + 1 > max_answer_length
                    ):
                        continue

                    answer = {
                        "text": context[
                            offsets[start_index][0] : offsets[end_index][1]
                        ],
                        "logit_score": start_logit[start_index] + end_logit[end_index],
                    }
                    answers.append(answer)

        # Select the answer with the best score
        if len(answers) > 0:
            best_answer = max(answers, key=lambda x: x["logit_score"])
            predicted_answers.append(
                {"id": example_id, "prediction_text": best_answer["text"]}
            )
        else:
            predicted_answers.append({"id": example_id, "prediction_text": ""})

    theoretical_answers = [
        {"id": ex["id"], "answers": ex["answers"]} for ex in examples
    ]
    return metric.compute(predictions=predicted_answers, references=theoretical_answers)

现在这些都已定义,我们只需要一个额外的辅助函数,这将帮助我们对模型进行基准测试。

def measure_execution_time(model, batch_sizes, dataset):
    dataset_for_model = dataset.remove_columns(["example_id", "offset_mapping"])
    dataset_for_model.set_format("torch")
    batch_size_to_time_sec = {}
    for batch_size in batch_sizes:
        batch = {
            k: dataset_for_model[k][:batch_size].cuda()
            for k in dataset_for_model.column_names
        }

        with torch.no_grad():
            baseline_predictions = model(**batch)
            timer = benchmark.Timer(
                stmt="model(**batch)", globals={"model": model, "batch": batch}
            )
            p50 = timer.blocked_autorange().median * 1000
            batch_size_to_time_sec[batch_size] = p50

            model_c = torch.compile(model, fullgraph=True)
            timer = benchmark.Timer(
                stmt="model(**batch)", globals={"model": model_c, "batch": batch}
            )
            p50 = timer.blocked_autorange().median * 1000
            batch_size_to_time_sec[f"{batch_size}_compile"] = p50
            new_predictions = model_c(**batch)

    return batch_size_to_time_sec

我们将从加载我们的模型和分词器开始,然后设置我们的数据集。

# load model
model_name = "bert-base-cased"
tokenizer = transformers.AutoTokenizer.from_pretrained(model_name)
model = transformers.AutoModelForQuestionAnswering.from_pretrained(model_name)
print(f"Loading tokenizer: {model_name}")
print(f"Loading model: {model_name}")

# set up train and val dataset
squad_dataset = datasets.load_dataset("squad")
tokenized_squad_dataset = {}
tokenized_squad_dataset["train"] = squad_dataset["train"].map(
    lambda x: preprocess_train_function(x, tokenizer), batched=True
)
tokenized_squad_dataset["validation"] = squad_dataset["validation"].map(
    lambda x: preprocess_validation_function(x, tokenizer),
    batched=True,
    remove_columns=squad_dataset["train"].column_names,
)
data_collator = transformers.DataCollatorWithPadding(tokenizer=tokenizer)
Some weights of BertForQuestionAnswering were not initialized from the model checkpoint at bert-base-cased and are newly initialized: ['qa_outputs.bias', 'qa_outputs.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
Loading tokenizer: bert-base-cased
Loading model: bert-base-cased

Downloading readme:   0%|          | 0.00/7.62k [00:00<?, ?B/s]
Downloading readme: 100%|##########| 7.62k/7.62k [00:00<00:00, 54.9MB/s]

Downloading data:   0%|          | 0.00/14.5M [00:00<?, ?B/s]
Downloading data:  73%|#######2  | 10.5M/14.5M [00:00<00:00, 82.3MB/s]
Downloading data: 100%|##########| 14.5M/14.5M [00:00<00:00, 92.7MB/s]

Downloading data:   0%|          | 0.00/1.82M [00:00<?, ?B/s]
Downloading data: 100%|##########| 1.82M/1.82M [00:00<00:00, 20.4MB/s]

Generating train split:   0%|          | 0/87599 [00:00<?, ? examples/s]
Generating train split:  68%|######8   | 60000/87599 [00:00<00:00, 588374.21 examples/s]
Generating train split: 100%|##########| 87599/87599 [00:00<00:00, 603976.19 examples/s]

Generating validation split:   0%|          | 0/10570 [00:00<?, ? examples/s]
Generating validation split: 100%|##########| 10570/10570 [00:00<00:00, 557903.39 examples/s]

Map:   0%|          | 0/87599 [00:00<?, ? examples/s]
Map:   1%|1         | 1000/87599 [00:00<00:54, 1594.95 examples/s]
Map:   2%|2         | 2000/87599 [00:01<00:48, 1766.59 examples/s]
Map:   3%|3         | 3000/87599 [00:01<00:45, 1862.80 examples/s]
Map:   5%|4         | 4000/87599 [00:02<00:44, 1895.77 examples/s]
Map:   6%|5         | 5000/87599 [00:02<00:43, 1909.04 examples/s]
Map:   7%|6         | 6000/87599 [00:03<00:42, 1933.93 examples/s]
Map:   8%|7         | 7000/87599 [00:03<00:41, 1941.21 examples/s]
Map:   9%|9         | 8000/87599 [00:04<00:40, 1952.10 examples/s]
Map:  10%|#         | 9000/87599 [00:04<00:40, 1940.31 examples/s]
Map:  11%|#1        | 10000/87599 [00:05<00:39, 1942.26 examples/s]
Map:  13%|#2        | 11000/87599 [00:05<00:39, 1963.58 examples/s]
Map:  14%|#3        | 12000/87599 [00:06<00:38, 1978.01 examples/s]
Map:  15%|#4        | 13000/87599 [00:06<00:37, 1988.41 examples/s]
Map:  16%|#5        | 14000/87599 [00:07<00:36, 1991.69 examples/s]
Map:  17%|#7        | 15000/87599 [00:07<00:36, 2000.08 examples/s]
Map:  18%|#8        | 16000/87599 [00:08<00:35, 1998.76 examples/s]
Map:  19%|#9        | 17000/87599 [00:08<00:35, 1982.64 examples/s]
Map:  21%|##        | 18000/87599 [00:09<00:35, 1976.70 examples/s]
Map:  22%|##1       | 19000/87599 [00:09<00:34, 1982.41 examples/s]
Map:  23%|##2       | 20000/87599 [00:10<00:34, 1970.83 examples/s]
Map:  24%|##3       | 21000/87599 [00:10<00:34, 1958.02 examples/s]
Map:  25%|##5       | 22000/87599 [00:11<00:33, 1958.98 examples/s]
Map:  26%|##6       | 23000/87599 [00:11<00:32, 1968.69 examples/s]
Map:  27%|##7       | 24000/87599 [00:12<00:32, 1956.58 examples/s]
Map:  29%|##8       | 25000/87599 [00:12<00:32, 1953.29 examples/s]
Map:  30%|##9       | 26000/87599 [00:13<00:31, 1948.25 examples/s]
Map:  31%|###       | 27000/87599 [00:13<00:31, 1929.39 examples/s]
Map:  32%|###1      | 28000/87599 [00:14<00:30, 1935.91 examples/s]
Map:  33%|###3      | 29000/87599 [00:14<00:30, 1923.13 examples/s]
Map:  34%|###4      | 30000/87599 [00:15<00:29, 1926.18 examples/s]
Map:  35%|###5      | 31000/87599 [00:15<00:29, 1924.42 examples/s]
Map:  37%|###6      | 32000/87599 [00:16<00:28, 1926.52 examples/s]
Map:  38%|###7      | 33000/87599 [00:16<00:28, 1923.94 examples/s]
Map:  39%|###8      | 34000/87599 [00:17<00:28, 1911.92 examples/s]
Map:  40%|###9      | 35000/87599 [00:18<00:27, 1900.57 examples/s]
Map:  41%|####1     | 36000/87599 [00:18<00:27, 1908.59 examples/s]
Map:  42%|####2     | 37000/87599 [00:19<00:26, 1902.59 examples/s]
Map:  43%|####3     | 38000/87599 [00:19<00:25, 1914.32 examples/s]
Map:  45%|####4     | 39000/87599 [00:20<00:25, 1922.25 examples/s]
Map:  46%|####5     | 40000/87599 [00:20<00:24, 1908.92 examples/s]
Map:  47%|####6     | 41000/87599 [00:21<00:24, 1917.21 examples/s]
Map:  48%|####7     | 42000/87599 [00:21<00:23, 1923.73 examples/s]
Map:  49%|####9     | 43000/87599 [00:22<00:23, 1930.86 examples/s]
Map:  50%|#####     | 44000/87599 [00:22<00:22, 1921.43 examples/s]
Map:  51%|#####1    | 45000/87599 [00:23<00:22, 1922.71 examples/s]
Map:  53%|#####2    | 46000/87599 [00:23<00:21, 1918.46 examples/s]
Map:  54%|#####3    | 47000/87599 [00:24<00:21, 1919.52 examples/s]
Map:  55%|#####4    | 48000/87599 [00:24<00:20, 1925.01 examples/s]
Map:  56%|#####5    | 49000/87599 [00:25<00:20, 1929.20 examples/s]
Map:  57%|#####7    | 50000/87599 [00:25<00:19, 1928.87 examples/s]
Map:  58%|#####8    | 51000/87599 [00:26<00:18, 1929.11 examples/s]
Map:  59%|#####9    | 52000/87599 [00:26<00:18, 1930.75 examples/s]
Map:  61%|######    | 53000/87599 [00:27<00:17, 1938.93 examples/s]
Map:  62%|######1   | 54000/87599 [00:27<00:17, 1936.58 examples/s]
Map:  63%|######2   | 55000/87599 [00:28<00:16, 1920.65 examples/s]
Map:  64%|######3   | 56000/87599 [00:28<00:16, 1906.35 examples/s]
Map:  65%|######5   | 57000/87599 [00:29<00:16, 1909.41 examples/s]
Map:  66%|######6   | 58000/87599 [00:30<00:15, 1914.09 examples/s]
Map:  67%|######7   | 59000/87599 [00:30<00:14, 1919.35 examples/s]
Map:  68%|######8   | 60000/87599 [00:31<00:14, 1930.71 examples/s]
Map:  70%|######9   | 61000/87599 [00:31<00:13, 1933.99 examples/s]
Map:  71%|#######   | 62000/87599 [00:32<00:13, 1932.09 examples/s]
Map:  72%|#######1  | 63000/87599 [00:32<00:12, 1934.19 examples/s]
Map:  73%|#######3  | 64000/87599 [00:33<00:12, 1934.31 examples/s]
Map:  74%|#######4  | 65000/87599 [00:33<00:11, 1930.45 examples/s]
Map:  75%|#######5  | 66000/87599 [00:34<00:11, 1910.18 examples/s]
Map:  76%|#######6  | 67000/87599 [00:34<00:10, 1916.77 examples/s]
Map:  78%|#######7  | 68000/87599 [00:35<00:10, 1922.31 examples/s]
Map:  79%|#######8  | 69000/87599 [00:35<00:09, 1922.97 examples/s]
Map:  80%|#######9  | 70000/87599 [00:36<00:09, 1911.53 examples/s]
Map:  81%|########1 | 71000/87599 [00:36<00:08, 1914.39 examples/s]
Map:  82%|########2 | 72000/87599 [00:37<00:08, 1914.35 examples/s]
Map:  83%|########3 | 73000/87599 [00:37<00:07, 1902.87 examples/s]
Map:  84%|########4 | 74000/87599 [00:38<00:07, 1909.53 examples/s]
Map:  86%|########5 | 75000/87599 [00:38<00:06, 1893.71 examples/s]
Map:  87%|########6 | 76000/87599 [00:39<00:06, 1880.76 examples/s]
Map:  88%|########7 | 77000/87599 [00:39<00:05, 1885.01 examples/s]
Map:  89%|########9 | 78000/87599 [00:40<00:05, 1883.43 examples/s]
Map:  90%|######### | 79000/87599 [00:41<00:04, 1879.84 examples/s]
Map:  91%|#########1| 80000/87599 [00:41<00:04, 1894.62 examples/s]
Map:  92%|#########2| 81000/87599 [00:42<00:03, 1905.04 examples/s]
Map:  94%|#########3| 82000/87599 [00:42<00:02, 1915.37 examples/s]
Map:  95%|#########4| 83000/87599 [00:43<00:02, 1907.68 examples/s]
Map:  96%|#########5| 84000/87599 [00:43<00:01, 1914.94 examples/s]
Map:  97%|#########7| 85000/87599 [00:44<00:01, 1922.16 examples/s]
Map:  98%|#########8| 86000/87599 [00:44<00:00, 1923.66 examples/s]
Map:  99%|#########9| 87000/87599 [00:45<00:00, 1917.90 examples/s]
Map: 100%|##########| 87599/87599 [00:45<00:00, 1894.20 examples/s]
Map: 100%|##########| 87599/87599 [00:45<00:00, 1924.05 examples/s]

Map:   0%|          | 0/10570 [00:00<?, ? examples/s]
Map:   9%|9         | 1000/10570 [00:02<00:27, 346.07 examples/s]
Map:  19%|#8        | 2000/10570 [00:03<00:12, 711.73 examples/s]
Map:  28%|##8       | 3000/10570 [00:03<00:07, 1069.67 examples/s]
Map:  38%|###7      | 4000/10570 [00:04<00:04, 1392.72 examples/s]
Map:  47%|####7     | 5000/10570 [00:04<00:03, 1588.95 examples/s]
Map:  57%|#####6    | 6000/10570 [00:04<00:02, 1796.99 examples/s]
Map:  66%|######6   | 7000/10570 [00:05<00:01, 1974.33 examples/s]
Map:  76%|#######5  | 8000/10570 [00:05<00:01, 2110.35 examples/s]
Map:  85%|########5 | 9000/10570 [00:06<00:00, 2215.09 examples/s]
Map:  95%|#########4| 10000/10570 [00:06<00:00, 2314.59 examples/s]
Map: 100%|##########| 10570/10570 [00:06<00:00, 2339.30 examples/s]
Map: 100%|##########| 10570/10570 [00:06<00:00, 1567.49 examples/s]

建立基线

接下来,我们将在 SQuAD 上训练我们模型的快速基线。此任务要求我们的模型在给定的上下文(维基百科文章)中识别回答给定问题的跨度或文本段。运行以下代码,我得到的 F1 分数为 86.9。这与 NVIDIA 报告的分数非常接近,差异可能是由于 BERT-base 与 BERT-large 或微调超参数造成的。

training_args = transformers.TrainingArguments(
    "trainer",
    num_train_epochs=1,
    lr_scheduler_type="constant",
    per_device_train_batch_size=32,
    per_device_eval_batch_size=256,
    logging_steps=50,
    # Limit max steps for tutorial runners. Delete the below line to see the reported accuracy numbers.
    max_steps=500,
    report_to=None,
)

trainer = transformers.Trainer(
    model,
    training_args,
    train_dataset=tokenized_squad_dataset["train"],
    eval_dataset=tokenized_squad_dataset["validation"],
    data_collator=data_collator,
    tokenizer=tokenizer,
)

trainer.train()

# batch sizes to compare for eval
batch_sizes = [4, 16, 64, 256]
# 2:4 sparsity require fp16, so we cast here for a fair comparison
with torch.autocast("cuda"):
    with torch.no_grad():
        predictions = trainer.predict(tokenized_squad_dataset["validation"])
        start_logits, end_logits = predictions.predictions
        fp16_baseline = compute_metrics(
            start_logits,
            end_logits,
            tokenized_squad_dataset["validation"],
            squad_dataset["validation"],
        )
        fp16_time = measure_execution_time(
            model,
            batch_sizes,
            tokenized_squad_dataset["validation"],
        )

print("fp16", fp16_baseline)
print("cuda_fp16 time", fp16_time)

import pandas as pd
df = pd.DataFrame(trainer.state.log_history)
df.plot.line(x='step', y='loss', title="Loss vs. # steps", ylabel="loss")
Loss vs. # steps
torch.distributed process group is initialized, but parallel_mode != ParallelMode.DISTRIBUTED. In order to use Torch DDP, launch your script with `python -m torch.distributed.launch
Using the `WANDB_DISABLED` environment variable is deprecated and will be removed in v5. Use the --report_to flag to control the integrations used for logging result (for instance --report_to none).

  0%|          | 0/500 [00:00<?, ?it/s]
  0%|          | 1/500 [00:00<03:59,  2.09it/s]
  0%|          | 2/500 [00:00<03:47,  2.19it/s]
  1%|          | 3/500 [00:01<03:44,  2.22it/s]
  1%|          | 4/500 [00:01<03:42,  2.23it/s]
  1%|1         | 5/500 [00:02<03:41,  2.24it/s]
  1%|1         | 6/500 [00:02<03:40,  2.24it/s]
  1%|1         | 7/500 [00:03<03:39,  2.24it/s]
  2%|1         | 8/500 [00:03<03:39,  2.24it/s]
  2%|1         | 9/500 [00:04<03:38,  2.24it/s]
  2%|2         | 10/500 [00:04<03:38,  2.25it/s]
  2%|2         | 11/500 [00:04<03:37,  2.25it/s]
  2%|2         | 12/500 [00:05<03:37,  2.25it/s]
  3%|2         | 13/500 [00:05<03:36,  2.25it/s]
  3%|2         | 14/500 [00:06<03:36,  2.25it/s]
  3%|3         | 15/500 [00:06<03:35,  2.25it/s]
  3%|3         | 16/500 [00:07<03:35,  2.25it/s]
  3%|3         | 17/500 [00:07<03:34,  2.25it/s]
  4%|3         | 18/500 [00:08<03:34,  2.25it/s]
  4%|3         | 19/500 [00:08<03:34,  2.25it/s]
  4%|4         | 20/500 [00:08<03:33,  2.25it/s]
  4%|4         | 21/500 [00:09<03:33,  2.25it/s]
  4%|4         | 22/500 [00:09<03:32,  2.25it/s]
  5%|4         | 23/500 [00:10<03:32,  2.25it/s]
  5%|4         | 24/500 [00:10<03:31,  2.25it/s]
  5%|5         | 25/500 [00:11<03:31,  2.25it/s]
  5%|5         | 26/500 [00:11<03:30,  2.25it/s]
  5%|5         | 27/500 [00:12<03:30,  2.25it/s]
  6%|5         | 28/500 [00:12<03:29,  2.25it/s]
  6%|5         | 29/500 [00:12<03:29,  2.25it/s]
  6%|6         | 30/500 [00:13<03:29,  2.25it/s]
  6%|6         | 31/500 [00:13<03:28,  2.25it/s]
  6%|6         | 32/500 [00:14<03:28,  2.25it/s]
  7%|6         | 33/500 [00:14<03:27,  2.25it/s]
  7%|6         | 34/500 [00:15<03:27,  2.25it/s]
  7%|7         | 35/500 [00:15<03:26,  2.25it/s]
  7%|7         | 36/500 [00:16<03:26,  2.25it/s]
  7%|7         | 37/500 [00:16<03:25,  2.25it/s]
  8%|7         | 38/500 [00:16<03:25,  2.25it/s]
  8%|7         | 39/500 [00:17<03:25,  2.25it/s]
  8%|8         | 40/500 [00:17<03:24,  2.25it/s]
  8%|8         | 41/500 [00:18<03:24,  2.25it/s]
  8%|8         | 42/500 [00:18<03:23,  2.25it/s]
  9%|8         | 43/500 [00:19<03:23,  2.25it/s]
  9%|8         | 44/500 [00:19<03:22,  2.25it/s]
  9%|9         | 45/500 [00:20<03:22,  2.25it/s]
  9%|9         | 46/500 [00:20<03:21,  2.25it/s]
  9%|9         | 47/500 [00:20<03:21,  2.25it/s]
 10%|9         | 48/500 [00:21<03:20,  2.25it/s]
 10%|9         | 49/500 [00:21<03:20,  2.25it/s]
 10%|#         | 50/500 [00:22<03:20,  2.25it/s]

{'loss': 3.8381, 'grad_norm': 14.371336936950684, 'learning_rate': 5e-05, 'epoch': 0.02}

 10%|#         | 50/500 [00:22<03:20,  2.25it/s]
 10%|#         | 51/500 [00:22<03:19,  2.25it/s]
 10%|#         | 52/500 [00:23<03:19,  2.25it/s]
 11%|#         | 53/500 [00:23<03:18,  2.25it/s]
 11%|#         | 54/500 [00:24<03:18,  2.25it/s]
 11%|#1        | 55/500 [00:24<03:17,  2.25it/s]
 11%|#1        | 56/500 [00:24<03:17,  2.25it/s]
 11%|#1        | 57/500 [00:25<03:17,  2.25it/s]
 12%|#1        | 58/500 [00:25<03:16,  2.25it/s]
 12%|#1        | 59/500 [00:26<03:16,  2.25it/s]
 12%|#2        | 60/500 [00:26<03:15,  2.25it/s]
 12%|#2        | 61/500 [00:27<03:15,  2.25it/s]
 12%|#2        | 62/500 [00:27<03:14,  2.25it/s]
 13%|#2        | 63/500 [00:28<03:14,  2.25it/s]
 13%|#2        | 64/500 [00:28<03:13,  2.25it/s]
 13%|#3        | 65/500 [00:28<03:13,  2.25it/s]
 13%|#3        | 66/500 [00:29<03:13,  2.25it/s]
 13%|#3        | 67/500 [00:29<03:12,  2.25it/s]
 14%|#3        | 68/500 [00:30<03:12,  2.25it/s]
 14%|#3        | 69/500 [00:30<03:11,  2.25it/s]
 14%|#4        | 70/500 [00:31<03:11,  2.25it/s]
 14%|#4        | 71/500 [00:31<03:10,  2.25it/s]
 14%|#4        | 72/500 [00:32<03:10,  2.25it/s]
 15%|#4        | 73/500 [00:32<03:09,  2.25it/s]
 15%|#4        | 74/500 [00:32<03:09,  2.25it/s]
 15%|#5        | 75/500 [00:33<03:09,  2.25it/s]
 15%|#5        | 76/500 [00:33<03:08,  2.25it/s]
 15%|#5        | 77/500 [00:34<03:08,  2.25it/s]
 16%|#5        | 78/500 [00:34<03:07,  2.25it/s]
 16%|#5        | 79/500 [00:35<03:07,  2.25it/s]
 16%|#6        | 80/500 [00:35<03:06,  2.25it/s]
 16%|#6        | 81/500 [00:36<03:06,  2.25it/s]
 16%|#6        | 82/500 [00:36<03:05,  2.25it/s]
 17%|#6        | 83/500 [00:36<03:05,  2.25it/s]
 17%|#6        | 84/500 [00:37<03:04,  2.25it/s]
 17%|#7        | 85/500 [00:37<03:04,  2.25it/s]
 17%|#7        | 86/500 [00:38<03:04,  2.25it/s]
 17%|#7        | 87/500 [00:38<03:03,  2.25it/s]
 18%|#7        | 88/500 [00:39<03:03,  2.25it/s]
 18%|#7        | 89/500 [00:39<03:02,  2.25it/s]
 18%|#8        | 90/500 [00:40<03:02,  2.25it/s]
 18%|#8        | 91/500 [00:40<03:01,  2.25it/s]
 18%|#8        | 92/500 [00:40<03:01,  2.25it/s]
 19%|#8        | 93/500 [00:41<03:00,  2.25it/s]
 19%|#8        | 94/500 [00:41<03:00,  2.25it/s]
 19%|#9        | 95/500 [00:42<03:00,  2.25it/s]
 19%|#9        | 96/500 [00:42<02:59,  2.25it/s]
 19%|#9        | 97/500 [00:43<02:59,  2.25it/s]
 20%|#9        | 98/500 [00:43<02:58,  2.25it/s]
 20%|#9        | 99/500 [00:44<02:58,  2.25it/s]
 20%|##        | 100/500 [00:44<02:57,  2.25it/s]

{'loss': 2.3804, 'grad_norm': 15.045374870300293, 'learning_rate': 5e-05, 'epoch': 0.04}

 20%|##        | 100/500 [00:44<02:57,  2.25it/s]
 20%|##        | 101/500 [00:44<02:57,  2.25it/s]
 20%|##        | 102/500 [00:45<02:57,  2.25it/s]
 21%|##        | 103/500 [00:45<02:56,  2.25it/s]
 21%|##        | 104/500 [00:46<02:56,  2.25it/s]
 21%|##1       | 105/500 [00:46<02:55,  2.25it/s]
 21%|##1       | 106/500 [00:47<02:55,  2.25it/s]
 21%|##1       | 107/500 [00:47<02:54,  2.25it/s]
 22%|##1       | 108/500 [00:48<02:54,  2.25it/s]
 22%|##1       | 109/500 [00:48<02:53,  2.25it/s]
 22%|##2       | 110/500 [00:48<02:53,  2.25it/s]
 22%|##2       | 111/500 [00:49<02:52,  2.25it/s]
 22%|##2       | 112/500 [00:49<02:52,  2.25it/s]
 23%|##2       | 113/500 [00:50<02:52,  2.25it/s]
 23%|##2       | 114/500 [00:50<02:51,  2.25it/s]
 23%|##3       | 115/500 [00:51<02:51,  2.25it/s]
 23%|##3       | 116/500 [00:51<02:50,  2.25it/s]
 23%|##3       | 117/500 [00:52<02:50,  2.25it/s]
 24%|##3       | 118/500 [00:52<02:49,  2.25it/s]
 24%|##3       | 119/500 [00:52<02:49,  2.25it/s]
 24%|##4       | 120/500 [00:53<02:48,  2.25it/s]
 24%|##4       | 121/500 [00:53<02:48,  2.25it/s]
 24%|##4       | 122/500 [00:54<02:48,  2.25it/s]
 25%|##4       | 123/500 [00:54<02:47,  2.25it/s]
 25%|##4       | 124/500 [00:55<02:47,  2.25it/s]
 25%|##5       | 125/500 [00:55<02:46,  2.25it/s]
 25%|##5       | 126/500 [00:56<02:46,  2.25it/s]
 25%|##5       | 127/500 [00:56<02:45,  2.25it/s]
 26%|##5       | 128/500 [00:56<02:45,  2.25it/s]
 26%|##5       | 129/500 [00:57<02:44,  2.25it/s]
 26%|##6       | 130/500 [00:57<02:44,  2.25it/s]
 26%|##6       | 131/500 [00:58<02:44,  2.25it/s]
 26%|##6       | 132/500 [00:58<02:43,  2.25it/s]
 27%|##6       | 133/500 [00:59<02:43,  2.25it/s]
 27%|##6       | 134/500 [00:59<02:42,  2.25it/s]
 27%|##7       | 135/500 [01:00<02:42,  2.25it/s]
 27%|##7       | 136/500 [01:00<02:41,  2.25it/s]
 27%|##7       | 137/500 [01:00<02:41,  2.25it/s]
 28%|##7       | 138/500 [01:01<02:40,  2.25it/s]
 28%|##7       | 139/500 [01:01<02:40,  2.25it/s]
 28%|##8       | 140/500 [01:02<02:40,  2.25it/s]
 28%|##8       | 141/500 [01:02<02:39,  2.25it/s]
 28%|##8       | 142/500 [01:03<02:39,  2.25it/s]
 29%|##8       | 143/500 [01:03<02:38,  2.25it/s]
 29%|##8       | 144/500 [01:04<02:38,  2.25it/s]
 29%|##9       | 145/500 [01:04<02:37,  2.25it/s]
 29%|##9       | 146/500 [01:04<02:37,  2.25it/s]
 29%|##9       | 147/500 [01:05<02:36,  2.25it/s]
 30%|##9       | 148/500 [01:05<02:36,  2.25it/s]
 30%|##9       | 149/500 [01:06<02:36,  2.25it/s]
 30%|###       | 150/500 [01:06<02:35,  2.25it/s]

{'loss': 1.8675, 'grad_norm': 11.23995590209961, 'learning_rate': 5e-05, 'epoch': 0.05}

 30%|###       | 150/500 [01:06<02:35,  2.25it/s]
 30%|###       | 151/500 [01:07<02:35,  2.25it/s]
 30%|###       | 152/500 [01:07<02:34,  2.25it/s]
 31%|###       | 153/500 [01:08<02:34,  2.25it/s]
 31%|###       | 154/500 [01:08<02:33,  2.25it/s]
 31%|###1      | 155/500 [01:08<02:33,  2.25it/s]
 31%|###1      | 156/500 [01:09<02:32,  2.25it/s]
 31%|###1      | 157/500 [01:09<02:32,  2.25it/s]
 32%|###1      | 158/500 [01:10<02:32,  2.25it/s]
 32%|###1      | 159/500 [01:10<02:31,  2.25it/s]
 32%|###2      | 160/500 [01:11<02:31,  2.25it/s]
 32%|###2      | 161/500 [01:11<02:30,  2.25it/s]
 32%|###2      | 162/500 [01:12<02:30,  2.25it/s]
 33%|###2      | 163/500 [01:12<02:29,  2.25it/s]
 33%|###2      | 164/500 [01:12<02:29,  2.25it/s]
 33%|###3      | 165/500 [01:13<02:28,  2.25it/s]
 33%|###3      | 166/500 [01:13<02:28,  2.25it/s]
 33%|###3      | 167/500 [01:14<02:28,  2.25it/s]
 34%|###3      | 168/500 [01:14<02:27,  2.25it/s]
 34%|###3      | 169/500 [01:15<02:27,  2.25it/s]
 34%|###4      | 170/500 [01:15<02:26,  2.25it/s]
 34%|###4      | 171/500 [01:16<02:26,  2.25it/s]
 34%|###4      | 172/500 [01:16<02:25,  2.25it/s]
 35%|###4      | 173/500 [01:16<02:25,  2.25it/s]
 35%|###4      | 174/500 [01:17<02:24,  2.25it/s]
 35%|###5      | 175/500 [01:17<02:24,  2.25it/s]
 35%|###5      | 176/500 [01:18<02:24,  2.25it/s]
 35%|###5      | 177/500 [01:18<02:23,  2.25it/s]
 36%|###5      | 178/500 [01:19<02:23,  2.25it/s]
 36%|###5      | 179/500 [01:19<02:22,  2.25it/s]
 36%|###6      | 180/500 [01:20<02:22,  2.25it/s]
 36%|###6      | 181/500 [01:20<02:21,  2.25it/s]
 36%|###6      | 182/500 [01:20<02:21,  2.25it/s]
 37%|###6      | 183/500 [01:21<02:20,  2.25it/s]
 37%|###6      | 184/500 [01:21<02:20,  2.25it/s]
 37%|###7      | 185/500 [01:22<02:20,  2.25it/s]
 37%|###7      | 186/500 [01:22<02:19,  2.25it/s]
 37%|###7      | 187/500 [01:23<02:19,  2.25it/s]
 38%|###7      | 188/500 [01:23<02:18,  2.25it/s]
 38%|###7      | 189/500 [01:24<02:18,  2.25it/s]
 38%|###8      | 190/500 [01:24<02:17,  2.25it/s]
 38%|###8      | 191/500 [01:24<02:17,  2.25it/s]
 38%|###8      | 192/500 [01:25<02:16,  2.25it/s]
 39%|###8      | 193/500 [01:25<02:16,  2.25it/s]
 39%|###8      | 194/500 [01:26<02:16,  2.25it/s]
 39%|###9      | 195/500 [01:26<02:15,  2.25it/s]
 39%|###9      | 196/500 [01:27<02:15,  2.25it/s]
 39%|###9      | 197/500 [01:27<02:14,  2.25it/s]
 40%|###9      | 198/500 [01:28<02:14,  2.25it/s]
 40%|###9      | 199/500 [01:28<02:13,  2.25it/s]
 40%|####      | 200/500 [01:28<02:13,  2.25it/s]

{'loss': 1.728, 'grad_norm': 12.760137557983398, 'learning_rate': 5e-05, 'epoch': 0.07}

 40%|####      | 200/500 [01:28<02:13,  2.25it/s]
 40%|####      | 201/500 [01:29<02:13,  2.25it/s]
 40%|####      | 202/500 [01:29<02:12,  2.25it/s]
 41%|####      | 203/500 [01:30<02:12,  2.25it/s]
 41%|####      | 204/500 [01:30<02:11,  2.25it/s]
 41%|####1     | 205/500 [01:31<02:11,  2.25it/s]
 41%|####1     | 206/500 [01:31<02:10,  2.25it/s]
 41%|####1     | 207/500 [01:32<02:10,  2.25it/s]
 42%|####1     | 208/500 [01:32<02:09,  2.25it/s]
 42%|####1     | 209/500 [01:32<02:09,  2.25it/s]
 42%|####2     | 210/500 [01:33<02:08,  2.25it/s]
 42%|####2     | 211/500 [01:33<02:08,  2.25it/s]
 42%|####2     | 212/500 [01:34<02:08,  2.25it/s]
 43%|####2     | 213/500 [01:34<02:07,  2.25it/s]
 43%|####2     | 214/500 [01:35<02:07,  2.25it/s]
 43%|####3     | 215/500 [01:35<02:06,  2.25it/s]
 43%|####3     | 216/500 [01:36<02:06,  2.25it/s]
 43%|####3     | 217/500 [01:36<02:05,  2.25it/s]
 44%|####3     | 218/500 [01:36<02:05,  2.25it/s]
 44%|####3     | 219/500 [01:37<02:05,  2.25it/s]
 44%|####4     | 220/500 [01:37<02:04,  2.25it/s]
 44%|####4     | 221/500 [01:38<02:04,  2.25it/s]
 44%|####4     | 222/500 [01:38<02:03,  2.25it/s]
 45%|####4     | 223/500 [01:39<02:03,  2.25it/s]
 45%|####4     | 224/500 [01:39<02:02,  2.25it/s]
 45%|####5     | 225/500 [01:40<02:02,  2.25it/s]
 45%|####5     | 226/500 [01:40<02:01,  2.25it/s]
 45%|####5     | 227/500 [01:40<02:01,  2.25it/s]
 46%|####5     | 228/500 [01:41<02:00,  2.25it/s]
 46%|####5     | 229/500 [01:41<02:00,  2.25it/s]
 46%|####6     | 230/500 [01:42<02:00,  2.25it/s]
 46%|####6     | 231/500 [01:42<01:59,  2.25it/s]
 46%|####6     | 232/500 [01:43<01:59,  2.25it/s]
 47%|####6     | 233/500 [01:43<01:58,  2.25it/s]
 47%|####6     | 234/500 [01:44<01:58,  2.25it/s]
 47%|####6     | 235/500 [01:44<01:57,  2.25it/s]
 47%|####7     | 236/500 [01:44<01:57,  2.25it/s]
 47%|####7     | 237/500 [01:45<01:56,  2.25it/s]
 48%|####7     | 238/500 [01:45<01:56,  2.25it/s]
 48%|####7     | 239/500 [01:46<01:56,  2.25it/s]
 48%|####8     | 240/500 [01:46<01:55,  2.25it/s]
 48%|####8     | 241/500 [01:47<01:55,  2.25it/s]
 48%|####8     | 242/500 [01:47<01:54,  2.25it/s]
 49%|####8     | 243/500 [01:48<01:54,  2.25it/s]
 49%|####8     | 244/500 [01:48<01:53,  2.25it/s]
 49%|####9     | 245/500 [01:48<01:53,  2.25it/s]
 49%|####9     | 246/500 [01:49<01:52,  2.25it/s]
 49%|####9     | 247/500 [01:49<01:52,  2.25it/s]
 50%|####9     | 248/500 [01:50<01:52,  2.25it/s]
 50%|####9     | 249/500 [01:50<01:51,  2.25it/s]
 50%|#####     | 250/500 [01:51<01:51,  2.25it/s]

{'loss': 1.5848, 'grad_norm': 12.067267417907715, 'learning_rate': 5e-05, 'epoch': 0.09}

 50%|#####     | 250/500 [01:51<01:51,  2.25it/s]
 50%|#####     | 251/500 [01:51<01:50,  2.25it/s]
 50%|#####     | 252/500 [01:52<01:50,  2.25it/s]
 51%|#####     | 253/500 [01:52<01:49,  2.25it/s]
 51%|#####     | 254/500 [01:52<01:49,  2.25it/s]
 51%|#####1    | 255/500 [01:53<01:49,  2.25it/s]
 51%|#####1    | 256/500 [01:53<01:48,  2.25it/s]
 51%|#####1    | 257/500 [01:54<01:48,  2.25it/s]
 52%|#####1    | 258/500 [01:54<01:47,  2.25it/s]
 52%|#####1    | 259/500 [01:55<01:47,  2.25it/s]
 52%|#####2    | 260/500 [01:55<01:46,  2.25it/s]
 52%|#####2    | 261/500 [01:56<01:46,  2.25it/s]
 52%|#####2    | 262/500 [01:56<01:45,  2.25it/s]
 53%|#####2    | 263/500 [01:56<01:45,  2.25it/s]
 53%|#####2    | 264/500 [01:57<01:44,  2.25it/s]
 53%|#####3    | 265/500 [01:57<01:44,  2.25it/s]
 53%|#####3    | 266/500 [01:58<01:44,  2.25it/s]
 53%|#####3    | 267/500 [01:58<01:43,  2.25it/s]
 54%|#####3    | 268/500 [01:59<01:43,  2.25it/s]
 54%|#####3    | 269/500 [01:59<01:42,  2.25it/s]
 54%|#####4    | 270/500 [02:00<01:42,  2.25it/s]
 54%|#####4    | 271/500 [02:00<01:41,  2.25it/s]
 54%|#####4    | 272/500 [02:00<01:41,  2.25it/s]
 55%|#####4    | 273/500 [02:01<01:40,  2.25it/s]
 55%|#####4    | 274/500 [02:01<01:40,  2.25it/s]
 55%|#####5    | 275/500 [02:02<01:40,  2.25it/s]
 55%|#####5    | 276/500 [02:02<01:39,  2.25it/s]
 55%|#####5    | 277/500 [02:03<01:39,  2.25it/s]
 56%|#####5    | 278/500 [02:03<01:38,  2.25it/s]
 56%|#####5    | 279/500 [02:04<01:38,  2.25it/s]
 56%|#####6    | 280/500 [02:04<01:37,  2.25it/s]
 56%|#####6    | 281/500 [02:04<01:37,  2.25it/s]
 56%|#####6    | 282/500 [02:05<01:36,  2.25it/s]
 57%|#####6    | 283/500 [02:05<01:36,  2.25it/s]
 57%|#####6    | 284/500 [02:06<01:36,  2.25it/s]
 57%|#####6    | 285/500 [02:06<01:35,  2.25it/s]
 57%|#####7    | 286/500 [02:07<01:35,  2.25it/s]
 57%|#####7    | 287/500 [02:07<01:34,  2.25it/s]
 58%|#####7    | 288/500 [02:08<01:34,  2.25it/s]
 58%|#####7    | 289/500 [02:08<01:33,  2.25it/s]
 58%|#####8    | 290/500 [02:08<01:33,  2.25it/s]
 58%|#####8    | 291/500 [02:09<01:32,  2.25it/s]
 58%|#####8    | 292/500 [02:09<01:32,  2.25it/s]
 59%|#####8    | 293/500 [02:10<01:32,  2.25it/s]
 59%|#####8    | 294/500 [02:10<01:31,  2.25it/s]
 59%|#####8    | 295/500 [02:11<01:31,  2.25it/s]
 59%|#####9    | 296/500 [02:11<01:30,  2.25it/s]
 59%|#####9    | 297/500 [02:12<01:30,  2.25it/s]
 60%|#####9    | 298/500 [02:12<01:29,  2.25it/s]
 60%|#####9    | 299/500 [02:12<01:29,  2.25it/s]
 60%|######    | 300/500 [02:13<01:28,  2.25it/s]

{'loss': 1.5276, 'grad_norm': 11.940499305725098, 'learning_rate': 5e-05, 'epoch': 0.11}

 60%|######    | 300/500 [02:13<01:28,  2.25it/s]
 60%|######    | 301/500 [02:13<01:28,  2.25it/s]
 60%|######    | 302/500 [02:14<01:28,  2.25it/s]
 61%|######    | 303/500 [02:14<01:27,  2.25it/s]
 61%|######    | 304/500 [02:15<01:27,  2.25it/s]
 61%|######1   | 305/500 [02:15<01:26,  2.25it/s]
 61%|######1   | 306/500 [02:16<01:26,  2.25it/s]
 61%|######1   | 307/500 [02:16<01:25,  2.25it/s]
 62%|######1   | 308/500 [02:17<01:25,  2.25it/s]
 62%|######1   | 309/500 [02:17<01:24,  2.25it/s]
 62%|######2   | 310/500 [02:17<01:24,  2.25it/s]
 62%|######2   | 311/500 [02:18<01:24,  2.25it/s]
 62%|######2   | 312/500 [02:18<01:23,  2.25it/s]
 63%|######2   | 313/500 [02:19<01:23,  2.25it/s]
 63%|######2   | 314/500 [02:19<01:22,  2.25it/s]
 63%|######3   | 315/500 [02:20<01:22,  2.25it/s]
 63%|######3   | 316/500 [02:20<01:21,  2.25it/s]
 63%|######3   | 317/500 [02:21<01:21,  2.25it/s]
 64%|######3   | 318/500 [02:21<01:20,  2.25it/s]
 64%|######3   | 319/500 [02:21<01:20,  2.25it/s]
 64%|######4   | 320/500 [02:22<01:20,  2.25it/s]
 64%|######4   | 321/500 [02:22<01:19,  2.25it/s]
 64%|######4   | 322/500 [02:23<01:19,  2.25it/s]
 65%|######4   | 323/500 [02:23<01:18,  2.25it/s]
 65%|######4   | 324/500 [02:24<01:18,  2.25it/s]
 65%|######5   | 325/500 [02:24<01:17,  2.25it/s]
 65%|######5   | 326/500 [02:25<01:17,  2.25it/s]
 65%|######5   | 327/500 [02:25<01:16,  2.25it/s]
 66%|######5   | 328/500 [02:25<01:16,  2.25it/s]
 66%|######5   | 329/500 [02:26<01:16,  2.25it/s]
 66%|######6   | 330/500 [02:26<01:15,  2.25it/s]
 66%|######6   | 331/500 [02:27<01:15,  2.25it/s]
 66%|######6   | 332/500 [02:27<01:14,  2.25it/s]
 67%|######6   | 333/500 [02:28<01:14,  2.25it/s]
 67%|######6   | 334/500 [02:28<01:13,  2.25it/s]
 67%|######7   | 335/500 [02:29<01:13,  2.25it/s]
 67%|######7   | 336/500 [02:29<01:12,  2.25it/s]
 67%|######7   | 337/500 [02:29<01:12,  2.25it/s]
 68%|######7   | 338/500 [02:30<01:12,  2.25it/s]
 68%|######7   | 339/500 [02:30<01:11,  2.25it/s]
 68%|######8   | 340/500 [02:31<01:11,  2.25it/s]
 68%|######8   | 341/500 [02:31<01:10,  2.25it/s]
 68%|######8   | 342/500 [02:32<01:10,  2.25it/s]
 69%|######8   | 343/500 [02:32<01:09,  2.25it/s]
 69%|######8   | 344/500 [02:33<01:09,  2.25it/s]
 69%|######9   | 345/500 [02:33<01:08,  2.25it/s]
 69%|######9   | 346/500 [02:33<01:08,  2.25it/s]
 69%|######9   | 347/500 [02:34<01:08,  2.25it/s]
 70%|######9   | 348/500 [02:34<01:07,  2.25it/s]
 70%|######9   | 349/500 [02:35<01:07,  2.25it/s]
 70%|#######   | 350/500 [02:35<01:06,  2.25it/s]

{'loss': 1.4958, 'grad_norm': 9.553606033325195, 'learning_rate': 5e-05, 'epoch': 0.13}

 70%|#######   | 350/500 [02:35<01:06,  2.25it/s]
 70%|#######   | 351/500 [02:36<01:06,  2.25it/s]
 70%|#######   | 352/500 [02:36<01:05,  2.25it/s]
 71%|#######   | 353/500 [02:37<01:05,  2.25it/s]
 71%|#######   | 354/500 [02:37<01:04,  2.25it/s]
 71%|#######1  | 355/500 [02:37<01:04,  2.25it/s]
 71%|#######1  | 356/500 [02:38<01:04,  2.25it/s]
 71%|#######1  | 357/500 [02:38<01:03,  2.25it/s]
 72%|#######1  | 358/500 [02:39<01:03,  2.25it/s]
 72%|#######1  | 359/500 [02:39<01:02,  2.25it/s]
 72%|#######2  | 360/500 [02:40<01:02,  2.25it/s]
 72%|#######2  | 361/500 [02:40<01:01,  2.25it/s]
 72%|#######2  | 362/500 [02:41<01:01,  2.25it/s]
 73%|#######2  | 363/500 [02:41<01:00,  2.25it/s]
 73%|#######2  | 364/500 [02:41<01:00,  2.25it/s]
 73%|#######3  | 365/500 [02:42<01:00,  2.25it/s]
 73%|#######3  | 366/500 [02:42<00:59,  2.25it/s]
 73%|#######3  | 367/500 [02:43<00:59,  2.25it/s]
 74%|#######3  | 368/500 [02:43<00:58,  2.25it/s]
 74%|#######3  | 369/500 [02:44<00:58,  2.25it/s]
 74%|#######4  | 370/500 [02:44<00:57,  2.25it/s]
 74%|#######4  | 371/500 [02:45<00:57,  2.25it/s]
 74%|#######4  | 372/500 [02:45<00:56,  2.25it/s]
 75%|#######4  | 373/500 [02:45<00:56,  2.25it/s]
 75%|#######4  | 374/500 [02:46<00:56,  2.25it/s]
 75%|#######5  | 375/500 [02:46<00:55,  2.25it/s]
 75%|#######5  | 376/500 [02:47<00:55,  2.25it/s]
 75%|#######5  | 377/500 [02:47<00:54,  2.25it/s]
 76%|#######5  | 378/500 [02:48<00:54,  2.25it/s]
 76%|#######5  | 379/500 [02:48<00:53,  2.25it/s]
 76%|#######6  | 380/500 [02:49<00:53,  2.25it/s]
 76%|#######6  | 381/500 [02:49<00:52,  2.25it/s]
 76%|#######6  | 382/500 [02:49<00:52,  2.25it/s]
 77%|#######6  | 383/500 [02:50<00:52,  2.25it/s]
 77%|#######6  | 384/500 [02:50<00:51,  2.25it/s]
 77%|#######7  | 385/500 [02:51<00:51,  2.25it/s]
 77%|#######7  | 386/500 [02:51<00:50,  2.25it/s]
 77%|#######7  | 387/500 [02:52<00:50,  2.25it/s]
 78%|#######7  | 388/500 [02:52<00:49,  2.25it/s]
 78%|#######7  | 389/500 [02:53<00:49,  2.25it/s]
 78%|#######8  | 390/500 [02:53<00:48,  2.25it/s]
 78%|#######8  | 391/500 [02:53<00:48,  2.25it/s]
 78%|#######8  | 392/500 [02:54<00:48,  2.25it/s]
 79%|#######8  | 393/500 [02:54<00:47,  2.25it/s]
 79%|#######8  | 394/500 [02:55<00:47,  2.25it/s]
 79%|#######9  | 395/500 [02:55<00:46,  2.25it/s]
 79%|#######9  | 396/500 [02:56<00:46,  2.25it/s]
 79%|#######9  | 397/500 [02:56<00:45,  2.25it/s]
 80%|#######9  | 398/500 [02:57<00:45,  2.25it/s]
 80%|#######9  | 399/500 [02:57<00:44,  2.25it/s]
 80%|########  | 400/500 [02:57<00:44,  2.25it/s]

{'loss': 1.3849, 'grad_norm': 10.62413501739502, 'learning_rate': 5e-05, 'epoch': 0.15}

 80%|########  | 400/500 [02:57<00:44,  2.25it/s]
 80%|########  | 401/500 [02:58<00:44,  2.25it/s]
 80%|########  | 402/500 [02:58<00:43,  2.25it/s]
 81%|########  | 403/500 [02:59<00:43,  2.25it/s]
 81%|########  | 404/500 [02:59<00:42,  2.25it/s]
 81%|########1 | 405/500 [03:00<00:42,  2.25it/s]
 81%|########1 | 406/500 [03:00<00:41,  2.25it/s]
 81%|########1 | 407/500 [03:01<00:41,  2.25it/s]
 82%|########1 | 408/500 [03:01<00:40,  2.25it/s]
 82%|########1 | 409/500 [03:01<00:40,  2.25it/s]
 82%|########2 | 410/500 [03:02<00:40,  2.25it/s]
 82%|########2 | 411/500 [03:02<00:39,  2.25it/s]
 82%|########2 | 412/500 [03:03<00:39,  2.25it/s]
 83%|########2 | 413/500 [03:03<00:38,  2.25it/s]
 83%|########2 | 414/500 [03:04<00:38,  2.25it/s]
 83%|########2 | 415/500 [03:04<00:37,  2.25it/s]
 83%|########3 | 416/500 [03:05<00:37,  2.25it/s]
 83%|########3 | 417/500 [03:05<00:36,  2.25it/s]
 84%|########3 | 418/500 [03:05<00:36,  2.25it/s]
 84%|########3 | 419/500 [03:06<00:36,  2.25it/s]
 84%|########4 | 420/500 [03:06<00:35,  2.25it/s]
 84%|########4 | 421/500 [03:07<00:35,  2.25it/s]
 84%|########4 | 422/500 [03:07<00:34,  2.25it/s]
 85%|########4 | 423/500 [03:08<00:34,  2.25it/s]
 85%|########4 | 424/500 [03:08<00:33,  2.25it/s]
 85%|########5 | 425/500 [03:09<00:33,  2.25it/s]
 85%|########5 | 426/500 [03:09<00:32,  2.25it/s]
 85%|########5 | 427/500 [03:09<00:32,  2.25it/s]
 86%|########5 | 428/500 [03:10<00:32,  2.25it/s]
 86%|########5 | 429/500 [03:10<00:31,  2.25it/s]
 86%|########6 | 430/500 [03:11<00:31,  2.25it/s]
 86%|########6 | 431/500 [03:11<00:30,  2.25it/s]
 86%|########6 | 432/500 [03:12<00:30,  2.25it/s]
 87%|########6 | 433/500 [03:12<00:29,  2.25it/s]
 87%|########6 | 434/500 [03:13<00:29,  2.25it/s]
 87%|########7 | 435/500 [03:13<00:28,  2.25it/s]
 87%|########7 | 436/500 [03:13<00:28,  2.25it/s]
 87%|########7 | 437/500 [03:14<00:28,  2.25it/s]
 88%|########7 | 438/500 [03:14<00:27,  2.25it/s]
 88%|########7 | 439/500 [03:15<00:27,  2.25it/s]
 88%|########8 | 440/500 [03:15<00:26,  2.25it/s]
 88%|########8 | 441/500 [03:16<00:26,  2.25it/s]
 88%|########8 | 442/500 [03:16<00:25,  2.25it/s]
 89%|########8 | 443/500 [03:17<00:25,  2.25it/s]
 89%|########8 | 444/500 [03:17<00:24,  2.25it/s]
 89%|########9 | 445/500 [03:17<00:24,  2.25it/s]
 89%|########9 | 446/500 [03:18<00:24,  2.25it/s]
 89%|########9 | 447/500 [03:18<00:23,  2.25it/s]
 90%|########9 | 448/500 [03:19<00:23,  2.25it/s]
 90%|########9 | 449/500 [03:19<00:22,  2.25it/s]
 90%|######### | 450/500 [03:20<00:22,  2.25it/s]

{'loss': 1.3429, 'grad_norm': 13.479365348815918, 'learning_rate': 5e-05, 'epoch': 0.16}

 90%|######### | 450/500 [03:20<00:22,  2.25it/s]
 90%|######### | 451/500 [03:20<00:21,  2.25it/s]
 90%|######### | 452/500 [03:21<00:21,  2.25it/s]
 91%|######### | 453/500 [03:21<00:20,  2.25it/s]
 91%|######### | 454/500 [03:21<00:20,  2.25it/s]
 91%|#########1| 455/500 [03:22<00:20,  2.25it/s]
 91%|#########1| 456/500 [03:22<00:19,  2.25it/s]
 91%|#########1| 457/500 [03:23<00:19,  2.25it/s]
 92%|#########1| 458/500 [03:23<00:18,  2.25it/s]
 92%|#########1| 459/500 [03:24<00:18,  2.25it/s]
 92%|#########2| 460/500 [03:24<00:17,  2.25it/s]
 92%|#########2| 461/500 [03:25<00:17,  2.25it/s]
 92%|#########2| 462/500 [03:25<00:16,  2.25it/s]
 93%|#########2| 463/500 [03:25<00:16,  2.25it/s]
 93%|#########2| 464/500 [03:26<00:16,  2.25it/s]
 93%|#########3| 465/500 [03:26<00:15,  2.25it/s]
 93%|#########3| 466/500 [03:27<00:15,  2.25it/s]
 93%|#########3| 467/500 [03:27<00:14,  2.25it/s]
 94%|#########3| 468/500 [03:28<00:14,  2.25it/s]
 94%|#########3| 469/500 [03:28<00:13,  2.25it/s]
 94%|#########3| 470/500 [03:29<00:13,  2.25it/s]
 94%|#########4| 471/500 [03:29<00:12,  2.25it/s]
 94%|#########4| 472/500 [03:29<00:12,  2.25it/s]
 95%|#########4| 473/500 [03:30<00:12,  2.25it/s]
 95%|#########4| 474/500 [03:30<00:11,  2.25it/s]
 95%|#########5| 475/500 [03:31<00:11,  2.25it/s]
 95%|#########5| 476/500 [03:31<00:10,  2.25it/s]
 95%|#########5| 477/500 [03:32<00:10,  2.25it/s]
 96%|#########5| 478/500 [03:32<00:09,  2.25it/s]
 96%|#########5| 479/500 [03:33<00:09,  2.25it/s]
 96%|#########6| 480/500 [03:33<00:08,  2.25it/s]
 96%|#########6| 481/500 [03:33<00:08,  2.25it/s]
 96%|#########6| 482/500 [03:34<00:08,  2.25it/s]
 97%|#########6| 483/500 [03:34<00:07,  2.25it/s]
 97%|#########6| 484/500 [03:35<00:07,  2.25it/s]
 97%|#########7| 485/500 [03:35<00:06,  2.25it/s]
 97%|#########7| 486/500 [03:36<00:06,  2.25it/s]
 97%|#########7| 487/500 [03:36<00:05,  2.25it/s]
 98%|#########7| 488/500 [03:37<00:05,  2.25it/s]
 98%|#########7| 489/500 [03:37<00:04,  2.25it/s]
 98%|#########8| 490/500 [03:37<00:04,  2.25it/s]
 98%|#########8| 491/500 [03:38<00:04,  2.25it/s]
 98%|#########8| 492/500 [03:38<00:03,  2.25it/s]
 99%|#########8| 493/500 [03:39<00:03,  2.25it/s]
 99%|#########8| 494/500 [03:39<00:02,  2.25it/s]
 99%|#########9| 495/500 [03:40<00:02,  2.25it/s]
 99%|#########9| 496/500 [03:40<00:01,  2.25it/s]
 99%|#########9| 497/500 [03:41<00:01,  2.25it/s]
100%|#########9| 498/500 [03:41<00:00,  2.25it/s]
100%|#########9| 499/500 [03:41<00:00,  2.25it/s]
100%|##########| 500/500 [03:42<00:00,  2.25it/s]

{'loss': 1.3365, 'grad_norm': 14.161598205566406, 'learning_rate': 5e-05, 'epoch': 0.18}

100%|##########| 500/500 [03:42<00:00,  2.25it/s]

{'train_runtime': 223.7842, 'train_samples_per_second': 71.497, 'train_steps_per_second': 2.234, 'train_loss': 1.8486361694335938, 'epoch': 0.18}

100%|##########| 500/500 [03:43<00:00,  2.25it/s]
100%|##########| 500/500 [03:43<00:00,  2.23it/s]

  0%|          | 0/43 [00:00<?, ?it/s]
  5%|4         | 2/43 [00:00<00:12,  3.20it/s]
  7%|6         | 3/43 [00:01<00:17,  2.26it/s]
  9%|9         | 4/43 [00:01<00:19,  1.96it/s]
 12%|#1        | 5/43 [00:02<00:20,  1.82it/s]
 14%|#3        | 6/43 [00:03<00:21,  1.74it/s]
 16%|#6        | 7/43 [00:03<00:21,  1.69it/s]
 19%|#8        | 8/43 [00:04<00:21,  1.66it/s]
 21%|##        | 9/43 [00:05<00:20,  1.64it/s]
 23%|##3       | 10/43 [00:05<00:20,  1.63it/s]
 26%|##5       | 11/43 [00:06<00:19,  1.62it/s]
 28%|##7       | 12/43 [00:06<00:19,  1.61it/s]
 30%|###       | 13/43 [00:07<00:18,  1.61it/s]
 33%|###2      | 14/43 [00:08<00:18,  1.60it/s]
 35%|###4      | 15/43 [00:08<00:17,  1.60it/s]
 37%|###7      | 16/43 [00:09<00:16,  1.60it/s]
 40%|###9      | 17/43 [00:10<00:16,  1.60it/s]
 42%|####1     | 18/43 [00:10<00:15,  1.60it/s]
 44%|####4     | 19/43 [00:11<00:15,  1.60it/s]
 47%|####6     | 20/43 [00:11<00:14,  1.60it/s]
 49%|####8     | 21/43 [00:12<00:13,  1.60it/s]
 51%|#####1    | 22/43 [00:13<00:13,  1.60it/s]
 53%|#####3    | 23/43 [00:13<00:12,  1.60it/s]
 56%|#####5    | 24/43 [00:14<00:11,  1.60it/s]
 58%|#####8    | 25/43 [00:15<00:11,  1.60it/s]
 60%|######    | 26/43 [00:15<00:10,  1.60it/s]
 63%|######2   | 27/43 [00:16<00:10,  1.60it/s]
 65%|######5   | 28/43 [00:16<00:09,  1.60it/s]
 67%|######7   | 29/43 [00:17<00:08,  1.60it/s]
 70%|######9   | 30/43 [00:18<00:08,  1.60it/s]
 72%|#######2  | 31/43 [00:18<00:07,  1.60it/s]
 74%|#######4  | 32/43 [00:19<00:06,  1.60it/s]
 77%|#######6  | 33/43 [00:20<00:06,  1.60it/s]
 79%|#######9  | 34/43 [00:20<00:05,  1.60it/s]
 81%|########1 | 35/43 [00:21<00:05,  1.60it/s]
 84%|########3 | 36/43 [00:21<00:04,  1.60it/s]
 86%|########6 | 37/43 [00:22<00:03,  1.60it/s]
 88%|########8 | 38/43 [00:23<00:03,  1.60it/s]
 91%|######### | 39/43 [00:23<00:02,  1.59it/s]
 93%|#########3| 40/43 [00:24<00:01,  1.59it/s]
 95%|#########5| 41/43 [00:25<00:01,  1.59it/s]
 98%|#########7| 42/43 [00:25<00:00,  1.70it/s]
100%|##########| 43/43 [00:25<00:00,  1.68it/s]

Downloading builder script:   0%|          | 0.00/4.53k [00:00<?, ?B/s]
Downloading builder script: 100%|##########| 4.53k/4.53k [00:00<00:00, 34.0MB/s]

Downloading extra modules:   0%|          | 0.00/3.32k [00:00<?, ?B/s]
Downloading extra modules: 100%|##########| 3.32k/3.32k [00:00<00:00, 35.4MB/s]
fp16 {'exact_match': 71.19205298013244, 'f1': 80.46294958468431}
cuda_fp16 time {4: 9.291296300002614, '4_compile': 9.214334000034796, 16: 30.575215200042294, '16_compile': 31.7835579999155, 64: 118.95021049986099, '64_compile': 109.09820949996174, 256: 457.8732750001109, '256_compile': 414.86550299987357}

<Axes: title={'center': 'Loss vs. # steps'}, xlabel='step', ylabel='loss'>

将 BERT 剪枝为 2:4 稀疏

现在我们有了基线,是时候修剪 BERT 了。有许多不同的剪枝策略,但最常见的一种是幅度剪枝,它旨在删除 L1 范数最低的权重。NVIDIA 在其所有结果中都使用了幅度剪枝,这是一种常见的基线。

为此,我们将使用 torch.ao.pruning 包,其中包含权重范数(幅度)稀疏器。这些稀疏器通过将掩码参数化应用于模型中的权重张量来工作。这使它们可以通过屏蔽掉修剪的权重来模拟稀疏性。

我们还需要决定将稀疏性应用于模型的哪些层,在本例中是所有 nn.Linear 层,除了特定于任务的头部输出。这是因为半结构化稀疏性具有 形状约束,并且特定于任务的 nn.Linear 层不满足这些约束。

sparsifier = WeightNormSparsifier(
    # apply sparsity to all blocks
    sparsity_level=1.0,
    # shape of 4 elements is a block
    sparse_block_shape=(1, 4),
    # two zeros for every block of 4
    zeros_per_block=2
)

# add to config if ``nn.Linear`` and in the BERT model.
sparse_config = [
    {"tensor_fqn": f"{fqn}.weight"}
    for fqn, module in model.named_modules()
    if isinstance(module, nn.Linear) and "layer" in fqn
]

剪枝模型的第一步是插入参数化以掩码模型的权重。这通过 prepare 步骤完成。任何时候我们尝试访问 .weight,我们都会得到 mask * weight

# Prepare the model, insert fake-sparsity parametrizations for training
sparsifier.prepare(model, sparse_config)
print(model.bert.encoder.layer[0].output)
BertOutput(
  (dense): ParametrizedLinear(
    in_features=3072, out_features=768, bias=True
    (parametrizations): ModuleDict(
      (weight): ParametrizationList(
        (0): FakeSparsity()
      )
    )
  )
  (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
  (dropout): Dropout(p=0.1, inplace=False)
)

然后,我们将执行单个剪枝步骤。所有剪枝器都实现了一个 update_mask() 方法,该方法使用由剪枝器实现确定的逻辑更新掩码。step 方法为稀疏配置中指定的权重调用此 update_mask 函数。

我们还将评估模型,以显示零样本剪枝或不进行微调/重新训练的剪枝的精度下降。

sparsifier.step()
with torch.autocast("cuda"):
    with torch.no_grad():
        predictions = trainer.predict(tokenized_squad_dataset["validation"])
    pruned = compute_metrics(
        *predictions.predictions,
        tokenized_squad_dataset["validation"],
        squad_dataset["validation"],
    )
print("pruned eval metrics:", pruned)
  0%|          | 0/43 [00:00<?, ?it/s]
  5%|4         | 2/43 [00:00<00:12,  3.18it/s]
  7%|6         | 3/43 [00:01<00:17,  2.24it/s]
  9%|9         | 4/43 [00:01<00:20,  1.95it/s]
 12%|#1        | 5/43 [00:02<00:21,  1.81it/s]
 14%|#3        | 6/43 [00:03<00:21,  1.73it/s]
 16%|#6        | 7/43 [00:03<00:21,  1.68it/s]
 19%|#8        | 8/43 [00:04<00:21,  1.65it/s]
 21%|##        | 9/43 [00:05<00:20,  1.63it/s]
 23%|##3       | 10/43 [00:05<00:20,  1.62it/s]
 26%|##5       | 11/43 [00:06<00:19,  1.61it/s]
 28%|##7       | 12/43 [00:06<00:19,  1.61it/s]
 30%|###       | 13/43 [00:07<00:18,  1.60it/s]
 33%|###2      | 14/43 [00:08<00:18,  1.60it/s]
 35%|###4      | 15/43 [00:08<00:17,  1.60it/s]
 37%|###7      | 16/43 [00:09<00:16,  1.60it/s]
 40%|###9      | 17/43 [00:10<00:16,  1.59it/s]
 42%|####1     | 18/43 [00:10<00:15,  1.59it/s]
 44%|####4     | 19/43 [00:11<00:15,  1.59it/s]
 47%|####6     | 20/43 [00:11<00:14,  1.59it/s]
 49%|####8     | 21/43 [00:12<00:13,  1.59it/s]
 51%|#####1    | 22/43 [00:13<00:13,  1.59it/s]
 53%|#####3    | 23/43 [00:13<00:12,  1.59it/s]
 56%|#####5    | 24/43 [00:14<00:11,  1.59it/s]
 58%|#####8    | 25/43 [00:15<00:11,  1.59it/s]
 60%|######    | 26/43 [00:15<00:10,  1.59it/s]
 63%|######2   | 27/43 [00:16<00:10,  1.59it/s]
 65%|######5   | 28/43 [00:16<00:09,  1.59it/s]
 67%|######7   | 29/43 [00:17<00:08,  1.59it/s]
 70%|######9   | 30/43 [00:18<00:08,  1.59it/s]
 72%|#######2  | 31/43 [00:18<00:07,  1.59it/s]
 74%|#######4  | 32/43 [00:19<00:06,  1.59it/s]
 77%|#######6  | 33/43 [00:20<00:06,  1.59it/s]
 79%|#######9  | 34/43 [00:20<00:05,  1.59it/s]
 81%|########1 | 35/43 [00:21<00:05,  1.59it/s]
 84%|########3 | 36/43 [00:21<00:04,  1.59it/s]
 86%|########6 | 37/43 [00:22<00:03,  1.59it/s]
 88%|########8 | 38/43 [00:23<00:03,  1.59it/s]
 91%|######### | 39/43 [00:23<00:02,  1.59it/s]
 93%|#########3| 40/43 [00:24<00:01,  1.59it/s]
 95%|#########5| 41/43 [00:25<00:01,  1.59it/s]
 98%|#########7| 42/43 [00:25<00:00,  1.70it/s]
100%|##########| 43/43 [00:25<00:00,  1.67it/s]
pruned eval metrics: {'exact_match': 30.047303689687794, 'f1': 41.41343319253715}

在这种状态下,我们可以开始微调模型,更新不会被修剪的元素,以便更好地弥补精度损失。一旦我们达到满意的状态,我们就可以调用 squash_mask 来融合掩码和权重。这将删除参数化,我们最终得到一个归零的 2:4 密集模型。

trainer.train()
sparsifier.squash_mask()
torch.set_printoptions(edgeitems=4)
print(model.bert.encoder.layer[0].intermediate.dense.weight[:8, :8])

df["sparse_loss"] = pd.DataFrame(trainer.state.log_history)["loss"]
df.plot.line(x='step', y=["loss", "sparse_loss"], title="Loss vs. # steps", ylabel="loss")
Loss vs. # steps
  0%|          | 0/500 [00:00<?, ?it/s]
  0%|          | 1/500 [00:00<04:04,  2.04it/s]
  0%|          | 2/500 [00:00<03:52,  2.14it/s]
  1%|          | 3/500 [00:01<03:48,  2.18it/s]
  1%|          | 4/500 [00:01<03:46,  2.19it/s]
  1%|1         | 5/500 [00:02<03:44,  2.20it/s]
  1%|1         | 6/500 [00:02<03:43,  2.21it/s]
  1%|1         | 7/500 [00:03<03:42,  2.21it/s]
  2%|1         | 8/500 [00:03<03:42,  2.22it/s]
  2%|1         | 9/500 [00:04<03:41,  2.22it/s]
  2%|2         | 10/500 [00:04<03:40,  2.22it/s]
  2%|2         | 11/500 [00:04<03:40,  2.22it/s]
  2%|2         | 12/500 [00:05<03:39,  2.22it/s]
  3%|2         | 13/500 [00:05<03:39,  2.22it/s]
  3%|2         | 14/500 [00:06<03:38,  2.22it/s]
  3%|3         | 15/500 [00:06<03:38,  2.22it/s]
  3%|3         | 16/500 [00:07<03:37,  2.22it/s]
  3%|3         | 17/500 [00:07<03:37,  2.22it/s]
  4%|3         | 18/500 [00:08<03:37,  2.22it/s]
  4%|3         | 19/500 [00:08<03:36,  2.22it/s]
  4%|4         | 20/500 [00:09<03:36,  2.22it/s]
  4%|4         | 21/500 [00:09<03:35,  2.22it/s]
  4%|4         | 22/500 [00:09<03:35,  2.22it/s]
  5%|4         | 23/500 [00:10<03:34,  2.22it/s]
  5%|4         | 24/500 [00:10<03:34,  2.22it/s]
  5%|5         | 25/500 [00:11<03:33,  2.22it/s]
  5%|5         | 26/500 [00:11<03:33,  2.22it/s]
  5%|5         | 27/500 [00:12<03:33,  2.22it/s]
  6%|5         | 28/500 [00:12<03:32,  2.22it/s]
  6%|5         | 29/500 [00:13<03:32,  2.22it/s]
  6%|6         | 30/500 [00:13<03:31,  2.22it/s]
  6%|6         | 31/500 [00:14<03:31,  2.22it/s]
  6%|6         | 32/500 [00:14<03:30,  2.22it/s]
  7%|6         | 33/500 [00:14<03:30,  2.22it/s]
  7%|6         | 34/500 [00:15<03:29,  2.22it/s]
  7%|7         | 35/500 [00:15<03:29,  2.22it/s]
  7%|7         | 36/500 [00:16<03:28,  2.22it/s]
  7%|7         | 37/500 [00:16<03:28,  2.22it/s]
  8%|7         | 38/500 [00:17<03:28,  2.22it/s]
  8%|7         | 39/500 [00:17<03:27,  2.22it/s]
  8%|8         | 40/500 [00:18<03:27,  2.22it/s]
  8%|8         | 41/500 [00:18<03:26,  2.22it/s]
  8%|8         | 42/500 [00:18<03:26,  2.22it/s]
  9%|8         | 43/500 [00:19<03:25,  2.22it/s]
  9%|8         | 44/500 [00:19<03:25,  2.22it/s]
  9%|9         | 45/500 [00:20<03:25,  2.22it/s]
  9%|9         | 46/500 [00:20<03:24,  2.22it/s]
  9%|9         | 47/500 [00:21<03:24,  2.22it/s]
 10%|9         | 48/500 [00:21<03:23,  2.22it/s]
 10%|9         | 49/500 [00:22<03:23,  2.22it/s]
 10%|#         | 50/500 [00:22<03:22,  2.22it/s]

{'loss': 1.8834, 'grad_norm': 11.247416496276855, 'learning_rate': 5e-05, 'epoch': 0.02}

 10%|#         | 50/500 [00:22<03:22,  2.22it/s]
 10%|#         | 51/500 [00:23<03:22,  2.22it/s]
 10%|#         | 52/500 [00:23<03:21,  2.22it/s]
 11%|#         | 53/500 [00:23<03:21,  2.22it/s]
 11%|#         | 54/500 [00:24<03:20,  2.22it/s]
 11%|#1        | 55/500 [00:24<03:20,  2.22it/s]
 11%|#1        | 56/500 [00:25<03:20,  2.22it/s]
 11%|#1        | 57/500 [00:25<03:19,  2.22it/s]
 12%|#1        | 58/500 [00:26<03:19,  2.22it/s]
 12%|#1        | 59/500 [00:26<03:18,  2.22it/s]
 12%|#2        | 60/500 [00:27<03:18,  2.22it/s]
 12%|#2        | 61/500 [00:27<03:17,  2.22it/s]
 12%|#2        | 62/500 [00:27<03:17,  2.22it/s]
 13%|#2        | 63/500 [00:28<03:16,  2.22it/s]
 13%|#2        | 64/500 [00:28<03:16,  2.22it/s]
 13%|#3        | 65/500 [00:29<03:16,  2.22it/s]
 13%|#3        | 66/500 [00:29<03:15,  2.22it/s]
 13%|#3        | 67/500 [00:30<03:15,  2.22it/s]
 14%|#3        | 68/500 [00:30<03:14,  2.22it/s]
 14%|#3        | 69/500 [00:31<03:14,  2.22it/s]
 14%|#4        | 70/500 [00:31<03:13,  2.22it/s]
 14%|#4        | 71/500 [00:32<03:13,  2.22it/s]
 14%|#4        | 72/500 [00:32<03:12,  2.22it/s]
 15%|#4        | 73/500 [00:32<03:12,  2.22it/s]
 15%|#4        | 74/500 [00:33<03:11,  2.22it/s]
 15%|#5        | 75/500 [00:33<03:11,  2.22it/s]
 15%|#5        | 76/500 [00:34<03:11,  2.22it/s]
 15%|#5        | 77/500 [00:34<03:10,  2.22it/s]
 16%|#5        | 78/500 [00:35<03:10,  2.22it/s]
 16%|#5        | 79/500 [00:35<03:09,  2.22it/s]
 16%|#6        | 80/500 [00:36<03:09,  2.22it/s]
 16%|#6        | 81/500 [00:36<03:08,  2.22it/s]
 16%|#6        | 82/500 [00:36<03:08,  2.22it/s]
 17%|#6        | 83/500 [00:37<03:07,  2.22it/s]
 17%|#6        | 84/500 [00:37<03:07,  2.22it/s]
 17%|#7        | 85/500 [00:38<03:06,  2.22it/s]
 17%|#7        | 86/500 [00:38<03:06,  2.22it/s]
 17%|#7        | 87/500 [00:39<03:06,  2.22it/s]
 18%|#7        | 88/500 [00:39<03:05,  2.22it/s]
 18%|#7        | 89/500 [00:40<03:05,  2.22it/s]
 18%|#8        | 90/500 [00:40<03:04,  2.22it/s]
 18%|#8        | 91/500 [00:41<03:04,  2.22it/s]
 18%|#8        | 92/500 [00:41<03:03,  2.22it/s]
 19%|#8        | 93/500 [00:41<03:03,  2.22it/s]
 19%|#8        | 94/500 [00:42<03:02,  2.22it/s]
 19%|#9        | 95/500 [00:42<03:02,  2.22it/s]
 19%|#9        | 96/500 [00:43<03:02,  2.22it/s]
 19%|#9        | 97/500 [00:43<03:01,  2.22it/s]
 20%|#9        | 98/500 [00:44<03:01,  2.22it/s]
 20%|#9        | 99/500 [00:44<03:00,  2.22it/s]
 20%|##        | 100/500 [00:45<03:00,  2.22it/s]

{'loss': 1.4046, 'grad_norm': 9.31086254119873, 'learning_rate': 5e-05, 'epoch': 0.04}

 20%|##        | 100/500 [00:45<03:00,  2.22it/s]
 20%|##        | 101/500 [00:45<02:59,  2.22it/s]
 20%|##        | 102/500 [00:45<02:59,  2.22it/s]
 21%|##        | 103/500 [00:46<02:58,  2.22it/s]
 21%|##        | 104/500 [00:46<02:58,  2.22it/s]
 21%|##1       | 105/500 [00:47<02:57,  2.22it/s]
 21%|##1       | 106/500 [00:47<02:57,  2.22it/s]
 21%|##1       | 107/500 [00:48<02:57,  2.22it/s]
 22%|##1       | 108/500 [00:48<02:56,  2.22it/s]
 22%|##1       | 109/500 [00:49<02:56,  2.22it/s]
 22%|##2       | 110/500 [00:49<02:55,  2.22it/s]
 22%|##2       | 111/500 [00:50<02:55,  2.22it/s]
 22%|##2       | 112/500 [00:50<02:54,  2.22it/s]
 23%|##2       | 113/500 [00:50<02:54,  2.22it/s]
 23%|##2       | 114/500 [00:51<02:53,  2.22it/s]
 23%|##3       | 115/500 [00:51<02:53,  2.22it/s]
 23%|##3       | 116/500 [00:52<02:53,  2.22it/s]
 23%|##3       | 117/500 [00:52<02:52,  2.22it/s]
 24%|##3       | 118/500 [00:53<02:52,  2.22it/s]
 24%|##3       | 119/500 [00:53<02:51,  2.22it/s]
 24%|##4       | 120/500 [00:54<02:51,  2.22it/s]
 24%|##4       | 121/500 [00:54<02:50,  2.22it/s]
 24%|##4       | 122/500 [00:55<02:50,  2.22it/s]
 25%|##4       | 123/500 [00:55<02:49,  2.22it/s]
 25%|##4       | 124/500 [00:55<02:49,  2.22it/s]
 25%|##5       | 125/500 [00:56<02:48,  2.22it/s]
 25%|##5       | 126/500 [00:56<02:48,  2.22it/s]
 25%|##5       | 127/500 [00:57<02:48,  2.22it/s]
 26%|##5       | 128/500 [00:57<02:47,  2.22it/s]
 26%|##5       | 129/500 [00:58<02:47,  2.22it/s]
 26%|##6       | 130/500 [00:58<02:46,  2.22it/s]
 26%|##6       | 131/500 [00:59<02:46,  2.22it/s]
 26%|##6       | 132/500 [00:59<02:45,  2.22it/s]
 27%|##6       | 133/500 [00:59<02:45,  2.22it/s]
 27%|##6       | 134/500 [01:00<02:44,  2.22it/s]
 27%|##7       | 135/500 [01:00<02:44,  2.22it/s]
 27%|##7       | 136/500 [01:01<02:43,  2.22it/s]
 27%|##7       | 137/500 [01:01<02:43,  2.22it/s]
 28%|##7       | 138/500 [01:02<02:43,  2.22it/s]
 28%|##7       | 139/500 [01:02<02:42,  2.22it/s]
 28%|##8       | 140/500 [01:03<02:42,  2.22it/s]
 28%|##8       | 141/500 [01:03<02:41,  2.22it/s]
 28%|##8       | 142/500 [01:04<02:41,  2.22it/s]
 29%|##8       | 143/500 [01:04<02:40,  2.22it/s]
 29%|##8       | 144/500 [01:04<02:40,  2.22it/s]
 29%|##9       | 145/500 [01:05<02:39,  2.22it/s]
 29%|##9       | 146/500 [01:05<02:39,  2.22it/s]
 29%|##9       | 147/500 [01:06<02:39,  2.22it/s]
 30%|##9       | 148/500 [01:06<02:38,  2.22it/s]
 30%|##9       | 149/500 [01:07<02:38,  2.22it/s]
 30%|###       | 150/500 [01:07<02:37,  2.22it/s]

{'loss': 1.1801, 'grad_norm': 9.222478866577148, 'learning_rate': 5e-05, 'epoch': 0.05}

 30%|###       | 150/500 [01:07<02:37,  2.22it/s]
 30%|###       | 151/500 [01:08<02:37,  2.22it/s]
 30%|###       | 152/500 [01:08<02:37,  2.22it/s]
 31%|###       | 153/500 [01:08<02:36,  2.22it/s]
 31%|###       | 154/500 [01:09<02:35,  2.22it/s]
 31%|###1      | 155/500 [01:09<02:35,  2.22it/s]
 31%|###1      | 156/500 [01:10<02:35,  2.22it/s]
 31%|###1      | 157/500 [01:10<02:34,  2.22it/s]
 32%|###1      | 158/500 [01:11<02:34,  2.22it/s]
 32%|###1      | 159/500 [01:11<02:33,  2.22it/s]
 32%|###2      | 160/500 [01:12<02:33,  2.22it/s]
 32%|###2      | 161/500 [01:12<02:32,  2.22it/s]
 32%|###2      | 162/500 [01:13<02:32,  2.22it/s]
 33%|###2      | 163/500 [01:13<02:31,  2.22it/s]
 33%|###2      | 164/500 [01:13<02:31,  2.22it/s]
 33%|###3      | 165/500 [01:14<02:30,  2.22it/s]
 33%|###3      | 166/500 [01:14<02:30,  2.22it/s]
 33%|###3      | 167/500 [01:15<02:30,  2.22it/s]
 34%|###3      | 168/500 [01:15<02:29,  2.22it/s]
 34%|###3      | 169/500 [01:16<02:29,  2.22it/s]
 34%|###4      | 170/500 [01:16<02:28,  2.22it/s]
 34%|###4      | 171/500 [01:17<02:28,  2.22it/s]
 34%|###4      | 172/500 [01:17<02:27,  2.22it/s]
 35%|###4      | 173/500 [01:17<02:27,  2.22it/s]
 35%|###4      | 174/500 [01:18<02:26,  2.22it/s]
 35%|###5      | 175/500 [01:18<02:26,  2.22it/s]
 35%|###5      | 176/500 [01:19<02:25,  2.22it/s]
 35%|###5      | 177/500 [01:19<02:25,  2.22it/s]
 36%|###5      | 178/500 [01:20<02:25,  2.22it/s]
 36%|###5      | 179/500 [01:20<02:24,  2.22it/s]
 36%|###6      | 180/500 [01:21<02:24,  2.22it/s]
 36%|###6      | 181/500 [01:21<02:23,  2.22it/s]
 36%|###6      | 182/500 [01:22<02:23,  2.22it/s]
 37%|###6      | 183/500 [01:22<02:22,  2.22it/s]
 37%|###6      | 184/500 [01:22<02:22,  2.22it/s]
 37%|###7      | 185/500 [01:23<02:21,  2.22it/s]
 37%|###7      | 186/500 [01:23<02:21,  2.22it/s]
 37%|###7      | 187/500 [01:24<02:20,  2.22it/s]
 38%|###7      | 188/500 [01:24<02:20,  2.22it/s]
 38%|###7      | 189/500 [01:25<02:20,  2.22it/s]
 38%|###8      | 190/500 [01:25<02:19,  2.22it/s]
 38%|###8      | 191/500 [01:26<02:19,  2.22it/s]
 38%|###8      | 192/500 [01:26<02:18,  2.22it/s]
 39%|###8      | 193/500 [01:26<02:18,  2.22it/s]
 39%|###8      | 194/500 [01:27<02:17,  2.22it/s]
 39%|###9      | 195/500 [01:27<02:17,  2.22it/s]
 39%|###9      | 196/500 [01:28<02:16,  2.22it/s]
 39%|###9      | 197/500 [01:28<02:16,  2.22it/s]
 40%|###9      | 198/500 [01:29<02:16,  2.22it/s]
 40%|###9      | 199/500 [01:29<02:15,  2.22it/s]
 40%|####      | 200/500 [01:30<02:15,  2.22it/s]

{'loss': 1.2227, 'grad_norm': 7.899211406707764, 'learning_rate': 5e-05, 'epoch': 0.07}

 40%|####      | 200/500 [01:30<02:15,  2.22it/s]
 40%|####      | 201/500 [01:30<02:14,  2.22it/s]
 40%|####      | 202/500 [01:31<02:14,  2.22it/s]
 41%|####      | 203/500 [01:31<02:13,  2.22it/s]
 41%|####      | 204/500 [01:31<02:13,  2.22it/s]
 41%|####1     | 205/500 [01:32<02:12,  2.22it/s]
 41%|####1     | 206/500 [01:32<02:12,  2.22it/s]
 41%|####1     | 207/500 [01:33<02:12,  2.22it/s]
 42%|####1     | 208/500 [01:33<02:11,  2.22it/s]
 42%|####1     | 209/500 [01:34<02:11,  2.22it/s]
 42%|####2     | 210/500 [01:34<02:10,  2.22it/s]
 42%|####2     | 211/500 [01:35<02:10,  2.22it/s]
 42%|####2     | 212/500 [01:35<02:09,  2.22it/s]
 43%|####2     | 213/500 [01:36<02:09,  2.22it/s]
 43%|####2     | 214/500 [01:36<02:08,  2.22it/s]
 43%|####3     | 215/500 [01:36<02:08,  2.22it/s]
 43%|####3     | 216/500 [01:37<02:07,  2.22it/s]
 43%|####3     | 217/500 [01:37<02:07,  2.22it/s]
 44%|####3     | 218/500 [01:38<02:07,  2.22it/s]
 44%|####3     | 219/500 [01:38<02:06,  2.22it/s]
 44%|####4     | 220/500 [01:39<02:06,  2.22it/s]
 44%|####4     | 221/500 [01:39<02:05,  2.22it/s]
 44%|####4     | 222/500 [01:40<02:05,  2.22it/s]
 45%|####4     | 223/500 [01:40<02:04,  2.22it/s]
 45%|####4     | 224/500 [01:40<02:04,  2.22it/s]
 45%|####5     | 225/500 [01:41<02:03,  2.22it/s]
 45%|####5     | 226/500 [01:41<02:03,  2.22it/s]
 45%|####5     | 227/500 [01:42<02:02,  2.22it/s]
 46%|####5     | 228/500 [01:42<02:02,  2.22it/s]
 46%|####5     | 229/500 [01:43<02:02,  2.22it/s]
 46%|####6     | 230/500 [01:43<02:01,  2.22it/s]
 46%|####6     | 231/500 [01:44<02:01,  2.22it/s]
 46%|####6     | 232/500 [01:44<02:00,  2.22it/s]
 47%|####6     | 233/500 [01:45<02:00,  2.22it/s]
 47%|####6     | 234/500 [01:45<01:59,  2.22it/s]
 47%|####6     | 235/500 [01:45<01:59,  2.22it/s]
 47%|####7     | 236/500 [01:46<01:58,  2.22it/s]
 47%|####7     | 237/500 [01:46<01:58,  2.22it/s]
 48%|####7     | 238/500 [01:47<01:58,  2.22it/s]
 48%|####7     | 239/500 [01:47<01:57,  2.22it/s]
 48%|####8     | 240/500 [01:48<01:57,  2.22it/s]
 48%|####8     | 241/500 [01:48<01:56,  2.22it/s]
 48%|####8     | 242/500 [01:49<01:56,  2.22it/s]
 49%|####8     | 243/500 [01:49<01:55,  2.22it/s]
 49%|####8     | 244/500 [01:49<01:55,  2.22it/s]
 49%|####9     | 245/500 [01:50<01:54,  2.22it/s]
 49%|####9     | 246/500 [01:50<01:54,  2.22it/s]
 49%|####9     | 247/500 [01:51<01:53,  2.22it/s]
 50%|####9     | 248/500 [01:51<01:53,  2.22it/s]
 50%|####9     | 249/500 [01:52<01:53,  2.22it/s]
 50%|#####     | 250/500 [01:52<01:52,  2.22it/s]

{'loss': 1.1098, 'grad_norm': 7.073291778564453, 'learning_rate': 5e-05, 'epoch': 0.09}

 50%|#####     | 250/500 [01:52<01:52,  2.22it/s]
 50%|#####     | 251/500 [01:53<01:52,  2.22it/s]
 50%|#####     | 252/500 [01:53<01:51,  2.22it/s]
 51%|#####     | 253/500 [01:54<01:51,  2.22it/s]
 51%|#####     | 254/500 [01:54<01:50,  2.22it/s]
 51%|#####1    | 255/500 [01:54<01:50,  2.22it/s]
 51%|#####1    | 256/500 [01:55<01:49,  2.22it/s]
 51%|#####1    | 257/500 [01:55<01:49,  2.22it/s]
 52%|#####1    | 258/500 [01:56<01:48,  2.22it/s]
 52%|#####1    | 259/500 [01:56<01:48,  2.22it/s]
 52%|#####2    | 260/500 [01:57<01:48,  2.22it/s]
 52%|#####2    | 261/500 [01:57<01:47,  2.22it/s]
 52%|#####2    | 262/500 [01:58<01:47,  2.22it/s]
 53%|#####2    | 263/500 [01:58<01:46,  2.22it/s]
 53%|#####2    | 264/500 [01:58<01:46,  2.22it/s]
 53%|#####3    | 265/500 [01:59<01:45,  2.22it/s]
 53%|#####3    | 266/500 [01:59<01:45,  2.22it/s]
 53%|#####3    | 267/500 [02:00<01:45,  2.22it/s]
 54%|#####3    | 268/500 [02:00<01:44,  2.22it/s]
 54%|#####3    | 269/500 [02:01<01:44,  2.22it/s]
 54%|#####4    | 270/500 [02:01<01:43,  2.22it/s]
 54%|#####4    | 271/500 [02:02<01:43,  2.22it/s]
 54%|#####4    | 272/500 [02:02<01:42,  2.22it/s]
 55%|#####4    | 273/500 [02:03<01:42,  2.22it/s]
 55%|#####4    | 274/500 [02:03<01:41,  2.22it/s]
 55%|#####5    | 275/500 [02:03<01:41,  2.22it/s]
 55%|#####5    | 276/500 [02:04<01:40,  2.22it/s]
 55%|#####5    | 277/500 [02:04<01:40,  2.22it/s]
 56%|#####5    | 278/500 [02:05<01:40,  2.22it/s]
 56%|#####5    | 279/500 [02:05<01:39,  2.22it/s]
 56%|#####6    | 280/500 [02:06<01:39,  2.22it/s]
 56%|#####6    | 281/500 [02:06<01:38,  2.22it/s]
 56%|#####6    | 282/500 [02:07<01:38,  2.22it/s]
 57%|#####6    | 283/500 [02:07<01:37,  2.22it/s]
 57%|#####6    | 284/500 [02:07<01:37,  2.22it/s]
 57%|#####6    | 285/500 [02:08<01:36,  2.22it/s]
 57%|#####7    | 286/500 [02:08<01:36,  2.22it/s]
 57%|#####7    | 287/500 [02:09<01:35,  2.22it/s]
 58%|#####7    | 288/500 [02:09<01:35,  2.22it/s]
 58%|#####7    | 289/500 [02:10<01:35,  2.22it/s]
 58%|#####8    | 290/500 [02:10<01:34,  2.22it/s]
 58%|#####8    | 291/500 [02:11<01:34,  2.22it/s]
 58%|#####8    | 292/500 [02:11<01:33,  2.22it/s]
 59%|#####8    | 293/500 [02:12<01:33,  2.22it/s]
 59%|#####8    | 294/500 [02:12<01:32,  2.22it/s]
 59%|#####8    | 295/500 [02:12<01:32,  2.22it/s]
 59%|#####9    | 296/500 [02:13<01:31,  2.22it/s]
 59%|#####9    | 297/500 [02:13<01:31,  2.22it/s]
 60%|#####9    | 298/500 [02:14<01:30,  2.22it/s]
 60%|#####9    | 299/500 [02:14<01:30,  2.22it/s]
 60%|######    | 300/500 [02:15<01:30,  2.22it/s]

{'loss': 1.1338, 'grad_norm': 8.028770446777344, 'learning_rate': 5e-05, 'epoch': 0.11}

 60%|######    | 300/500 [02:15<01:30,  2.22it/s]
 60%|######    | 301/500 [02:15<01:29,  2.22it/s]
 60%|######    | 302/500 [02:16<01:29,  2.22it/s]
 61%|######    | 303/500 [02:16<01:28,  2.22it/s]
 61%|######    | 304/500 [02:17<01:28,  2.22it/s]
 61%|######1   | 305/500 [02:17<01:27,  2.22it/s]
 61%|######1   | 306/500 [02:17<01:27,  2.22it/s]
 61%|######1   | 307/500 [02:18<01:26,  2.22it/s]
 62%|######1   | 308/500 [02:18<01:26,  2.22it/s]
 62%|######1   | 309/500 [02:19<01:26,  2.22it/s]
 62%|######2   | 310/500 [02:19<01:25,  2.22it/s]
 62%|######2   | 311/500 [02:20<01:25,  2.22it/s]
 62%|######2   | 312/500 [02:20<01:24,  2.22it/s]
 63%|######2   | 313/500 [02:21<01:24,  2.22it/s]
 63%|######2   | 314/500 [02:21<01:23,  2.22it/s]
 63%|######3   | 315/500 [02:21<01:23,  2.22it/s]
 63%|######3   | 316/500 [02:22<01:22,  2.22it/s]
 63%|######3   | 317/500 [02:22<01:22,  2.22it/s]
 64%|######3   | 318/500 [02:23<01:21,  2.22it/s]
 64%|######3   | 319/500 [02:23<01:21,  2.22it/s]
 64%|######4   | 320/500 [02:24<01:21,  2.22it/s]
 64%|######4   | 321/500 [02:24<01:20,  2.22it/s]
 64%|######4   | 322/500 [02:25<01:20,  2.22it/s]
 65%|######4   | 323/500 [02:25<01:19,  2.22it/s]
 65%|######4   | 324/500 [02:26<01:19,  2.22it/s]
 65%|######5   | 325/500 [02:26<01:18,  2.22it/s]
 65%|######5   | 326/500 [02:26<01:18,  2.22it/s]
 65%|######5   | 327/500 [02:27<01:17,  2.22it/s]
 66%|######5   | 328/500 [02:27<01:17,  2.22it/s]
 66%|######5   | 329/500 [02:28<01:17,  2.22it/s]
 66%|######6   | 330/500 [02:28<01:16,  2.22it/s]
 66%|######6   | 331/500 [02:29<01:16,  2.22it/s]
 66%|######6   | 332/500 [02:29<01:15,  2.22it/s]
 67%|######6   | 333/500 [02:30<01:15,  2.22it/s]
 67%|######6   | 334/500 [02:30<01:14,  2.22it/s]
 67%|######7   | 335/500 [02:30<01:14,  2.22it/s]
 67%|######7   | 336/500 [02:31<01:13,  2.22it/s]
 67%|######7   | 337/500 [02:31<01:13,  2.22it/s]
 68%|######7   | 338/500 [02:32<01:12,  2.22it/s]
 68%|######7   | 339/500 [02:32<01:12,  2.22it/s]
 68%|######8   | 340/500 [02:33<01:12,  2.22it/s]
 68%|######8   | 341/500 [02:33<01:11,  2.22it/s]
 68%|######8   | 342/500 [02:34<01:11,  2.22it/s]
 69%|######8   | 343/500 [02:34<01:10,  2.22it/s]
 69%|######8   | 344/500 [02:35<01:10,  2.22it/s]
 69%|######9   | 345/500 [02:35<01:09,  2.22it/s]
 69%|######9   | 346/500 [02:35<01:09,  2.22it/s]
 69%|######9   | 347/500 [02:36<01:08,  2.22it/s]
 70%|######9   | 348/500 [02:36<01:08,  2.22it/s]
 70%|######9   | 349/500 [02:37<01:08,  2.22it/s]
 70%|#######   | 350/500 [02:37<01:07,  2.22it/s]

{'loss': 1.1054, 'grad_norm': 8.8486909866333, 'learning_rate': 5e-05, 'epoch': 0.13}

 70%|#######   | 350/500 [02:37<01:07,  2.22it/s]
 70%|#######   | 351/500 [02:38<01:07,  2.22it/s]
 70%|#######   | 352/500 [02:38<01:06,  2.22it/s]
 71%|#######   | 353/500 [02:39<01:06,  2.22it/s]
 71%|#######   | 354/500 [02:39<01:05,  2.22it/s]
 71%|#######1  | 355/500 [02:39<01:05,  2.22it/s]
 71%|#######1  | 356/500 [02:40<01:04,  2.22it/s]
 71%|#######1  | 357/500 [02:40<01:04,  2.22it/s]
 72%|#######1  | 358/500 [02:41<01:03,  2.22it/s]
 72%|#######1  | 359/500 [02:41<01:03,  2.22it/s]
 72%|#######2  | 360/500 [02:42<01:03,  2.22it/s]
 72%|#######2  | 361/500 [02:42<01:02,  2.22it/s]
 72%|#######2  | 362/500 [02:43<01:02,  2.22it/s]
 73%|#######2  | 363/500 [02:43<01:01,  2.22it/s]
 73%|#######2  | 364/500 [02:44<01:01,  2.22it/s]
 73%|#######3  | 365/500 [02:44<01:00,  2.22it/s]
 73%|#######3  | 366/500 [02:44<01:00,  2.22it/s]
 73%|#######3  | 367/500 [02:45<00:59,  2.22it/s]
 74%|#######3  | 368/500 [02:45<00:59,  2.22it/s]
 74%|#######3  | 369/500 [02:46<00:59,  2.22it/s]
 74%|#######4  | 370/500 [02:46<00:58,  2.22it/s]
 74%|#######4  | 371/500 [02:47<00:58,  2.22it/s]
 74%|#######4  | 372/500 [02:47<00:57,  2.22it/s]
 75%|#######4  | 373/500 [02:48<00:57,  2.22it/s]
 75%|#######4  | 374/500 [02:48<00:56,  2.22it/s]
 75%|#######5  | 375/500 [02:48<00:56,  2.22it/s]
 75%|#######5  | 376/500 [02:49<00:55,  2.22it/s]
 75%|#######5  | 377/500 [02:49<00:55,  2.22it/s]
 76%|#######5  | 378/500 [02:50<00:54,  2.22it/s]
 76%|#######5  | 379/500 [02:50<00:54,  2.22it/s]
 76%|#######6  | 380/500 [02:51<00:54,  2.22it/s]
 76%|#######6  | 381/500 [02:51<00:53,  2.22it/s]
 76%|#######6  | 382/500 [02:52<00:53,  2.22it/s]
 77%|#######6  | 383/500 [02:52<00:52,  2.22it/s]
 77%|#######6  | 384/500 [02:53<00:52,  2.22it/s]
 77%|#######7  | 385/500 [02:53<00:51,  2.22it/s]
 77%|#######7  | 386/500 [02:53<00:51,  2.22it/s]
 77%|#######7  | 387/500 [02:54<00:50,  2.22it/s]
 78%|#######7  | 388/500 [02:54<00:50,  2.22it/s]
 78%|#######7  | 389/500 [02:55<00:50,  2.22it/s]
 78%|#######8  | 390/500 [02:55<00:49,  2.22it/s]
 78%|#######8  | 391/500 [02:56<00:49,  2.22it/s]
 78%|#######8  | 392/500 [02:56<00:48,  2.22it/s]
 79%|#######8  | 393/500 [02:57<00:48,  2.22it/s]
 79%|#######8  | 394/500 [02:57<00:47,  2.22it/s]
 79%|#######9  | 395/500 [02:58<00:47,  2.22it/s]
 79%|#######9  | 396/500 [02:58<00:46,  2.22it/s]
 79%|#######9  | 397/500 [02:58<00:46,  2.22it/s]
 80%|#######9  | 398/500 [02:59<00:45,  2.22it/s]
 80%|#######9  | 399/500 [02:59<00:45,  2.22it/s]
 80%|########  | 400/500 [03:00<00:45,  2.22it/s]

{'loss': 1.0149, 'grad_norm': 8.551692008972168, 'learning_rate': 5e-05, 'epoch': 0.15}

 80%|########  | 400/500 [03:00<00:45,  2.22it/s]
 80%|########  | 401/500 [03:00<00:44,  2.22it/s]
 80%|########  | 402/500 [03:01<00:44,  2.22it/s]
 81%|########  | 403/500 [03:01<00:43,  2.22it/s]
 81%|########  | 404/500 [03:02<00:43,  2.22it/s]
 81%|########1 | 405/500 [03:02<00:42,  2.22it/s]
 81%|########1 | 406/500 [03:02<00:42,  2.22it/s]
 81%|########1 | 407/500 [03:03<00:41,  2.22it/s]
 82%|########1 | 408/500 [03:03<00:41,  2.22it/s]
 82%|########1 | 409/500 [03:04<00:41,  2.22it/s]
 82%|########2 | 410/500 [03:04<00:40,  2.22it/s]
 82%|########2 | 411/500 [03:05<00:40,  2.22it/s]
 82%|########2 | 412/500 [03:05<00:39,  2.22it/s]
 83%|########2 | 413/500 [03:06<00:39,  2.22it/s]
 83%|########2 | 414/500 [03:06<00:38,  2.22it/s]
 83%|########2 | 415/500 [03:07<00:38,  2.22it/s]
 83%|########3 | 416/500 [03:07<00:37,  2.22it/s]
 83%|########3 | 417/500 [03:07<00:37,  2.22it/s]
 84%|########3 | 418/500 [03:08<00:36,  2.22it/s]
 84%|########3 | 419/500 [03:08<00:36,  2.22it/s]
 84%|########4 | 420/500 [03:09<00:36,  2.22it/s]
 84%|########4 | 421/500 [03:09<00:35,  2.22it/s]
 84%|########4 | 422/500 [03:10<00:35,  2.22it/s]
 85%|########4 | 423/500 [03:10<00:34,  2.22it/s]
 85%|########4 | 424/500 [03:11<00:34,  2.22it/s]
 85%|########5 | 425/500 [03:11<00:33,  2.22it/s]
 85%|########5 | 426/500 [03:11<00:33,  2.22it/s]
 85%|########5 | 427/500 [03:12<00:32,  2.22it/s]
 86%|########5 | 428/500 [03:12<00:32,  2.22it/s]
 86%|########5 | 429/500 [03:13<00:31,  2.22it/s]
 86%|########6 | 430/500 [03:13<00:31,  2.22it/s]
 86%|########6 | 431/500 [03:14<00:31,  2.22it/s]
 86%|########6 | 432/500 [03:14<00:30,  2.22it/s]
 87%|########6 | 433/500 [03:15<00:30,  2.22it/s]
 87%|########6 | 434/500 [03:15<00:29,  2.22it/s]
 87%|########7 | 435/500 [03:16<00:29,  2.22it/s]
 87%|########7 | 436/500 [03:16<00:28,  2.22it/s]
 87%|########7 | 437/500 [03:16<00:28,  2.22it/s]
 88%|########7 | 438/500 [03:17<00:27,  2.22it/s]
 88%|########7 | 439/500 [03:17<00:27,  2.22it/s]
 88%|########8 | 440/500 [03:18<00:27,  2.22it/s]
 88%|########8 | 441/500 [03:18<00:26,  2.22it/s]
 88%|########8 | 442/500 [03:19<00:26,  2.22it/s]
 89%|########8 | 443/500 [03:19<00:25,  2.22it/s]
 89%|########8 | 444/500 [03:20<00:25,  2.22it/s]
 89%|########9 | 445/500 [03:20<00:24,  2.22it/s]
 89%|########9 | 446/500 [03:20<00:24,  2.22it/s]
 89%|########9 | 447/500 [03:21<00:23,  2.22it/s]
 90%|########9 | 448/500 [03:21<00:23,  2.22it/s]
 90%|########9 | 449/500 [03:22<00:22,  2.22it/s]
 90%|######### | 450/500 [03:22<00:22,  2.22it/s]

{'loss': 0.9959, 'grad_norm': 8.505718231201172, 'learning_rate': 5e-05, 'epoch': 0.16}

 90%|######### | 450/500 [03:22<00:22,  2.22it/s]
 90%|######### | 451/500 [03:23<00:22,  2.22it/s]
 90%|######### | 452/500 [03:23<00:21,  2.22it/s]
 91%|######### | 453/500 [03:24<00:21,  2.22it/s]
 91%|######### | 454/500 [03:24<00:20,  2.22it/s]
 91%|#########1| 455/500 [03:25<00:20,  2.22it/s]
 91%|#########1| 456/500 [03:25<00:19,  2.22it/s]
 91%|#########1| 457/500 [03:25<00:19,  2.22it/s]
 92%|#########1| 458/500 [03:26<00:18,  2.22it/s]
 92%|#########1| 459/500 [03:26<00:18,  2.22it/s]
 92%|#########2| 460/500 [03:27<00:18,  2.22it/s]
 92%|#########2| 461/500 [03:27<00:17,  2.22it/s]
 92%|#########2| 462/500 [03:28<00:17,  2.22it/s]
 93%|#########2| 463/500 [03:28<00:16,  2.22it/s]
 93%|#########2| 464/500 [03:29<00:16,  2.22it/s]
 93%|#########3| 465/500 [03:29<00:15,  2.22it/s]
 93%|#########3| 466/500 [03:29<00:15,  2.22it/s]
 93%|#########3| 467/500 [03:30<00:14,  2.22it/s]
 94%|#########3| 468/500 [03:30<00:14,  2.22it/s]
 94%|#########3| 469/500 [03:31<00:13,  2.22it/s]
 94%|#########3| 470/500 [03:31<00:13,  2.22it/s]
 94%|#########4| 471/500 [03:32<00:13,  2.22it/s]
 94%|#########4| 472/500 [03:32<00:12,  2.22it/s]
 95%|#########4| 473/500 [03:33<00:12,  2.22it/s]
 95%|#########4| 474/500 [03:33<00:11,  2.22it/s]
 95%|#########5| 475/500 [03:34<00:11,  2.22it/s]
 95%|#########5| 476/500 [03:34<00:10,  2.22it/s]
 95%|#########5| 477/500 [03:34<00:10,  2.22it/s]
 96%|#########5| 478/500 [03:35<00:09,  2.22it/s]
 96%|#########5| 479/500 [03:35<00:09,  2.22it/s]
 96%|#########6| 480/500 [03:36<00:09,  2.22it/s]
 96%|#########6| 481/500 [03:36<00:08,  2.22it/s]
 96%|#########6| 482/500 [03:37<00:08,  2.22it/s]
 97%|#########6| 483/500 [03:37<00:07,  2.22it/s]
 97%|#########6| 484/500 [03:38<00:07,  2.22it/s]
 97%|#########7| 485/500 [03:38<00:06,  2.22it/s]
 97%|#########7| 486/500 [03:39<00:06,  2.22it/s]
 97%|#########7| 487/500 [03:39<00:05,  2.22it/s]
 98%|#########7| 488/500 [03:39<00:05,  2.22it/s]
 98%|#########7| 489/500 [03:40<00:04,  2.22it/s]
 98%|#########8| 490/500 [03:40<00:04,  2.22it/s]
 98%|#########8| 491/500 [03:41<00:04,  2.22it/s]
 98%|#########8| 492/500 [03:41<00:03,  2.22it/s]
 99%|#########8| 493/500 [03:42<00:03,  2.22it/s]
 99%|#########8| 494/500 [03:42<00:02,  2.22it/s]
 99%|#########9| 495/500 [03:43<00:02,  2.22it/s]
 99%|#########9| 496/500 [03:43<00:01,  2.22it/s]
 99%|#########9| 497/500 [03:43<00:01,  2.22it/s]
100%|#########9| 498/500 [03:44<00:00,  2.22it/s]
100%|#########9| 499/500 [03:44<00:00,  2.22it/s]
100%|##########| 500/500 [03:45<00:00,  2.22it/s]

{'loss': 1.0049, 'grad_norm': 10.322616577148438, 'learning_rate': 5e-05, 'epoch': 0.18}

100%|##########| 500/500 [03:45<00:00,  2.22it/s]

{'train_runtime': 234.7101, 'train_samples_per_second': 68.169, 'train_steps_per_second': 2.13, 'train_loss': 1.2055460739135742, 'epoch': 0.18}

100%|##########| 500/500 [03:54<00:00,  2.22it/s]
100%|##########| 500/500 [03:54<00:00,  2.13it/s]
tensor([[ 0.0000, -0.0175,  0.0000,  0.0106,  0.0000,  0.0312,  0.0000, -0.0742],
        [ 0.0407, -0.0000, -0.0000,  0.0499,  0.0354, -0.0000, -0.0000, -0.0178],
        [-0.0335, -0.0328,  0.0000,  0.0000, -0.0000, -0.0000,  0.0572, -0.0189],
        [ 0.0000, -0.0000, -0.0499,  0.0291,  0.0966,  0.0000, -0.0000, -0.0766],
        [ 0.0089, -0.0000,  0.0000, -0.0400, -0.0000,  0.0188,  0.0658, -0.0000],
        [-0.0000, -0.0313,  0.0000,  0.0316,  0.0400, -0.0178,  0.0000, -0.0000],
        [-0.1075, -0.0287,  0.0000, -0.0000,  0.0424,  0.0655, -0.0000, -0.0000],
        [-0.0000,  0.0300, -0.0652, -0.0000,  0.0311, -0.0000,  0.0000, -0.0595]],
       device='cuda:0', grad_fn=<SliceBackward0>)

<Axes: title={'center': 'Loss vs. # steps'}, xlabel='step', ylabel='loss'>

加速 2:4 稀疏模型以进行推理

现在我们有了这种格式的模型,我们可以像快速入门指南中一样加速它以进行推理。

model = model.cuda().half()
# accelerate for sparsity
for fqn, module in model.named_modules():
    if isinstance(module, nn.Linear) and "layer" in fqn:
        module.weight = nn.Parameter(to_sparse_semi_structured(module.weight))

with torch.no_grad():
    predictions = trainer.predict(tokenized_squad_dataset["validation"])
start_logits, end_logits = predictions.predictions
metrics_sparse = compute_metrics(
    start_logits,
    end_logits,
    tokenized_squad_dataset["validation"],
    squad_dataset["validation"],
)
print("sparse eval metrics: ", metrics_sparse)
sparse_perf = measure_execution_time(
    model,
    batch_sizes,
    tokenized_squad_dataset["validation"],
)
print("sparse perf metrics: ", sparse_perf)
  0%|          | 0/43 [00:00<?, ?it/s]
  5%|4         | 2/43 [00:01<00:27,  1.48it/s]
  7%|6         | 3/43 [00:02<00:38,  1.04it/s]
  9%|9         | 4/43 [00:04<00:43,  1.10s/it]
 12%|#1        | 5/43 [00:05<00:45,  1.19s/it]
 14%|#3        | 6/43 [00:06<00:46,  1.24s/it]
 16%|#6        | 7/43 [00:08<00:46,  1.28s/it]
 19%|#8        | 8/43 [00:09<00:45,  1.30s/it]
 21%|##        | 9/43 [00:10<00:44,  1.32s/it]
 23%|##3       | 10/43 [00:12<00:43,  1.33s/it]
 26%|##5       | 11/43 [00:13<00:42,  1.34s/it]
 28%|##7       | 12/43 [00:14<00:41,  1.34s/it]
 30%|###       | 13/43 [00:16<00:40,  1.34s/it]
 33%|###2      | 14/43 [00:17<00:39,  1.35s/it]
 35%|###4      | 15/43 [00:18<00:37,  1.35s/it]
 37%|###7      | 16/43 [00:20<00:36,  1.35s/it]
 40%|###9      | 17/43 [00:21<00:35,  1.35s/it]
 42%|####1     | 18/43 [00:22<00:33,  1.35s/it]
 44%|####4     | 19/43 [00:24<00:32,  1.35s/it]
 47%|####6     | 20/43 [00:25<00:31,  1.35s/it]
 49%|####8     | 21/43 [00:27<00:29,  1.35s/it]
 51%|#####1    | 22/43 [00:28<00:28,  1.35s/it]
 53%|#####3    | 23/43 [00:29<00:27,  1.35s/it]
 56%|#####5    | 24/43 [00:31<00:25,  1.35s/it]
 58%|#####8    | 25/43 [00:32<00:24,  1.35s/it]
 60%|######    | 26/43 [00:33<00:23,  1.35s/it]
 63%|######2   | 27/43 [00:35<00:21,  1.35s/it]
 65%|######5   | 28/43 [00:36<00:20,  1.35s/it]
 67%|######7   | 29/43 [00:37<00:18,  1.35s/it]
 70%|######9   | 30/43 [00:39<00:17,  1.35s/it]
 72%|#######2  | 31/43 [00:40<00:16,  1.35s/it]
 74%|#######4  | 32/43 [00:41<00:14,  1.35s/it]
 77%|#######6  | 33/43 [00:43<00:13,  1.35s/it]
 79%|#######9  | 34/43 [00:44<00:12,  1.35s/it]
 81%|########1 | 35/43 [00:45<00:10,  1.35s/it]
 84%|########3 | 36/43 [00:47<00:09,  1.35s/it]
 86%|########6 | 37/43 [00:48<00:08,  1.35s/it]
 88%|########8 | 38/43 [00:50<00:06,  1.35s/it]
 91%|######### | 39/43 [00:51<00:05,  1.35s/it]
 93%|#########3| 40/43 [00:52<00:04,  1.35s/it]
 95%|#########5| 41/43 [00:54<00:02,  1.35s/it]
 98%|#########7| 42/43 [00:55<00:01,  1.31s/it]
100%|##########| 43/43 [00:55<00:00,  1.03it/s]
100%|##########| 43/43 [00:55<00:00,  1.29s/it]
sparse eval metrics:  {'exact_match': 71.27719962157049, 'f1': 80.40876597032333}
sparse perf metrics:  {4: 16.22369025001262, '4_compile': 9.579042000041227, 16: 60.472320200005925, '16_compile': 34.016872900019735, 64: 238.72479699957694, '64_compile': 134.9668625000504, 256: 1182.9201580003428, '256_compile': 550.6339859998661}

幅度剪枝后重新训练我们的模型已恢复了模型剪枝时丢失的几乎所有 F1。同时,我们实现了 bs=16 时 1.28 倍的加速。请注意,并非所有形状都适合性能改进。当批量大小较小且计算时间有限时,稀疏内核可能比其密集对应项慢。

由于半结构化稀疏性是作为张量子类实现的,因此它与 torch.compile 兼容。当与 to_sparse_semi_structured 结合使用时,我们能够在 BERT 上实现总共 2 倍的加速。

指标

fp16

2:4 稀疏

delta / 加速

已编译

精确匹配率 (%)

78.53

78.44

-0.09

F1 (%)

86.93

86.49

-0.44

时间 (bs=4)

11.10

15.54

0.71x

时间 (bs=16)

19.35

15.74

1.23x

时间 (bs=64)

72.71

59.41

1.22x

时间 (bs=256)

286.65

247.63

1.14x

时间 (bs=4)

7.59

7.46

1.02x

时间 (bs=16)

11.47

9.68

1.18x

时间 (bs=64)

41.57

36.92

1.13x

时间 (bs=256)

159.22

142.23

1.12x

结论

在本教程中,我们展示了如何将 BERT 剪枝为 2:4 稀疏,以及如何加速 2:4 稀疏模型以进行推理。通过利用我们的 SparseSemiStructuredTensor 子类,我们能够实现比 fp16 基线快 1.3 倍的速度,并且在使用 torch.compile 时,速度提升高达 2 倍。我们还通过微调 BERT 来恢复任何损失的 F1(密集型为 86.92,稀疏型为 86.48),证明了 2:4 稀疏性的优势。

脚本的总运行时间: ( 14 分钟 24.477 秒)

由 Sphinx-Gallery 生成的图库

文档

访问 PyTorch 的全面开发者文档

查看文档

教程

获取面向初学者和高级开发者的深入教程

查看教程

资源

查找开发资源并获得问题解答

查看资源