模型描述

来自 Transformer 的双向编码器表示(即 BERT)是一种革命性的自监督预训练技术,通过学习预测文本中有意隐藏(掩码)的部分来训练。关键在于,BERT 学习到的表示已被证明可以很好地泛化到下游任务,并且当 BERT 于 2018 年首次发布时,它在许多 NLP 基准数据集上取得了当时最先进的成果。

RoBERTa 构建于 BERT 的语言掩码策略之上,并修改了 BERT 的关键超参数,包括移除 BERT 的下一句预训练目标,以及使用更大的 mini-batch 和学习率进行训练。RoBERTa 的训练数据量也比 BERT 大一个数量级,训练时间更长。这使得 RoBERTa 的表示相比 BERT 能够更好地泛化到下游任务。

要求

预处理需要一些额外的 Python 依赖项

pip install regex requests hydra-core omegaconf

示例

加载 RoBERTa
import torch
roberta = torch.hub.load('pytorch/fairseq', 'roberta.large')
roberta.eval()  # disable dropout (or leave in train mode to finetune)
对输入文本应用字节对编码 (BPE)
tokens = roberta.encode('Hello world!')
assert tokens.tolist() == [0, 31414, 232, 328, 2]
assert roberta.decode(tokens) == 'Hello world!'
从 RoBERTa 中提取特征
# Extract the last layer's features
last_layer_features = roberta.extract_features(tokens)
assert last_layer_features.size() == torch.Size([1, 5, 1024])

# Extract all layer's features (layer 0 is the embedding layer)
all_layers = roberta.extract_features(tokens, return_all_hiddens=True)
assert len(all_layers) == 25
assert torch.all(all_layers[-1] == last_layer_features)
使用 RoBERTa 进行句子对分类任务
# Download RoBERTa already finetuned for MNLI
roberta = torch.hub.load('pytorch/fairseq', 'roberta.large.mnli')
roberta.eval()  # disable dropout for evaluation

with torch.no_grad():
    # Encode a pair of sentences and make a prediction
    tokens = roberta.encode('Roberta is a heavily optimized version of BERT.', 'Roberta is not very optimized.')
    prediction = roberta.predict('mnli', tokens).argmax().item()
    assert prediction == 0  # contradiction

    # Encode another pair of sentences
    tokens = roberta.encode('Roberta is a heavily optimized version of BERT.', 'Roberta is based on BERT.')
    prediction = roberta.predict('mnli', tokens).argmax().item()
    assert prediction == 2  # entailment
注册一个新的(随机初始化的)分类头
roberta.register_classification_head('new_task', num_classes=3)
logprobs = roberta.predict('new_task', tokens)  # tensor([[-1.1050, -1.0672, -1.1245]], grad_fn=<LogSoftmaxBackward>)

参考文献