我们很高兴正式推出torchao,这是一个PyTorch原生库,通过利用低位数据类型、量化和稀疏性,使模型更快、更小。torchao是一个易于使用的技术工具包,其代码(大部分)以易于阅读的PyTorch代码编写,涵盖推理和训练。这篇博客将帮助您选择对您的工作负载重要的技术。
我们对LLama 3和Diffusion模型等热门GenAI模型的技术进行了基准测试,发现精度下降极小。除非另有说明,基准测试是在A100 80GB GPU上运行的bf16。
我们针对llama 3的顶级指标是:
- 使用带有int4纯权重量化和hqq的autoquant,Llama 3 8B推理速度提升97%
- 在128K上下文长度下,通过量化KV缓存,Llama 3.1 8B推理峰值VRAM减少73%
- 在H100上使用float8训练,Llama 3 70B预训练速度提升50%
- 使用4位量化优化器,Llama 3 8B峰值VRAM减少30%。
我们针对扩散模型推理的顶级指标是:
- 在H100上的flux1.dev上使用float8动态量化推理和float8逐行缩放,速度提升53%
- 使用int8动态量化,CogVideoX的模型VRAM减少50%
下面我们将介绍torchao中可用的一些技术,您可以将其应用于您的模型以进行推理和训练。
推理
我们的推理量化算法适用于包含nn.Linear层的任意PyTorch模型。通过我们的顶级quantize_
API,可以选择用于各种数据类型和稀疏布局的纯权重和动态激活量化。
from torchao.quantization import (
quantize_,
int4_weight_only,
)
quantize_(model, int4_weight_only())
有时量化一个层可能会因为开销而使其变慢,因此如果您希望我们为您选择如何量化模型中的每个层,那么您可以运行:
model = torchao.autoquant(torch.compile(model, mode='max-autotune'))
quantize_
API有几种不同的选项,具体取决于您的模型是计算密集型还是内存密集型。
from torchao.quantization import (
# Memory bound models
int4_weight_only,
int8_weight_only,
# Compute bound models
int8_dynamic_activation_int8_semi_sparse_weight,
int8_dynamic_activation_int8_weight,
# Device capability 8.9+
float8_weight_only,
float8_dynamic_activation_float8_weight,
)
我们还与HuggingFace diffusers团队合作,在diffusers-torchao中进行了广泛的扩散模型基准测试,其中我们展示了Flux.1-Dev上53.88%的速度提升和CogVideoX-5b上27.33%的速度提升。

我们的API是可组合的,例如,我们组合了稀疏性和量化,为ViT-H推理带来了5%的速度提升。
但也可以做一些事情,例如将权重量化为int4,将kv缓存量化为int8,以支持Llama 3.1 8B在不到18.9GB的VRAM中以128K完整上下文长度运行。
QAT
训练后量化,尤其是低于4位的量化,可能会导致严重的精度下降。通过使用量化感知训练(QAT),我们成功地在hellaswag上恢复了高达96%的精度下降。我们已将其作为一个端到端的方法集成到torchtune中,并提供了最小的教程。

训练
低精度计算和通信
torchao提供了易于使用的端到端工作流,用于降低训练计算和分布式通信的精度,从torch.nn.Linear
层的float8开始。下面是一行代码,可以将您的训练运行中的计算gemms转换为float8。
from torchao.float8 import convert_to_float8_training
convert_to_float8_training(model)
有关如何通过float8将LLaMa 3 70B预训练速度提高多达1.5倍的端到端示例,请参阅我们的README,以及torchtitan的博客和float8食谱。
LLaMa 3 70B的float8预训练性能和精度与bfloat16的比较
(来源:https://dev-discuss.pytorch.org/t/enabling-float8-all-gather-in-fsdp2/2359)
我们正在将训练工作流扩展到更多的数据类型和布局。
低位优化器
受Bits and Bytes的启发,我们还添加了对8位和4位优化器的原型支持,作为AdamW的直接替代。
from torchao.prototype.low_bit_optim import AdamW8bit, AdamW4bit
optim = AdamW8bit(model.parameters())

集成
我们一直积极致力于确保torchao在一些最重要的开源项目中良好运行。
- Huggingface transformers作为推理后端
- 在diffusers-torchao中作为加速扩散模型的参考实现
- 在HQQ中用于快速4位推理
- 在torchtune中用于PyTorch原生QLoRA和QAT食谱
- 在torchchat中用于训练后量化
- 在SGLang中用于int4和int8训练后量化
结论
如果您有兴趣让您的模型在训练或推理方面更快、更小,我们希望您会发现torchao有用且易于集成。
pip install torchao
我们对许多事情感到兴奋,包括低于4位的量化、用于高吞吐量推理的高性能内核、扩展到更多层、扩展类型或粒度、MX硬件支持以及支持更多硬件后端。如果以上任何一点让您感到兴奋,您可以关注我们的进展:https://github.com/pytorch/ao
如果您有兴趣参与torchao的工作,我们创建了一个贡献者指南,如果您有任何问题,我们会在discord.gg/gpumode的#torchao
频道与大家交流。
致谢
我们很幸运能站在巨人的肩膀上,并与开源领域的一些最优秀的人合作。谢谢大家!
- Bits and Bytes在低位优化器和QLoRA方面的开创性工作
- Answer.ai在使FSDP和QLoRA组合方面的工程工作
- Mobius Labs在量化算法和低位内核方面的愉快交流
- HuggingFace transformers在实战测试和集成我们工作方面的帮助
- HuggingFace diffusers在广泛基准测试和最佳实践方面的合作
- torch.compile使我们能够用纯PyTorch编写算法
- GPU MODE为我们早期的许多贡献者提供了支持