摘要
PyTorch最初使用eager模式,其中构成模型的每个PyTorch操作都会在被执行时立即独立运行. PyTorch 2.0引入了torch.compile以在默认eager模式之上加速PyTorch代码. 与eager模式相反,torch.compile会将整个模型预编译成一个针对给定硬件平台进行优化的单个图. AWS针对AWS Graviton3处理器优化了PyTorch的torch.compile功能. 这项优化使得Hugging Face模型推理性能提升高达2倍(基于33个模型的性能提升几何平均数),以及TorchBench模型推理性能提升高达1.35倍(基于45个模型的性能提升几何平均数),与AWS Graviton3上的Amazon EC2实例上多种自然语言处理(NLP)、计算机视觉(CV)和推荐模型的默认eager模式推理相比. 从PyTorch 2.3.1开始,这些优化可在torch Python wheels和AWS Graviton PyTorch 深度学习容器(DLC)中获取.
在这篇博客文章中,我们将展示如何在AWS Graviton3上的EC2实例上优化torch.compile性能,如何使用这些优化来提高推理性能,以及由此带来的加速效果.
为什么选择torch.compile,以及目标是什么?
在eager模式下,模型中的算子在遇到时会立即运行. 它更易于使用,更适合机器学习(ML)研究人员,因此是默认模式. 然而,eager模式会产生运行时开销,因为存在冗余的内核启动和内存读取开销. 而在torch compile模式下,算子首先被合成为一个图,其中一个算子与另一个算子合并以减少和局部化内存读取和总体的内核启动开销.
AWS Graviton团队的目标是优化Graviton3处理器的torch.compile后端. PyTorch eager模式已通过使用oneDNN(也称为MKLDNN)利用Arm Compute Library (ACL) 内核针对Graviton3处理器进行了优化. 因此,问题是如何在torch.compile模式下重用这些内核,以同时获得图编译和优化内核性能的最佳效果?
结果
AWS Graviton团队扩展了torch inductor和一个DNN原语,重用了ACL内核,并优化了Graviton3处理器上的compile模式性能. 从PyTorch 2.3.1开始,这些优化可在torch Python wheels和AWS Graviton DLC中获取. 请参阅随后的“运行推理”部分,了解安装、运行时配置以及如何运行测试的说明.
为了展示性能提升,我们使用了来自TorchBench的NLP、CV和推荐模型,以及来自Hugging Face的最受欢迎的NLP模型,涵盖了问答、文本分类、Token分类、翻译、零样本分类、摘要、特征提取、文本生成、文本到文本生成、填空和句子相似度等任务,以覆盖广泛的客户用例.
我们首先测量了TorchBench模型推理的延迟(以毫秒(msec)为单位),在下图中用红色虚线标记为1.0. 然后我们比较了torch.compile对相同模型推理带来的改进,归一化结果绘制在图中. 你可以看到,对于我们进行基准测试的45个模型,延迟提升了1.35倍(45个模型的几何平均数).
图1:使用TorchBench框架在AWS Graviton3上的c7g实例上使用torch.compile进行PyTorch模型推理性能提升. 参考eager模式性能标记为1.0. (越高越好)
与前述的TorchBench推理性能图类似,我们首先测量了Hugging Face NLP模型推理的延迟(以毫秒(msec)为单位),在下图中用红色虚线标记为1.0. 然后我们比较了torch.compile对相同模型推理带来的改进,归一化结果绘制在图中. 你可以看到,对于我们进行基准测试的33个模型,性能提升了大约2倍(33个模型的几何平均数).
图2:使用Hugging Face示例脚本在AWS Graviton3上的c7g实例上使用torch.compile进行Hugging Face NLP模型推理性能提升. 参考eager模式性能标记为1.0. (越高越好)
运行推理
从PyTorch 2.3.1开始,这些优化可在torch Python wheel和AWS Graviton PyTorch DLC中获取. 本节介绍如何使用torch Python wheels和来自Hugging Face和TorchBench仓库的基准测试脚本在eager模式和torch.compile模式下运行推理.
为了成功运行脚本并重现本文提到的加速数字,你需要使用Graviton3系列硬件(c7g/r7g/m7g/hpc7g
)的实例. 对于本文,我们使用了c7g.4xl (16 vcpu) 实例. 实例、AMI详情和所需的torch库版本在以下代码片段中提到.
Instance: c7g.4xl instance
Region: us-west-2
AMI: ami-05cc25bfa725a144a (Ubuntu 22.04/Jammy with 6.5.0-1017-aws kernel)
# Install Python
sudo apt-get update
sudo apt-get install -y python3 python3-pip
# Upgrade pip3 to the latest version
python3 -m pip install --upgrade pip
# Install PyTorch and extensions
python3 -m pip install torch==2.3.1 torchvision==0.18.1 torchaudio==2.3.1
为eager模式推理实现的通用运行时调优同样适用于torch.compile模式,因此,我们设置了以下环境变量,以进一步提高AWS Graviton3处理器上的torch.compile性能.
# Enable the fast math GEMM kernels, to accelerate fp32 inference with bfloat16 gemm
export DNNL_DEFAULT_FPMATH_MODE=BF16
# Enable Linux Transparent Huge Page (THP) allocations,
# to reduce the tensor memory allocation latency
export THP_MEM_ALLOC_ENABLE=1
# Set LRU Cache capacity to cache the primitives and avoid redundant
# memory allocations
export LRU_CACHE_CAPACITY=1024
TORCHBENCH 基准测试脚本
TorchBench是用于评估PyTorch性能的开源基准测试集合. 我们使用来自TorchBench仓库的脚本对45个模型进行了基准测试. 以下代码展示了如何在eager模式和使用inductor后端的compile模式下运行脚本.
# Set OMP_NUM_THREADS to number of vcpus, 16 for c7g.4xl instance
export OMP_NUM_THREADS=16
# Install the dependencies
sudo apt-get install -y libgl1-mesa-glx
sudo apt-get install -y libpangocairo-1.0-0
python3 -m pip install psutil numpy transformers pynvml numba onnx onnxruntime scikit-learn timm effdet gym doctr opencv-python h5py==3.10.0 python-doctr
# Clone pytorch benchmark repo
git clone https://github.com/pytorch/benchmark.git
cd benchmark
# PyTorch benchmark repo doesn't have any release tags. So,
# listing the commit we used for collecting the performance numbers
git checkout 9a5e4137299741e1b6fb7aa7f5a6a853e5dd2295
# Setup the models
python3 install.py
# Colect eager mode performance using the following command. The results will be
# stored at .userbenchmark/cpu/metric-<timestamp>.json.
python3 run_benchmark.py cpu --model BERT_pytorch,hf_Bert,hf_Bert_large,hf_GPT2,hf_Albert,hf_Bart,hf_BigBird,hf_DistilBert,hf_GPT2_large,dlrm,hf_T5,mnasnet1_0,mobilenet_v2,mobilenet_v3_large,squeezenet1_1,timm_efficientnet,shufflenet_v2_x1_0,timm_regnet,resnet50,soft_actor_critic,phlippe_densenet,resnet152,resnet18,resnext50_32x4d,densenet121,phlippe_resnet,doctr_det_predictor,timm_vovnet,alexnet,doctr_reco_predictor,vgg16,dcgan,yolov3,pytorch_stargan,hf_Longformer,timm_nfnet,timm_vision_transformer,timm_vision_transformer_large,nvidia_deeprecommender,demucs,tts_angular,hf_Reformer,pytorch_CycleGAN_and_pix2pix,functorch_dp_cifar10,pytorch_unet --test eval --metrics="latencies,cpu_peak_mem"
# Collect torch.compile mode performance with inductor backend
# and weights pre-packing enabled. The results will be stored at
# .userbenchmark/cpu/metric-<timestamp>.json
python3 run_benchmark.py cpu --model BERT_pytorch,hf_Bert,hf_Bert_large,hf_GPT2,hf_Albert,hf_Bart,hf_BigBird,hf_DistilBert,hf_GPT2_large,dlrm,hf_T5,mnasnet1_0,mobilenet_v2,mobilenet_v3_large,squeezenet1_1,timm_efficientnet,shufflenet_v2_x1_0,timm_regnet,resnet50,soft_actor_critic,phlippe_densenet,resnet152,resnet18,resnext50_32x4d,densenet121,phlippe_resnet,doctr_det_predictor,timm_vovnet,alexnet,doctr_reco_predictor,vgg16,dcgan,yolov3,pytorch_stargan,hf_Longformer,timm_nfnet,timm_vision_transformer,timm_vision_transformer_large,nvidia_deeprecommender,demucs,tts_angular,hf_Reformer,pytorch_CycleGAN_and_pix2pix,functorch_dp_cifar10,pytorch_unet --test eval --torchdynamo inductor --freeze_prepack_weights --metrics="latencies,cpu_peak_mem"
脚本成功完成后,结果将以JSON格式存储. 以下是样本输出
{
"name": "cpu"
"environ": {
"pytorch_git_version": "d44533f9d073df13895333e70b66f81c513c1889"
},
"metrics": {
"BERT_pytorch-eval_latency": 56.3769865,
"BERT_pytorch-eval_cmem": 0.4169921875
}
}
HUGGING FACE 基准测试脚本
Google T5 Small Text Translation模型是我们进行基准测试的大约30个Hugging Face模型之一. 我们用它作为示例模型来演示如何在eager和compile模式下运行推理. 在compile模式下运行所需的额外配置和API在粗体中突出显示. 将以下脚本保存为google_t5_small_text_translation.py
.
import argparse
from transformers import T5Tokenizer, T5Model
import torch
from torch.profiler import profile, record_function, ProfilerActivity
import torch._inductor.config as config
config.cpp.weight_prepack=True
config.freezing=True
def test_inference(mode, num_iter):
tokenizer = T5Tokenizer.from_pretrained("t5-small")
model = T5Model.from_pretrained("t5-small")
input_ids = tokenizer(
"Studies have been shown that owning a dog is good for you", return_tensors="pt"
).input_ids # Batch size 1
decoder_input_ids = tokenizer("Studies show that", return_tensors="pt").input_ids # Batch size 1
if (mode == 'compile'):
model = torch.compile(model)
with torch.no_grad():
for _ in range(50):
outputs = model(input_ids=input_ids, decoder_input_ids=decoder_input_ids)
with profile(activities=[ProfilerActivity.CPU]) as prof:
with record_function("model_inference"):
for _ in range(num_iter):
outputs = model(input_ids=input_ids, decoder_input_ids=decoder_input_ids)
print(prof.key_averages().table(sort_by="self_cpu_time_total"))
def main() -> None:
global m, args
parser = argparse.ArgumentParser(__doc__)
parser.add_argument(
"-m",
"--mode",
choices=["eager", "compile"],
default="eager",
help="Which test to run.",
)
parser.add_argument(
"-n",
"--number",
type=int,
default=100,
help="how many iterations to run.",
)
args = parser.parse_args()
test_inference(args.mode, args.number)
if __name__ == "__main__":
main()
按以下步骤运行脚本
# Set OMP_NUM_THREADS to number of vcpus to 4 because
# the scripts are running inference in sequence, and
# they don't need large number of vcpus
export OMP_NUM_THREADS=4
# Install the dependencies
python3 -m pip install transformers
# Run the inference script in Eager mode
# using number of iterations as 1 just to show the torch profiler output
# but for the benchmarking, we used 1000 iterations.
python3 google_t5_small_text_translation.py -n 1 -m eager
# Run the inference script in torch compile mode
python3 google_t5_small_text_translation.py -n 1 -m compile
脚本成功完成后,将打印torch profiler的输出,其中包含torch算子的延迟细分. 以下是torch profiler的样本输出
# Torch profiler output for the eager mode run on c7g.xl (4vcpu)
------------------------ ------------ ------------ ------------ ------------ ------------ ------------
Name Self CPU % Self CPU CPU total % CPU total CPU time avg # of Calls
------------------------ ------------ ------------ ------------ ------------ ------------ ------------
aten::mm 40.71% 12.502ms 40.71% 12.502ms 130.229us 96
model_inference 26.44% 8.118ms 100.00% 30.708ms 30.708ms 1
aten::bmm 6.85% 2.102ms 9.47% 2.908ms 80.778us 36
aten::matmul 3.73% 1.146ms 57.26% 17.583ms 133.205us 132
aten::select 1.88% 576.000us 1.90% 583.000us 0.998us 584
aten::transpose 1.51% 464.000us 1.83% 563.000us 3.027us 186
------------------------ ------------ ------------ ------------ ------------ ------------ -------------------
Self CPU time total: 30.708ms
# Torch profiler output for the compile mode run for the same model on the same instance
--------------------------------- ------------ ------------ ------------ ------------ ------------ ------------
Name Self CPU % Self CPU CPU total % CPU total CPU time avg # of Calls
--------------------------------- ------------ ------------ ------------ ------------ ------------ ------------
mkldnn::_linear_pointwise 37.98% 5.461ms 45.91% 6.602ms 68.771us 96
Torch-Compiled Region 29.56% 4.251ms 98.53% 14.168ms 14.168ms 1
aten::bmm 14.90% 2.143ms 21.73% 3.124ms 86.778us 36
aten::select 4.51% 648.000us 4.62% 665.000us 1.155us 576
aten::view 3.29% 473.000us 3.29% 473.000us 1.642us 288
aten::empty 2.53% 364.000us 2.53% 364.000us 3.165us 115
--------------------------------- ------------ ------------ ------------ ------------ ------------ --------------------
Self CPU time total: 14.379ms
技术深入探讨:挑战和优化细节
torch.compile的核心是新技术——TorchDynamo、AOTDispatcher和TorchInductor.
TorchDynamo使用Python帧评估钩子安全地捕获PyTorch程序
AOTDispatcher将PyTorch的自动求导引擎重载为跟踪自动求导,用于生成提前的反向跟踪.
TorchInductor是一个深度学习编译器,可为多种加速器和后端生成快速代码.
图3:PyTorch编译过程
调用torch.compile时,torch dynamo重写Python字节码以将PyTorch操作序列提取到FX 图中,然后使用inductor后端进行编译. 对于典型的推理场景(其中图被冻结且梯度计算被禁用),inductor会调用平台特定的优化,例如将图重写为性能更好的算子、算子融合和权重预打包.
然而,在Graviton3上,inductor无法执行任何这些优化,因为没有定义aarch64后端. 为了解决这个问题,我们扩展了inductor的FX阶段,使其选择oneDNN算子用于在具有ACL后端的Graviton3处理器上进行线性层编译. 以下是相关的代码片段
packed_weight_op = (
mkldnn._reorder_linear_weight
if (is_bf16_weight or mkldnn._is_mkldnn_acl_supported())
packed_linear_inputs: Tuple[Any, ...] = (input, packed_weight_node)
if is_bf16_weight or mkldnn._is_mkldnn_acl_supported():
packed_linear_inputs += (bias, "none", [], "")
packed_linear_op = mkldnn._linear_pointwise.default
完成此操作后,FX阶段成功将matmul
算子编译为linear_pointwise
. 以下代码片段突出了原始模型中的matmul
算子
%attention_scores : [num_users=1] = call_function[target=torch.matmul](args = (%query_layer, %transpose), kwargs = {})
%attention_scores_1 : [num_users=1] = call_function[target=operator.truediv](args = (%attention_scores, 8.0), kwargs = {})
%attention_scores_2 : [num_users=1] = call_function[target=operator.add](args = (%attention_scores_1, %extended_attention_mask_3), kwargs = {})
以下代码片段突出了编译后图中的linear_pointwise
算子
%_linear_pointwise_default_140 : [num_users=2] = call_function[target=torch.ops.mkldnn._linear_pointwise.default](args = (%add_7, %_frozen_param278, %_frozen_param16, none, [], ), kwargs = {})
%mul_5 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%_linear_pointwise_default_140, 0.5), kwargs = {})
%mul_6 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%_linear_pointwise_default_140, 0.7071067811865476), kwargs = {})
%erf : [num_users=1] = call_function[target=torch.ops.aten.erf.default](args = (%mul_6,), kwargs = {})
%add_8 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%erf, 1), kwargs = {})
这就完成了在AWS Graviton3处理器上将图编译为优化算子所需的torch inductor更改. 接下来是实际的推理过程,编译后的图将被调度运行. 我们在inductor编译期间选择了带有ACL的oneDNN作为后端,因此新算子如预期地被调度到oneDNN运行,例如mkldnn._linear_pointwise
. 然而,由于oneDNN ACL原语的不足,这些算子使用了C++参考内核而不是优化的ACL内核来运行. 因此,compile模式的性能仍然显著落后于eager模式的性能.
oneDNN ACL原语对torch.compile模式的支持主要存在三个方面的不足. 以下部分将详细介绍这些不足.
1. ACL原语不支持块状布局的权重
最初为eager模式设计的ACL原语仅支持标准channels last (NHWC)格式的权重,不进行预打包. 而将权重预打包成块状布局是inductor编译阶段的主要优化之一,其中权重会根据运行时平台进行重新排序分块. 这避免了运行通用矩阵乘法 (GEMM) 时冗余和即时的重新排序,否则这将成为推理性能的瓶颈. 但ACL原语不支持块状布局,因此算子使用了oneDNN C++参考内核来运行.
2. oneDNN不支持混合精度原语
AWS Graviton3处理器支持bfloat16 MMLA指令,可用于通过bfloat16 GEMM作为混合精度计算来加速fp32推理. ACL支持bfloat16混合精度GEMM内核,并已集成到oneDNN中作为现有fp32算子的快速数学计算选项. 然而,快速数学方法对compile模式不起作用,因为有权重预打包优化. compile模式需要在oneDNN中显式实现混合精度原语才能使用bfloat16加速.
3. ACL原语不支持某些激活函数的融合内核
在eager模式下,算子单独调度,因为模型在遇到时立即独立运行. 而在compile模式下,算子融合是另一个重要的优化,将算子融合以提高运行时效率. 例如,高斯误差线性单元 (GELU) 是基于transformers的神经网络架构中最广泛使用的激活函数之一. 因此,通常会有一个线性层(包含矩阵乘法),接着是GELU激活. 作为将模型编译为高效算子的一部分,torch inductor将matmul和GELU融合为一个单独的linearpointwise+gelu算子. 然而,oneDNN ACL原语不支持包含GELU的融合内核.
我们通过扩展oneDNN原语来处理额外的布局和新的原语定义,解决了这些不足. 以下部分将详细介绍这些优化.
优化1:扩展ACL原语以接受块状布局的权重张量
我们扩展了ACL原语,使其除了标准NHWC格式外,还接受块状布局. 以下是相关的代码片段
const bool is_weights_md_format_ok
= utils::one_of(weights_format_kind_received,
format_kind::any, format_kind::blocked);
const memory_desc_t weights_md_received = weights_md_;
acl_utils::reorder_to_weight_format(aip.wei_tensor_info,
weights_md_, expected_weight_format, inner_dim, o_dim,
remaining_dims, {});
ACL_CHECK_SUPPORT(
(weights_format_kind_received == format_kind::blocked)
&& !(dnnl_memory_desc_equal(
&weights_md_received, &weights_md_)),
"specified blocked format not supported by ACL, use "
"format_kind_t::any to find a supported blocked format for "
"your platform");
优化2:定义新的ACL原语以处理混合精度算子(权重为bfloat16,激活为fp32)
我们定义了混合精度原语定义,并更新了现有的oneDNN ACL fp32原语以处理bfloat16张量.
/* With graph compilation, we are able to reorder and pre-pack the weights during the model load
* and compilation phase itself so that redundant and on-the-fly reorders can be avoided.
* This primitive definition is to support gemm fastmath mode for the compile scenario where src is
* in fp32 and weights are in bf16
*/
{{forward, f32, bf16, f32}, {
CPU_INSTANCE_AARCH64_ACL(acl_inner_product_fwd_t)
nullptr,
}},
优化3:禁用torch inductor中的算子融合阶段
我们绕过了torch inductor中的算子融合阶段,这样编译后的图就不包含GELU融合算子. 这是一个临时解决方案,以便在torch.compile中启用ACL内核. 目前正在进行工作,以在未来的PyTorch版本中启用算子融合阶段. 通过这个权宜之计,我们成功地将线性层调度到ACL. 如下面的torch.profiler输出所示,原始模型中的aten::addmm
(矩阵乘法算子的一种变体)和aten::gelu
(如图4突出显示)被编译为mkldnn::_linear_pointwise
,而没有进行gelu算子融合(如图5突出显示).
--------------------------- ------------ ------------ ------------ ------------ ------------ ------------
Name Self CPU % Self CPU CPU total % CPU total CPU time avg # of Calls
--------------------------- ------------ ------------ ------------ ------------ ------------ ------------
aten::addmm 73.32% 46.543ms 74.49% 47.287ms 647.767us 73
model_inference 9.92% 6.296ms 100.00% 63.479ms 63.479ms 1
aten::bmm 4.37% 2.776ms 5.46% 3.467ms 144.458us 24
aten::copy_ 1.74% 1.102ms 1.74% 1.102ms 8.103us 136
aten::gelu 1.50% 950.000us 1.50% 950.000us 79.167us 12
图4:Hugging Face bert base模型在Eager模式下进行推理的torch.profiler输出,显示了addmm和gelu算子
----------------------------------------------------- ------------ ------------ ------------ ------------ ------------ ------------
Name Self CPU % Self CPU CPU total % CPU total CPU time avg # of Calls
----------------------------------------------------- ------------ ------------ ------------ ------------ ------------ ------------
mkldnn::_linear_pointwise 53.61% 15.529ms 57.53% 16.665ms 228.288us 73
Torch-Compiled Region 36.95% 10.705ms 99.31% 28.769ms 28.769ms 1
aten::_scaled_dot_product_flash_attention_for_cpu 3.67% 1.064ms 4.43% 1.284ms 107.000us 12
aten::view 1.97% 572.000us 1.97% 572.000us 2.509us 228
aten::empty 1.38% 399.000us 1.38% 399.000us 3.270us 122
图5:Hugging Face Bert base模型在torch.compile模式下进行推理的torch.profiler输出,显示了未进行gelu融合的linear_pointwise算子
最后,gelu
算子被编译成erf
(误差函数),并被调度到inductor自动向量化后端. 以下代码片段显示了编译后图中的erf
算子以及使用libm.so
运行它.
%_linear_pointwise_default_140 : [num_users=2] = call_function[target=torch.ops.mkldnn._linear_pointwise.default](args = (%add_7, %_frozen_param278, %_frozen_param16, none, [], ), kwargs = {})
%mul_5 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%_linear_pointwise_default_140, 0.5), kwargs = {})
%mul_6 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%_linear_pointwise_default_140, 0.7071067811865476), kwargs = {})
%erf : [num_users=1] = call_function[target=torch.ops.aten.erf.default](args = (%mul_6,), kwargs = {})
%add_8 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%erf, 1), kwargs = {})
图6:后梯度阶段后的代码片段,显示编译后图中的erf函数
0.82% 0.40% python3 libm.so.6 [.] erff32
0.05% 0.00% python3 libtorch_python.so [.] torch::autograd::THPVariable_erf
0.05% 0.00% python3 libtorch_cpu.so [.] at::_ops::erf::call
图7:Linux perf报告,显示erf被调度到libm.so
通过这项工作,我们利用inductor图编译以及oneDNN+ACL后端,优化了Graviton3处理器上的torch.compile
性能.
TorchBench增强功能
为了展示AWS Graviton3处理器上的torch.compile
性能提升,我们扩展了TorchBench框架,添加了一个新参数,用于启用图冻结和权重预打包,并在eval测试模式下禁用torch auto grad. 以下是相关的代码片段
parser.add_argument(
"—freeze_prepack_weights",
action='store_true',
help="set to freeze the graph and prepack weights",
)
if args.freeze_prepack_weights:
torch._inductor.config.freezing=True
torch._inductor.config.cpp.weight_prepack=True
图8:在TorchBench中为torchdynamo后端添加了freeze_prepack_weights
选项,以展示AWS Graviton3处理器上的torch.compile
性能提升
我们已将所有优化合并到上游,从PyTorch 2.3.1开始,这些优化在torch Python wheels和AWS Graviton PyTorch DLC中得到支持.
后续计划
接下来,我们将扩展torch inductor CPU后端支持以编译Llama模型,并添加对融合GEMM内核的支持,以在AWS Graviton3处理器上启用torch inductor算子融合优化.
结论
在本教程中,我们介绍了如何在AWS Graviton3上的EC2实例上优化torch.compile
性能,如何使用这些优化来提高PyTorch模型推理性能,并演示了由此带来的加速效果. 希望你能试一试! 如果你在Graviton上需要任何ML软件支持,请在AWS Graviton Technical Guide GitHub上提交issue.
致谢
我们要感谢PyTorch社区提供的基准torch.compile
框架以及他们为进一步优化它所做的不懈努力.
参考资料:https://pytorch.ac.cn/assets/pytorch2-2.pdf
作者
Sunita Nadampalli是AWS的软件开发经理兼AI/ML专家. 她负责AWS Graviton软件在AI/ML和HPC工作负载上的性能优化. 她热衷于开源软件开发,并致力于为基于Arm ISA的SoC提供高性能和可持续的软件解决方案.