• 教程 >
  • (测试版) 在 LSTM 词语言模型上进行动态量化
快捷方式

(Beta) 在 LSTM 词语言模型上进行动态量化

创建于: 2019 年 10 月 07 日 | 最后更新于: 2024 年 8 月 27 日 | 最后验证于: 2024 年 11 月 05 日

作者: James Reed

编辑: Seth Weidman

引言

量化涉及将模型的权重和激活从浮点数转换为整数,这可以减小模型大小并加快推理速度,同时只对精度造成少量影响。

在本教程中,我们将把最简单的量化形式 - 动态量化 - 应用于基于 LSTM 的下一个词预测模型,密切遵循 PyTorch 示例中的词语言模型

# imports
import os
from io import open
import time

import torch
import torch.nn as nn
import torch.nn.functional as F

1. 定义模型

这里我们定义 LSTM 模型架构,遵循词语言模型示例中的模型

class LSTMModel(nn.Module):
    """Container module with an encoder, a recurrent module, and a decoder."""

    def __init__(self, ntoken, ninp, nhid, nlayers, dropout=0.5):
        super(LSTMModel, self).__init__()
        self.drop = nn.Dropout(dropout)
        self.encoder = nn.Embedding(ntoken, ninp)
        self.rnn = nn.LSTM(ninp, nhid, nlayers, dropout=dropout)
        self.decoder = nn.Linear(nhid, ntoken)

        self.init_weights()

        self.nhid = nhid
        self.nlayers = nlayers

    def init_weights(self):
        initrange = 0.1
        self.encoder.weight.data.uniform_(-initrange, initrange)
        self.decoder.bias.data.zero_()
        self.decoder.weight.data.uniform_(-initrange, initrange)

    def forward(self, input, hidden):
        emb = self.drop(self.encoder(input))
        output, hidden = self.rnn(emb, hidden)
        output = self.drop(output)
        decoded = self.decoder(output)
        return decoded, hidden

    def init_hidden(self, bsz):
        weight = next(self.parameters())
        return (weight.new_zeros(self.nlayers, bsz, self.nhid),
                weight.new_zeros(self.nlayers, bsz, self.nhid))

2. 加载文本数据

接下来,我们将 Wikitext-2 数据集加载到一个 Corpus 中,同样遵循词语言模型示例中的预处理

class Dictionary(object):
    def __init__(self):
        self.word2idx = {}
        self.idx2word = []

    def add_word(self, word):
        if word not in self.word2idx:
            self.idx2word.append(word)
            self.word2idx[word] = len(self.idx2word) - 1
        return self.word2idx[word]

    def __len__(self):
        return len(self.idx2word)


class Corpus(object):
    def __init__(self, path):
        self.dictionary = Dictionary()
        self.train = self.tokenize(os.path.join(path, 'train.txt'))
        self.valid = self.tokenize(os.path.join(path, 'valid.txt'))
        self.test = self.tokenize(os.path.join(path, 'test.txt'))

    def tokenize(self, path):
        """Tokenizes a text file."""
        assert os.path.exists(path)
        # Add words to the dictionary
        with open(path, 'r', encoding="utf8") as f:
            for line in f:
                words = line.split() + ['<eos>']
                for word in words:
                    self.dictionary.add_word(word)

        # Tokenize file content
        with open(path, 'r', encoding="utf8") as f:
            idss = []
            for line in f:
                words = line.split() + ['<eos>']
                ids = []
                for word in words:
                    ids.append(self.dictionary.word2idx[word])
                idss.append(torch.tensor(ids).type(torch.int64))
            ids = torch.cat(idss)

        return ids

model_data_filepath = 'data/'

corpus = Corpus(model_data_filepath + 'wikitext-2')

3. 加载预训练模型

这是一个关于动态量化的教程,动态量化是一种在模型训练后应用的量化技术。因此,我们将简单地将一些预训练权重加载到这个模型架构中;这些权重是通过使用词语言模型示例中的默认设置训练五个 epoch 获得的。

ntokens = len(corpus.dictionary)

model = LSTMModel(
    ntoken = ntokens,
    ninp = 512,
    nhid = 256,
    nlayers = 5,
)

model.load_state_dict(
    torch.load(
        model_data_filepath + 'word_language_model_quantize.pth',
        map_location=torch.device('cpu'),
        weights_only=True
        )
    )

model.eval()
print(model)
LSTMModel(
  (drop): Dropout(p=0.5, inplace=False)
  (encoder): Embedding(33278, 512)
  (rnn): LSTM(512, 256, num_layers=5, dropout=0.5)
  (decoder): Linear(in_features=256, out_features=33278, bias=True)
)

现在让我们生成一些文本,以确保预训练模型正常工作 - 与之前类似,我们遵循此处的步骤

input_ = torch.randint(ntokens, (1, 1), dtype=torch.long)
hidden = model.init_hidden(1)
temperature = 1.0
num_words = 1000

with open(model_data_filepath + 'out.txt', 'w') as outf:
    with torch.no_grad():  # no tracking history
        for i in range(num_words):
            output, hidden = model(input_, hidden)
            word_weights = output.squeeze().div(temperature).exp().cpu()
            word_idx = torch.multinomial(word_weights, 1)[0]
            input_.fill_(word_idx)

            word = corpus.dictionary.idx2word[word_idx]

            outf.write(str(word.encode('utf-8')) + ('\n' if i % 20 == 19 else ' '))

            if i % 100 == 0:
                print('| Generated {}/{} words'.format(i, 1000))

with open(model_data_filepath + 'out.txt', 'r') as outf:
    all_output = outf.read()
    print(all_output)
| Generated 0/1000 words
| Generated 100/1000 words
| Generated 200/1000 words
| Generated 300/1000 words
| Generated 400/1000 words
| Generated 500/1000 words
| Generated 600/1000 words
| Generated 700/1000 words
| Generated 800/1000 words
| Generated 900/1000 words
b'and' b'Dulce' b"'s" b'<unk>' b'Typically' b',' b'where' b'they' b'were' b'Montagne' b'and' b'plumes' b'star' b'the' b'same' b'.' b'He' b'came' b'off' b'to'
b'be' b'chosen' b'to' b'arrive' b'former' b'racquets' b',' b'citing' b'just' b'advantage' b':' b'"' b'ESA' b"'ll" b'<unk>' b'"' b'.' b'The' b'game' b'developed'
b'among' b'favored' b'a' b'low' b'party' b'when' b'this' b'is' b'forced' b'to' b'change' b'up' b'when' b'it' b'was' b'a' b'batsman' b'man' b'.' b'Robbie'
b'swears' b'this' b'explains' b'how' b'it' b'might' b'be' b'heard' b'as' b'this' b'peerage' b'shorter' b'go' b',' b'to' b'be' b'a' b'rubber' b'planet' b'"'
b'low' b'"' b'or' b'committed' b'fleets' b'on' b'the' b'other' b'down' b'of' b'Australia' b'.' b'inoffensive' b'components' b'have' b'them' b'work' b',' b'while' b'the'
b'gameplay' b'bodied' b'(' b'not' b'satisfy' b',' b'Loved' b'Beaufort' b'<unk>' b')' b'.' b'The' b'true' b'differences' b'with' b'a' b'positive' b'coupled' b'with' b'beef'
b'motherboard' b',' b'using' b'exhibiting' b'a' b'real' b'decolonisation' b'above' b'all' b'@-@' b'linear' b'years' b'her' b'body' b'underground' b'.' b'Your' b'conversations' b'also' b'observation'
b'links' b'from' b'fate' b'a' b'fugitive' b'that' b'uses' b'it' b'.' b'"' b'powerful' b'three' b'distinct' b'or' b'doubts' b'have' b'happened' b'in' b'so' b'@-@'
b'adult' b'volume' b',' b'surgery' b'in' b'best' b'his' b'arms' b',' b'and' b'your' b'head' b'evoked' b'relatively' b'trapped' b'.' b'"' b'Norse' b'also' b'habitat'
b'finding' b',' b'and' b'thus' b'too' b'seldom' b'considered' b'they' b'know' b'will' b'be' b'like' b'.' b'In' b'this' b'same' b'year' b',' b'it' b'was'
b'not' b'understood' b'to' b'speak' b'down' b'second' b'manuscripts' b'and' b'the' b'low' b'@-@' b'sins' b'Newly' b'video' b'pavement' b',' b'mentioned' b'by' b'<unk>' b'and'
b'<unk>' b'.' b'While' b'abrupt' b',' b'monospaced' b',' b'Nicolas' b'@-@' b"'Connor" b',' b'and' b'are' b'Smile' b',' b'suggested' b'there' b'are' b'existed' b'at'
b'3' b'<unk>' b',' b'but' b'it' b'is' b'unclear' b'the' b'advertising' b'forces' b'of' b'the' b'1500' b'@-@' b'planet' b'was' b'recast' b'.' b'Males' b'suggests'
b'the' b'claim' b'for' b'a' b'courtship' b'@-@' b'shaped' b'colony' b'(' b'this' b'method' b'as' b'Pennsylvania' b'is' b'F\xc3\xb6rster' b')' b'.' b'By' b'18' b'\xc2\xb0'
b'nickel' b'(' b'1101' b')' b',' b'or' b'COs' b',' b'the' b'player' b'were' b'potentially' b'supposed' b'to' b'be' b'she' b'could' b'be' b'said' b'for'
b'things' b'and' b'that' b'the' b'bright' b'evidence' b'takes' b'place' b'by' b'Maryang' b'Punk' b'.' b'<unk>' b'XVI' b'Woodfull' b'of' b'Motion' b',' b'a' b'intended'
b'phenomenon' b'of' b'small' b'acts' b'of' b'clue' b',' b'came' b'for' b'theatrical' b',' b'becomes' b'about' b'to' b'meet' b'energy' b'@-@' b'hand' b'.' b'He'
b'can' b'have' b'upper' b'Astraeus' b'affirmed' b'copulation' b',' b'and' b'fall' b'not' b'alone' b'from' b'condoms' b'into' b'the' b'humans' b'.' b'Increased' b'starlings' b'were'
b'the' b'most' b'recent' b'sangatya' b'pressing' b'.' b'saprotrophic' b'calls' b'may' b'be' b'<unk>' b',' b'meaning' b'using' b'"' b'Patrick' b'"' b'dust' b',' b'<unk>'
b',' b'and' b'their' b'virus' b'.' b'Other' b'women' b'in' b'poor' b'reduction' b'can' b'be' b'burned' b'by' b'radiation' b'inside' b'notes' b'.' b'In' b'the'
b'common' b'fields' b'the' b'human' b'reading' b'is' b'extinct' b'.' b'<eos>' b'Some' b'of' b'the' b'same' b'batteries' b'have' b'disorder' b',' b'as' b'too' b'remains'
b'strum' b'is' b'due' b'to' b'leaf' b',' b'and' b'recorded' b'ocean' b',' b'after' b'it' b'is' b'more' b'difficult' b'.' b'In' b'a' b'large' b'manner'
b'greatly' b',' b'according' b'to' b'other' b'dissent' b',' b'they' b'may' b'be' b'misused' b'against' b'word' b'destruction' b'.' b'registered' b'it' b',' b'the' b'<unk>'
b'nucleosynthesis' b'standards' b'or' b'corporations' b'.' b'showing' b'some' b'other' b',' b'at' b'least' b'five' b'species' b'the' b'power' b'it' b'might' b'be' b'centric' b'with'
b'choosing' b'leaves' b'.' b'Subsequent' b'people' b'have' b'syncopated' b'their' b'anterior' b'broadside' b',' b'places' b'above' b'over' b'every' b'share' b',' b'and' b'their' b'shoulder'
b'can' b'lose' b',' b'foot' b'on' b'the' b'way' b'room' b'.' b'Even' b'males' b',' b'instructs' b'grain' b'.' b'<eos>' b'yellow' b'insects' b'are' b'taken'
b'.' b'Indeed' b',' b'thus' b'predators' b'are' b'more' b'far' b',' b'rather' b'than' b'their' b'subject' b'to' b'their' b'experiments' b'.' b'Pupils' b'are' b'well'
b'totally' b',' b'with' b'adults' b'which' b'have' b'alternatively' b'conditions' b'or' b',' b'their' b'food' b'return' b'advantage' b'broods' b'.' b'At' b'its' b'peak' b'that'
b'their' b'jumps' b'were' b'obtained' b'and' b'reflected' b'other' b'eggs' b'.' b'It' b'also' b'feed' b'to' b'topological' b'acids' b'of' b'275' b'\xe2\x80\x93' b'syphilis' b'.'
b'<eos>' b'Common' b'starlings' b'are' b'fat' b'throughout' b'wake' b'.' b'They' b'lineage' b'found' b'goats' b'with' b'often' b'are' b'able' b'to' b'forage' b'easy' b'.'
b'They' b'have' b'limited' b'Neversoft' b'in' b'Milburn' b'that' b'could' b'be' b'eastwards' b'.' b'Eurogamer' b'and' b'<unk>' b'hyaline' b'(' b'M.' b'Ponderosa' b'methods' b')'
b':' b'an' b'diet' b'to' b'technique' b'.' b'Charles' b'<unk>' b'3' b'marry' b'Seward' b'as' b'a' b'kakapo' b'when' b'allows' b'back' b'food' b'name' b'of'
b'spectators' b'.' b'As' b'this' b',' b'it' b'can' b'be' b'prolonged' b'by' b'other' b'Kaplan' b'rRNA' b',' b'any' b'generation' b'of' b'central' b'commonly' b'altars'
b'or' b'<unk>' b',' b'while' b'is' b'pot' b'in' b'the' b'Augustan' b'Republic' b'tests' b'.' b'A' b'beryllium' b',' b'priorities' b'based' b'with' b'Arne' b'bernissartensis'
b'in' b'Lena' b',' b'the' b'<unk>' b'limb' b'into' b'South' b'Africa' b',' b'and' b'is' b'that' b'describe' b'on' b'4' b'July' b'when' b'41st' b','
b'though' b'he' b'overrun' b'"' b'whose' b'clients' b'that' b'their' b'hold' b'on' b'a' b'outer' b'eye' b'"' b'.' b'The' b'male' b'of' b'ghat' b'leaving'
b'her' b'endings' b';' b'but' b',' b'in' b'1851' b',' b'faces' b'with' b'all' b'female' b'Landspace' b'needed' b'to' b'be' b'reliable' b'by' b'clusters' b'when'
b'they' b'cannot' b'be' b'accidental' b'.' b'This' b'picture' b'may' b'be' b'well' b'<unk>' b'or' b'is' b'observed' b',' b'with' b'so' b'leave' b',' b'even'
b'on' b'either' b'they' b'Banai' b'"' b'try' b'to' b'sustain' b'their' b'larvae' b'.' b'"' b'<eos>' b'<eos>' b'=' b'=' b'inquiries' b'=' b'=' b'<eos>'
b'<eos>' b'The' b'largest' b'starling' b'appears' b'to' b'be' b'large' b'.' b'Like' b'Sil' b'changed' b'(' b'Gerald' b'Randall' b')' b'do' b'not' b'call' b'by'
b'begin' b'nuclei' b',' b'the' b'other' b'kakapo' b'seem' b'to' b'guarantee' b'symptoms' b'.' b'By' b'naturally' b'when' b'22' b'extraction' b'appear' b'to' b'be' b'invalid'
b',' b'it' b'may' b'be' b'more' b'divided' b'in' b'quite' b'flock' b'too' b'.' b'<unk>' b'tends' b'to' b'penetrate' b',' b'if' b'in' b'keeping' b','
b'they' b'should' b'often' b'feed' b'as' b'it' b'has' b'their' b'brood' b'.' b'If' b'immediate' b'so' b'they' b'develop' b'old' b',' b'they' b'show' b'some'
b'<unk>' b'hitherto' b'or' b'without' b'cream' b'or' b'proving' b'degraded' b'respect' b'due' b'to' b'areas' b'that' b'they' b'came' b'into' b'any' b'common' b',' b'foliage'
b'or' b'consecrate' b'better' b'trusting' b'\xe2\x80\x94' b'showing' b'the' b'nest' b',' b'so' b'they' b'have' b'supported' b'them' b',' b'neither' b'apart' b'to' b'hair' b'back'
b'them' b',' b'although' b'they' b',' b'<unk>' b'<unk>' b'or' b'rDNA' b',' b'<unk>' b',' b'enters' b'Love' b'only' b'purely' b'and' b'government' b'or' b'white'
b'.' b'The' b'amount' b'of' b'Gaelic' b',' b'frequency' b',' b'at' b'a' b'matter' b'of' b'small' b'feeding' b'bird' b',' b'can' b'be' b'introduced' b'to'
b'distinguish' b'common' b'females' b'throughout' b'the' b'mountains' b'before' b'they' b'are' b',' b'when' b'the' b'breeding' b'Magdalen' b'looks' b'in' b'van' b'Boom' b'Blood' b'clear'
b'because' b'if' b'they' b'must' b'be' b'connected' b'making' b'them' b'to' b'some' b'other' b'Dubliners' b'that' b'are' b'attested' b'by' b'From' b'human' b'infections' b'.'
b'As' b'the' b'blessed' b'Wadsworth' b'@-@' b'cede' b'colour' b',' b'it' b'was' b'two' b'smaller' b'spots' b',' b'particularly' b'gains' b'during' b'them' b'or' b'vulnerable'

它不是 GPT-2,但看起来模型已经开始学习语言的结构了!

我们几乎准备好演示动态量化了。我们只需要再定义几个辅助函数

bptt = 25
criterion = nn.CrossEntropyLoss()
eval_batch_size = 1

# create test data set
def batchify(data, bsz):
    # Work out how cleanly we can divide the dataset into ``bsz`` parts.
    nbatch = data.size(0) // bsz
    # Trim off any extra elements that wouldn't cleanly fit (remainders).
    data = data.narrow(0, 0, nbatch * bsz)
    # Evenly divide the data across the ``bsz`` batches.
    return data.view(bsz, -1).t().contiguous()

test_data = batchify(corpus.test, eval_batch_size)

# Evaluation functions
def get_batch(source, i):
    seq_len = min(bptt, len(source) - 1 - i)
    data = source[i:i+seq_len]
    target = source[i+1:i+1+seq_len].reshape(-1)
    return data, target

def repackage_hidden(h):
  """Wraps hidden states in new Tensors, to detach them from their history."""

  if isinstance(h, torch.Tensor):
      return h.detach()
  else:
      return tuple(repackage_hidden(v) for v in h)

def evaluate(model_, data_source):
    # Turn on evaluation mode which disables dropout.
    model_.eval()
    total_loss = 0.
    hidden = model_.init_hidden(eval_batch_size)
    with torch.no_grad():
        for i in range(0, data_source.size(0) - 1, bptt):
            data, targets = get_batch(data_source, i)
            output, hidden = model_(data, hidden)
            hidden = repackage_hidden(hidden)
            output_flat = output.view(-1, ntokens)
            total_loss += len(data) * criterion(output_flat, targets).item()
    return total_loss / (len(data_source) - 1)

4. 测试动态量化

最后,我们可以在模型上调用 torch.quantization.quantize_dynamic!具体来说,

  • 我们指定希望模型中的 nn.LSTMnn.Linear 模块进行量化

  • 我们指定希望将权重转换为 int8

import torch.quantization

quantized_model = torch.quantization.quantize_dynamic(
    model, {nn.LSTM, nn.Linear}, dtype=torch.qint8
)
print(quantized_model)
LSTMModel(
  (drop): Dropout(p=0.5, inplace=False)
  (encoder): Embedding(33278, 512)
  (rnn): DynamicQuantizedLSTM(512, 256, num_layers=5, dropout=0.5)
  (decoder): DynamicQuantizedLinear(in_features=256, out_features=33278, dtype=torch.qint8, qscheme=torch.per_tensor_affine)
)

模型看起来一样;这给我们带来了什么好处?首先,我们看到模型大小显著减小

def print_size_of_model(model):
    torch.save(model.state_dict(), "temp.p")
    print('Size (MB):', os.path.getsize("temp.p")/1e6)
    os.remove('temp.p')

print_size_of_model(model)
print_size_of_model(quantized_model)
Size (MB): 113.944455
Size (MB): 79.738939

其次,我们看到推理时间更快,而评估损失没有差异

注意:我们设置线程数为一,用于单线程比较,因为量化模型是单线程运行的。

torch.set_num_threads(1)

def time_model_evaluation(model, test_data):
    s = time.time()
    loss = evaluate(model, test_data)
    elapsed = time.time() - s
    print('''loss: {0:.3f}\nelapsed time (seconds): {1:.1f}'''.format(loss, elapsed))

time_model_evaluation(model, test_data)
time_model_evaluation(quantized_model, test_data)
loss: 5.167
elapsed time (seconds): 201.7
loss: 5.168
elapsed time (seconds): 118.0

在 MacBook Pro 上本地运行此代码,不进行量化时,推理大约需要 200 秒,进行量化后仅需约 100 秒。

结论

动态量化可以是一种简单有效的方法,在减小模型大小的同时,只对精度产生有限的影响。

感谢阅读!一如既往,我们欢迎任何反馈,如果您有任何意见,请在此处创建 issue。

脚本总运行时间: ( 5 分 28.935 秒)

图库由 Sphinx-Gallery 生成

文档

查阅 PyTorch 的完整开发者文档

查看文档

教程

获取面向初学者和高级开发者的深度教程

查看教程

资源

查找开发资源并获得解答

查看资源