• 文档 >
  • 使用 Torch-TensorRT torch.compile 前端编译 GPT2
快捷方式

使用 Torch-TensorRT torch.compile 前端编译 GPT2

本示例展示了如何使用 Torch-TensorRT 的 torch.compile 前端优化最先进的 GPT2 模型。在编译前安装以下依赖项

pip install -r requirements.txt

GPT2 是一个因果(单向)Transformer 模型,使用大量文本语料库进行语言建模预训练。在本示例中,我们使用 HuggingFace 上提供的 GPT2 模型,并对其应用 torch.compile 以获取图模块的图表示。Torch-TensorRT 将此图转换为优化的 TensorRT 引擎。

导入所需库

import torch
import torch_tensorrt
from transformers import AutoModelForCausalLM, AutoTokenizer

定义所需参数

Torch-TensorRT 需要 GPU 才能成功编译模型。MAX_LENGTH 是生成 token 的最大长度。这等于输入提示的长度加上生成的新 token 数量。

MAX_LENGTH = 32
DEVICE = torch.device("cuda:0")

模型定义

我们使用 AutoModelForCausalLM 类从 hugging face 加载预训练的 GPT2 模型。Torch-TRT 当前不支持 kv_cache,因此设置为 use_cache=False

with torch.no_grad():
    tokenizer = AutoTokenizer.from_pretrained("gpt2")
    model = (
        AutoModelForCausalLM.from_pretrained(
            "gpt2",
            pad_token_id=tokenizer.eos_token_id,
            use_cache=False,
            attn_implementation="eager",
        )
        .eval()
        .cuda()
    )

PyTorch 推理

对示例输入提示进行 tokenization 并获取 PyTorch 模型输出。

prompt = "I enjoy walking with my cute dog"
model_inputs = tokenizer(prompt, return_tensors="pt")
input_ids = model_inputs["input_ids"].cuda()

AutoModelForCausalLM 类的 generate() API 用于使用贪婪解码进行自回归生成。

pyt_gen_tokens = model.generate(
    input_ids,
    max_length=MAX_LENGTH,
    use_cache=False,
    pad_token_id=tokenizer.eos_token_id,
)

Torch-TensorRT 编译和推理

输入序列长度是动态的,因此我们使用 torch._dynamo.mark_dynamic API 对其进行标记。我们为该值提供一个 (min, max) 范围,以便 TensorRT 提前知道要优化的值范围。通常,这将是模型的上下文长度。由于 0/1 特殊化,我们从 min=2 开始。

torch._dynamo.mark_dynamic(input_ids, 1, min=2, max=1023)
model.forward = torch.compile(
    model.forward,
    backend="tensorrt",
    dynamic=None,
    options={
        "enabled_precisions": {torch.float32},
        "disable_tf32": True,
        "min_block_size": 1,
    },
)

使用 TensorRT 模型进行贪婪解码的自回归生成循环。生成第一个 token 时,模型会使用 TensorRT 进行编译;而生成第二个 token 时会遇到重新编译(这是目前一个问题,将来会解决)。

trt_gen_tokens = model.generate(
    inputs=input_ids,
    max_length=MAX_LENGTH,
    use_cache=False,
    pad_token_id=tokenizer.eos_token_id,
)

解码 PyTorch 和 TensorRT 的输出句子

print(
    "Pytorch model generated text: ",
    tokenizer.decode(pyt_gen_tokens[0], skip_special_tokens=True),
)
print("=============================")
print(
    "TensorRT model generated text: ",
    tokenizer.decode(trt_gen_tokens[0], skip_special_tokens=True),
)

输出句子应如下所示

"""
Pytorch model generated text:  I enjoy walking with my cute dog, but I'm not sure if I'll ever be able to walk with my dog. I'm not sure if I'll
=============================
TensorRT model generated text:  I enjoy walking with my cute dog, but I'm not sure if I'll ever be able to walk with my dog. I'm not sure if I'll
"""

脚本总运行时间: ( 0 分 0.000 秒)

画廊由 Sphinx-Gallery 生成


© 版权所有 2024, NVIDIA Corporation.

使用 Sphinx 构建,主题由 Read the Docs 提供。

文档

访问 PyTorch 的全面开发者文档

查看文档

教程

获取面向初学者和高级开发者的深度教程

查看教程

资源

查找开发资源并获得问题解答

查看资源