• 文档 >
  • 引擎缓存 (BERT)
快捷方式

引擎缓存 (BERT)

BERT 上的小缓存示例。

import numpy as np
import torch
import torch_tensorrt
from engine_caching_example import remove_timing_cache
from transformers import BertModel

np.random.seed(0)
torch.manual_seed(0)

model = BertModel.from_pretrained("bert-base-uncased", return_dict=False).cuda().eval()
inputs = [
    torch.randint(0, 2, (1, 14), dtype=torch.int32).to("cuda"),
    torch.randint(0, 2, (1, 14), dtype=torch.int32).to("cuda"),
]


def compile_bert(iterations=3):
    times = []
    start = torch.cuda.Event(enable_timing=True)
    end = torch.cuda.Event(enable_timing=True)

    # The 1st iteration is to measure the compilation time without engine caching
    # The 2nd and 3rd iterations are to measure the compilation time with engine caching.
    # Since the 2nd iteration needs to compile and save the engine, it will be slower than the 1st iteration.
    # The 3rd iteration should be faster than the 1st iteration because it loads the cached engine.
    for i in range(iterations):
        # remove timing cache and reset dynamo for engine caching messurement
        remove_timing_cache()
        torch._dynamo.reset()

        if i == 0:
            cache_built_engines = False
            reuse_cached_engines = False
        else:
            cache_built_engines = True
            reuse_cached_engines = True

        start.record()
        compilation_kwargs = {
            "use_python_runtime": False,
            "enabled_precisions": {torch.float},
            "truncate_double": True,
            "debug": False,
            "min_block_size": 1,
            "make_refittable": True,
            "cache_built_engines": cache_built_engines,
            "reuse_cached_engines": reuse_cached_engines,
            "engine_cache_dir": "/tmp/torch_trt_bert_engine_cache",
            "engine_cache_size": 1 << 30,  # 1GB
        }
        optimized_model = torch.compile(
            model,
            backend="torch_tensorrt",
            options=compilation_kwargs,
        )
        optimized_model(*inputs)
        end.record()
        torch.cuda.synchronize()
        times.append(start.elapsed_time(end))

    print("-----compile bert-----> compilation time:\n", times, "milliseconds")


if __name__ == "__main__":
    compile_bert()

脚本的总运行时间:(0 分钟 0.000 秒)

由 Sphinx-Gallery 生成的画廊

文档

访问 PyTorch 的全面开发者文档

查看文档

教程

获取针对初学者和高级开发者的深入教程

查看教程

资源

查找开发资源并获得问题的答案

查看资源