编写自己的量化张量¶
torchao 中的量化构建在张量子类的基础上。它们是 torchao 的主要扩展点,用于使用低精度计算提供灵活的推理和训练支持,同时与重要的 PyTorch 功能(如 torch.compile、autograd 和分布式原语)结合使用。
在本教程中,我们将强调利用张量子类相对于模块替换的优势,并通过一个简单的示例来演示如何使用这种方法表达量化。
什么是张量子类?¶
张量子类只是继承自 torch.Tensor 的类。它们允许用户在模型中现有操作之间插入自定义计算逻辑,以便像顶层 torch 命名空间中的 torch.add 等函数能够继续无缝工作。
张量子类方法的一个显而易见的替代方案是模块替换:例如,简单地将模型中的所有 nn.Linear 模块替换为您自定义的 Int8QuantizedLinear 模块。与这种方法相比,使用张量子类有几个重要的优势:
更细粒度的集成点。 模块替换在模块级别拦截计算,因此不适用于依赖 torch 函数或原生模块变体(例如,略微修改过的 nn.Linear 版本)的模型。相比之下,由于张量子类在函数/操作级别拦截计算,因此只要使用相同的函数/操作,我们就能够对模型进行量化。
更好的可组合性。 使用模块替换组合多个功能很笨拙。例如,组合两个现有的 Int8QuantizedLinear 和 DistributedLinear 模块需要用户创建另一个线性类来复制这些功能。张量子类通过简单地将一个子类包装在另一个子类中来绕过这个问题。如果外部张量(例如 DTensor)知道内部张量已量化,这也可以提供性能优势,从而可以使用更少的网络和内存带宽执行昂贵的全收集操作。
重用 PyTorch 组件。 使用张量子类来表达量化是很自然的,因为量化张量只是具有不同 dtype 的 torch.Tensors。模型结构不会改变(nn.Linears 仍然是 nn.Linears),因此后续的优化过程也可以保持与之前完全相同。
在本教程的其余部分,我们将通过一个示例来演示如何使用这两种方法表达量化。有关张量子类的更多阅读资料,请参阅:
使用模块替换进行量化¶
我们首先以一个简单的示例开始,演示如何使用模块替换实现 int8 对称权重唯一量化。所有代码都可以在此示例脚本中找到。我们将使用以下函数将 float32 张量量化为 int8 张量
from typing import Tuple
import torch
def int8_symmetric_quantize(
fp32_tensor: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Symmetrically quantize the torch.float32 tensor into torch.int8.
Return a 2-tuple of (quantized value, scale).
input: dimensions=[M, N], dtype=torch.float32
output: dimensions=[M, N], dtype=torch.int8
scale: dimensions=[M, 1], dtype=torch.float32
"""
quant_min = -128
quant_max = 127
min_val = torch.amin(fp32_tensor, dim=[1], keepdim=False)
max_val = torch.amax(fp32_tensor, dim=[1], keepdim=False)
min_val_neg = torch.min(min_val, torch.zeros_like(min_val))
max_val_pos = torch.max(max_val, torch.zeros_like(max_val))
max_val_pos = torch.max(-min_val_neg, max_val_pos)
scale = max_val_pos / (float(quant_max - quant_min) / 2)
scale = scale.view(fp32_tensor.shape[0], -1)
out = torch.round(fp32_tensor * (1.0 / scale))
out = torch.clamp(out, quant_min, quant_max).to(torch.int8)
return out, scale
接下来,我们将创建一个新的 QuantizedLinear 模块,该模块调用此函数来动态量化权重
class QuantizedLinear(torch.nn.Linear):
"""
Linear module that performs dynamic and symmetric weight-only
int8 quantization.
"""
def forward(self, x: torch.Tensor) -> torch.Tensor:
w_int8, scale = int8_symmetric_quantize(self.weight)
return torch.matmul(x, w_int8.t().to(x.dtype)) * scale.t()
@classmethod
def from_float(cls, mod: torch.nn.Linear):
new_linear = cls(mod.in_features, mod.out_features, mod.bias)
new_linear.weight = mod.weight
return new_linear
然后,剩下要做的就是将模型中的所有 nn.Linear 模块替换为我们新的 QuantizedLinear 模块。让我们使用以下玩具模型进行演示
import copy
class ToyModel(torch.nn.Module):
def __init__(self, m: int, n: int, k: int):
super().__init__()
self.linear1 = torch.nn.Linear(m, n, bias=False)
self.linear2 = torch.nn.Linear(n, k, bias=False)
def forward(self, x):
x = self.linear1(x)
x = self.linear2(x)
return x
float_model = ToyModel(64, 128, 32).cuda()
quantized_model = copy.deepcopy(float_model)
# Swap torch.nn.Linear with QuantizedLinear
for name, child in quantized_model.named_children():
if type(child) == torch.nn.Linear:
new_linear = QuantizedLinear.from_float(child)
setattr(quantized_model, name, new_linear)
验证模型现在使用了我们的 QuantizedLinear 模块。这个模型现在可以使用了!
>>> print(float_model)
ToyModel(
(linear1): Linear(in_features=64, out_features=128, bias=False)
(linear2): Linear(in_features=128, out_features=32, bias=False)
)
>>> print(quantized_model)
ToyModel(
(linear1): QuantizedLinear(in_features=64, out_features=128, bias=False)
(linear2): QuantizedLinear(in_features=128, out_features=32, bias=False)
)
这种简单方法的一个重要缺点是灵活性。目前,这仅适用于原生的 PyTorch 模块,但如果模型有略微修改的线性模块(例如,支持分布式训练)怎么办?它也无法用于直接调用线性函数版本(torch.nn.functional.linear)的模型。
此外,假设我们想将此功能与分布式训练结合,分布式训练也是通过模块替换实现的。除了创建另一个结合这两个功能的模块之外,没有其他简洁的方法可以做到这一点。这些限制可以通过张量子类解决,张量子类是一种更优雅的方式,可以在模型中插入自定义计算,例如量化。
使用张量子类进行量化¶
在这里,我们将使用基于 __torch_dispatch__ 的张量子类重新实现上述量化技术。
张量子类(通常利用 __torch_dispatch__)是 PyTorch 中一个非常强大/灵活的扩展点。它们作为扩展点主要有两个目的:
张量子类允许您覆盖(几乎)每个 PyTorch API 的实现,并且在实现其他 PyTorch 产品时使用得很频繁
张量子类允许您将张量数据与附加元数据耦合。一些例子:
[量化] scale/zero_point 元数据(AffineQuantizedTensor)
[不规则性] 关于不规则结构的元数据(NestedTensor, 文档)
对张量子类感兴趣的读者可以参考其他一些资源:
__torch_dispatch__ 文档(链接)
什么是 __torch_dispatch__ (及其作用)(链接)
使用 __torch_dispatch__ 实现 FlopCounter 和 MemoryTracker 的 Google Colab(链接)
言归正传,让我们首先定义用于对称量化的基本张量子类
class Int8SymmetricTensor(torch.Tensor):
"""
Our subclass represents a tensor that has been quantized to int8
It will hold two inner tensors:
int_data: int8[M, N]
scale: fp32[M, 1]
"""
@staticmethod
@torch._dynamo.disable
def __new__(cls, int_data: torch.Tensor, scale: torch.Tensor):
return torch.Tensor._make_wrapper_subclass(
cls,
int_data.shape,
strides=int_data.stride(),
storage_offset=int_data.storage_offset(),
dtype=scale.dtype,
device=int_data.device,
)
@torch._dynamo.disable
def __init__(self, int_data: torch.Tensor, scale: torch.Tensor):
# inner data expected to be quantized already
assert int_data.dtype is torch.int8
# we could do more work to support ndim > 2!
assert int_data.ndim == 2
assert scale.ndim == 2
self.int_data = int_data
self.scale = scale
def __tensor_flatten__(self) -> Tuple[List[str], Any]:
"""
Returns a tuple of:
names of all inner tensor attributes (two in our case)
any other additional, non-tensor metadata.
Needed for PT2 support.
"""
return ["int_data", "scale"], None
@classmethod
def __tensor_unflatten__(cls, tensor_data_dict, extra_metadata, outer_size=None, outer_stride=None):
"""
__tensor_unflatten__ should effectively undo __tensor_flatten__.
inputs:
a dict mapping names of inner tensor attributes back to the tensors
the constant metadata from __tensor_flatten__
output:
a new instance of your subclass
Needed for PT2 support.
"""
assert extra_metadata is None
int_data = tensor_data_dict["int_data"]
scale = tensor_data_dict["scale"]
return Int8SymmetricTensor(int_data, scale)
def __repr__(self):
return f'Int8SymmetricTensor(int_data={repr(self.int_data)}, scale={repr(self.scale)})'
@staticmethod
def from_float(float_tensor):
"""
Actually performs the symmetric quantization.
In our simple inference example we will quantize weights "ahead-of-time",
although later in a training example we can quantize/dequantize
during model execution, inside of our __torch_dispatch__
input:
float32 torch.Tensor
output:
Int8SymmetricTensor
"""
int8_tensor, scale = int8_symmetric_quantize(float_tensor)
return Int8SymmetricTensor(int8_tensor, scale)
@classmethod
def __torch_dispatch__(cls, func, types, args, kwargs):
"""
Called for each ATen operator that our subclass is passed as an input to.
We need to define our own implementation for every operator here.
"""
if kwargs is None:
kwargs = {}
if func not in op_implementations_dict:
raise AssertionError(f'Int8SymmetricTensor does not yet support op: {str(func)}')
return op_implementations_dict[func](func, *args, **kwargs)
# Convenience function for registering our own implementation
# to every ATen operator in PyTorch
op_implementations_dict = {}
def register_op(ops: List[torch._ops.OpOverload]):
def impl_decorator(op_impl):
global op_implementations_dict
for op in ops:
op_implementations_dict[op] = op_impl
return op_impl
return impl_decorator
在上面的代码中,我们做了几件事
定义了一个基本的“包装器”张量子类 - 它实际上是一个容器对象,用于保存一些内部数据(特别是对应于我们的 int8 数据和 scale 的两个张量)
定义了 __torch_dispatch__ 的实现,当模型对我们的任何子类输入调用任何 ATen 操作符时,都会调用此实现
(为了支持 PT2)定义了 __tensor_flatten__/__tensor_unflatten__ 方法。这是为了使我们的子类与 torch.compile 一起工作所需的几个最大要求之一(稍后会详细介绍)。它有效地告诉 torch.compile 如何将我们的子类“解糖”成其内部组件。
(为了支持 PT2)向两个构造方法(__new__ 和 __init__)添加了 torch._dynamo.disable 装饰器(稍后会详细介绍)。
我们应该实现哪些操作符?¶
PyTorch 有一个相当大的操作符集合。我们不打算让新的张量子类实现 100% 的覆盖率,而是只关注上面玩具模型所需的那些操作。
但是,我们的模型中调用了哪些操作符,以便我们知道应该先实现什么?最笨的方法是反复运行模型,查看子类中出现哪些操作符错误。一个更优雅的方法是记录模型在执行期间遇到的所有操作符。这可以通过另一个 LoggingTensor 子类实现,例如此示例。
让我们在下面实现必要的操作符
from torch.utils._python_dispatch import return_and_correct_aliasing
@register_op([torch.ops.aten.mm.default])
def int8_mm(func, x, weight):
assert isinstance(weight, Int8SymmetricTensor), "Int8SymmetricTensor: matmul currently only supports the weight in low precision, not the input!"
return torch.mm(x, weight.int_data.to(x.dtype)) * weight.scale
@register_op([
torch.ops.aten.detach.default,
torch.ops.aten.t.default,
])
def int8_view_ops(func, *args, **kwargs):
assert isinstance(args[0], Int8SymmetricTensor)
out_data = func(args[0].int_data, *args[1:], **kwargs)
out_scale = func(args[0].scale, *args[1:], **kwargs)
out = Int8SymmetricTensor(out_data, out_scale)
return return_and_correct_aliasing(func, args, kwargs, out)
您很快会注意到一件事:我们的模型本身包含几个线性层,但我们看到一些操作符(如 aten.t 和 aten.mm)命中了我们的子类。一些背景知识:
我们有许多存在于 C++ 中的操作符分解,它们运行在张量子类“之上”。linear 就是这样一种操作符(分解代码在此)
分解的好处在于它们减少了作为子类作者需要实现的 API 数量。但如果您宁愿覆盖“更高层级”的操作符而不是其分解中的底层操作,那么它们可能会很麻烦。
如果您希望在更高级别覆盖某些操作(如 Linear),可以使用 __torch_function__ (示例)。值得注意的是,如果您需要 autograd 支持,那么在 __torch_function__ 层进行的任何覆盖都需要以可微分的方式编写,而在 __torch_dispatch__ 中进行的任何覆盖将自动可微分。
我们的实现中有一些值得指出的细微之处
您会注意到,在我们的 mm 实现中,我们不再需要在内部对权重/scale 进行转置。这是因为在我们到达 aten.mm 操作之前,转置“已经发生”了。
我们的 aten.mm 实现不返回张量子类输出。从这个意义上说,我们的量化子类的“传播”在矩阵乘法处结束。这反映了我们的权重是低精度的,但我们需要在高精度下执行矩阵乘法本身。一般来说,子类作者可以自由选择他们的子类对哪些操作进行传播或不传播。如果您希望模型中的每个函数(包括所有逐点操作和归约操作)都进行量化,您可以编写子类实现,对每个操作的输出进行量化,并始终返回一个子类。
我们能够对 4 个视图操作重用相同的实现。一般来说,许多操作可以通过相当通用的实现来处理:解包装任何子类输入,在内部张量上运行底层操作符,然后将输出重新包装回子类中。
然而,您是否总能重用实现取决于您尝试做什么。例如,我们在子类上通过对内部数据和内部 scale 张量调用相同的转置来实现 transpose(dim0, dim1)。如果我们的 scale 和数据张量具有不同的维度数,这将不起作用,因此在这种情况下,转置将需要自定义实现。
比较输出¶
话不多说,让我们用这两种量化版本运行我们的模型,并确认它们给出相同的输出!
float_model = ToyModel(64, 128, 32).cuda()
quantized_model_module_swap = copy.deepcopy(float_model)
quantized_model_subclass = copy.deepcopy(float_model)
# Swap torch.nn.Linear with QuantizedLinear
for name, child in quantized_model_module_swap.named_children():
if type(child) == torch.nn.Linear:
new_linear = QuantizedLinear.from_float(child)
setattr(quantized_model_module_swap, name, new_linear)
# Swap torch.nn.Linear weights with Int8SymmetricTensor subclasses
for name, child in quantized_model_subclass.named_children():
if type(child) == torch.nn.Linear:
subclass_param = Int8SymmetricTensor.from_float(child.weight)
child.weight = torch.nn.Parameter(subclass_param, requires_grad=True)
with torch.no_grad():
x = torch.randn(64, 64, 64, device='cuda')
out_module_swap = quantized_model_module_swap(x)
out = quantized_model_subclass(x)
print(torch.allclose(out, out_module_swap)) # prints True
# We can also use torch.compile to fuse some of our quantized logic
out_compiled = torch.compile(quantized_model_subclass)(x)
print(torch.allclose(out, out_compiled)) # prints True
后续步骤¶
在本教程中,我们演示了如何构建一个简单的量化张量子类。这是本系列两个教程中的第一部分。下一篇文章将讨论如何向张量子类添加更高级的功能,例如使其可训练、与 DTensor 组合以及添加张量并行性支持。有关 torchao 中如何使用张量子类构建 AffineQuantizedTensor 的更详细示例,请参阅此示例。
如果您在实现子类时有任何疑问,请随时在此提交问题。