注意
转到结尾 下载完整的示例代码
使用 torch.compile 后端编译 BERT¶
此交互式脚本旨在作为在 BERT 模型上使用 torch.compile 的 Torch-TensorRT 工作流程的示例。
导入和模型定义¶
import torch
import torch_tensorrt
from transformers import BertModel
# Initialize model with float precision and sample inputs
model = BertModel.from_pretrained("bert-base-uncased").eval().to("cuda")
inputs = [
torch.randint(0, 2, (1, 14), dtype=torch.int32).to("cuda"),
torch.randint(0, 2, (1, 14), dtype=torch.int32).to("cuda"),
]
torch_tensorrt.compile 的可选输入参数¶
# Enabled precision for TensorRT optimization
enabled_precisions = {torch.float}
# Whether to print verbose logs
debug = True
# Workspace size for TensorRT
workspace_size = 20 << 30
# Maximum number of TRT Engines
# (Lower value allows more graph segmentation)
min_block_size = 7
# Operations to Run in Torch, regardless of converter support
torch_executed_ops = {}
使用 torch.compile 编译¶
# Define backend compilation keyword arguments
compilation_kwargs = {
"enabled_precisions": enabled_precisions,
"debug": debug,
"workspace_size": workspace_size,
"min_block_size": min_block_size,
"torch_executed_ops": torch_executed_ops,
}
# Build and compile the model with torch.compile, using Torch-TensorRT backend
optimized_model = torch.compile(
model,
backend="torch_tensorrt",
dynamic=False,
options=compilation_kwargs,
)
optimized_model(*inputs)
等效地,我们可以通过便利的前端运行上述操作,如下所示:torch_tensorrt.compile(model, ir=”torch_compile”, inputs=inputs, **compilation_kwargs)
推理¶
# Does not cause recompilation (same batch size as input)
new_inputs = [
torch.randint(0, 2, (1, 14), dtype=torch.int32).to("cuda"),
torch.randint(0, 2, (1, 14), dtype=torch.int32).to("cuda"),
]
new_outputs = optimized_model(*new_inputs)
# Does cause recompilation (new batch size)
new_inputs = [
torch.randint(0, 2, (4, 14), dtype=torch.int32).to("cuda"),
torch.randint(0, 2, (4, 14), dtype=torch.int32).to("cuda"),
]
new_outputs = optimized_model(*new_inputs)
清理¶
# Finally, we use Torch utilities to clean up the workspace
torch._dynamo.reset()
Cuda 驱动程序错误说明¶
有时,在使用 torch_tensorrt 进行 Dynamo 编译后退出 Python 运行时,可能会遇到 Cuda 驱动程序错误。此问题与 https://github.com/NVIDIA/TensorRT/issues/2052 相关,可以通过将编译/推理包装在函数中并使用作用域调用来解决,如
if __name__ == '__main__':
compile_engine_and_infer()
脚本的总运行时间:(0 分钟 0.000 秒)