• 文档 >
  • 训练后量化 (PTQ)
快捷方式

训练后量化 (PTQ)

训练后量化 (PTQ) 是一种通过将传统的 FP32 激活空间映射到缩减的 INT8 空间来减少推理所需的计算资源,同时仍保留模型精度的技术。TensorRT 使用校准步骤,该步骤使用来自目标域的样本数据执行模型,并跟踪 FP32 中的激活,以校准到 INT8 的映射,从而最大程度地减少 FP32 推理和 INT8 推理之间的信息损失。

编写 TensorRT 应用程序的用户需要设置一个校准器类,该类将向 TensorRT 校准器提供样本数据。借助 Torch-TensorRT,我们希望利用 PyTorch 中现有的基础架构来简化校准器的实现。

LibTorch 提供了一个 DataLoaderDataset API,它们简化了输入数据的预处理和批处理。这些 API 通过 C++ 和 Python 接口公开,这使得最终用户更容易使用。对于 C++ 接口,我们使用 torch::Datasettorch::data::make_data_loader 对象来构建数据集并对其执行预处理。Python 接口中的等效功能使用 torch.utils.data.Datasettorch.utils.data.DataLoader。PyTorch 文档的这一部分提供了更多信息:https://pytorch.ac.cn/tutorials/advanced/cpp_frontend.html#loading-datahttps://pytorch.ac.cn/tutorials/recipes/recipes/loading_data_recipe.html。Torch-TensorRT 使用 Dataloaders 作为通用校准器实现的基础。因此,您将能够为您的目标域重用或快速实现 torch::Dataset,将其放置在 DataLoader 中,并创建一个 INT8 校准器,您可以将其提供给 Torch-TensorRT 以在编译模块期间运行 INT8 校准。

如何在 C++ 中创建自己的 PTQ 应用程序

以下是 CIFAR10 的 torch::Dataset 类的示例接口

 1//cpp/ptq/datasets/cifar10.h
 2#pragma once
 3
 4#include "torch/data/datasets/base.h"
 5#include "torch/data/example.h"
 6#include "torch/types.h"
 7
 8#include <cstddef>
 9#include <string>
10
11namespace datasets {
12// The CIFAR10 Dataset
13class CIFAR10 : public torch::data::datasets::Dataset<CIFAR10> {
14public:
15    // The mode in which the dataset is loaded
16    enum class Mode { kTrain, kTest };
17
18    // Loads CIFAR10 from un-tarred file
19    // Dataset can be found https://www.cs.toronto.edu/~kriz/cifar-10-binary.tar.gz
20    // Root path should be the directory that contains the content of tarball
21    explicit CIFAR10(const std::string& root, Mode mode = Mode::kTrain);
22
23    // Returns the pair at index in the dataset
24    torch::data::Example<> get(size_t index) override;
25
26    // The size of the dataset
27    c10::optional<size_t> size() const override;
28
29    // The mode the dataset is in
30    bool is_train() const noexcept;
31
32    // Returns all images stacked into a single tensor
33    const torch::Tensor& images() const;
34
35    // Returns all targets stacked into a single tensor
36    const torch::Tensor& targets() const;
37
38    // Trims the dataset to the first n pairs
39    CIFAR10&& use_subset(int64_t new_size);
40
41
42private:
43    Mode mode_;
44    torch::Tensor images_, targets_;
45};
46} // namespace datasets

此类的实现从 CIFAR10 数据集的二进制发行版中读取数据,并构建两个包含图像和标签的张量。

我们使用数据集的子集进行校准,因为我们不需要完整的数据集来进行有效的校准,而且校准确实需要一些时间,然后定义要应用于数据集中图像的预处理,并从将对数据进行批处理的数据集创建一个 DataLoader

auto calibration_dataset = datasets::CIFAR10(data_dir, datasets::CIFAR10::Mode::kTest)
                                    .use_subset(320)
                                    .map(torch::data::transforms::Normalize<>({0.4914, 0.4822, 0.4465},
                                                                            {0.2023, 0.1994, 0.2010}))
                                    .map(torch::data::transforms::Stack<>());
auto calibration_dataloader = torch::data::make_data_loader(std::move(calibration_dataset),
                                                            torch::data::DataLoaderOptions().batch_size(32)
                                                                                            .workers(2));

接下来,我们使用校准器工厂(位于 torch_tensorrt/ptq.h 中)从 calibration_dataloader 创建一个校准器

#include "torch_tensorrt/ptq.h"
...

auto calibrator = torch_tensorrt::ptq::make_int8_calibrator(std::move(calibration_dataloader), calibration_cache_file, true);

在这里,我们还定义了一个位置,用于写入校准缓存文件,以便我们可以使用该文件来重用校准数据,而无需数据集,以及如果缓存文件存在,是否应该使用它。还有一个 torch_tensorrt::ptq::make_int8_cache_calibrator 工厂,它创建一个仅将缓存用于在存储空间有限的机器(即没有空间存储完整数据集)上进行引擎构建或拥有更简单的部署应用程序的情况的校准器。

校准器工厂创建了一个校准器,该校准器继承自 nvinfer1::IInt8Calibrator 虚拟类(默认情况下为 nvinfer1::IInt8EntropyCalibrator2),该类定义了在校准时使用的校准算法。您可以像这样显式选择校准算法

// MinMax Calibrator is geared more towards NLP tasks
auto calibrator = torch_tensorrt::ptq::make_int8_calibrator<nvinfer1::IInt8MinMaxCalibrator>(std::move(calibration_dataloader), calibration_cache_file, true);

然后,为 INT8 校准设置模块所需做的就是设置 torch_tensorrt::CompileSpec 结构中的以下编译设置并编译模块

std::vector<std::vector<int64_t>> input_shape = {{32, 3, 32, 32}};
/// Configure settings for compilation
auto compile_spec = torch_tensorrt::CompileSpec({input_shape});
/// Set operating precision to INT8
compile_spec.enabled_precisions.insert(torch::kF16);
compile_spec.enabled_precisions.insert(torch::kI8);
/// Use the TensorRT Entropy Calibrator
compile_spec.ptq_calibrator = calibrator;

auto trt_mod = torch_tensorrt::CompileGraph(mod, compile_spec);

如果您有一个现有的 TensorRT 校准器实现,您可以直接使用指向您的校准器的指针设置 ptq_calibrator 字段,它也可以工作。从这里开始,在执行方式方面没有太大变化。您仍然可以使用 LibTorch 作为推理的唯一接口。当数据传递到 trt_mod.forward 时,数据应保持 FP32 精度。Torch-TensorRT 演示中有一个应用程序示例,它将指导您完成在 CIFAR10 上训练 VGG16 网络到使用 Torch-TensorRT 部署 INT8 的整个过程,地址为:https://github.com/pytorch/TensorRT/tree/master/cpp/ptq

如何在 Python 中创建自己的 PTQ 应用程序

Torch-TensorRT Python API 提供了一种简单方便的方法,可以使用带有 TensorRT 校准器的 PyTorch 数据加载器。DataLoaderCalibrator 类可用于通过提供所需的配置来创建 TensorRT 校准器。以下代码演示了如何使用它的示例

testing_dataset = torchvision.datasets.CIFAR10(
    root="./data",
    train=False,
    download=True,
    transform=transforms.Compose(
        [
            transforms.ToTensor(),
            transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
        ]
    ),
)

testing_dataloader = torch.utils.data.DataLoader(
    testing_dataset, batch_size=1, shuffle=False, num_workers=1
)
calibrator = torch_tensorrt.ptq.DataLoaderCalibrator(
    testing_dataloader,
    cache_file="./calibration.cache",
    use_cache=False,
    algo_type=torch_tensorrt.ptq.CalibrationAlgo.ENTROPY_CALIBRATION_2,
    device=torch.device("cuda:0"),
)

trt_mod = torch_tensorrt.compile(model, inputs=[torch_tensorrt.Input((1, 3, 32, 32))],
                                    enabled_precisions={torch.float, torch.half, torch.int8},
                                    calibrator=calibrator,
                                    device={
                                         "device_type": torch_tensorrt.DeviceType.GPU,
                                         "gpu_id": 0,
                                         "dla_core": 0,
                                         "allow_gpu_fallback": False,
                                         "disable_tf32": False
                                     })

如果存在用户想要使用的预先存在的校准缓存文件,则可以使用 CacheCalibrator,而无需任何数据加载器。以下示例演示了如何在 INT8 模式下使用 CacheCalibrator

calibrator = torch_tensorrt.ptq.CacheCalibrator("./calibration.cache")

trt_mod = torch_tensorrt.compile(model, inputs=[torch_tensorrt.Input([1, 3, 32, 32])],
                                      enabled_precisions={torch.float, torch.half, torch.int8},
                                      calibrator=calibrator)

如果您已经有了一个现有的校准器类(直接使用 TensorRT API 实现),您可以直接将校准器字段设置为您的类,这非常方便。有关如何使用 Torch-TensorRT API 在 VGG 网络上执行 PTQ 的演示,您可以参考 https://github.com/pytorch/TensorRT/blob/master/tests/py/test_ptq_dataloader_calibrator.pyhttps://github.com/pytorch/TensorRT/blob/master/tests/py/test_ptq_trt_calibrator.py

引用

Krizhevsky, A., & Hinton, G. (2009). 从微小图像中学习多层特征。

Simonyan, K., & Zisserman, A. (2014). 用于大规模图像识别的超深度卷积网络。arXiv 预印本 arXiv:1409.1556。

文档

访问 PyTorch 的全面开发者文档

查看文档

教程

获取针对初学者和高级开发人员的深入教程

查看教程

资源

查找开发资源并获得问题的解答

查看资源