PyTorch 通过易于使用的前端、分布式训练以及丰富的工具和库生态系统,实现了快速、灵活的实验和高效的生产部署。
import torch
class MyModule(torch.nn.Module):
def __init__(self, N, M):
super(MyModule, self).__init__()
self.weight = torch.nn.Parameter(torch.rand(N, M))
def forward(self, input):
if input.sum() > 0:
output = self.weight.mv(input)
else:
output = self.weight + input
return output
# Compile the model code to a static representation
my_script_module = torch.jit.script(MyModule(3, 4))
# Save the compiled code and model data so it can be loaded elsewhere
my_script_module.save("my_script_module.pt")
生产就绪
借助 TorchScript,PyTorch 既能在 Eager 模式下提供易用性和灵活性,又能无缝切换至 Graph 模式,从而在 C++ 运行时环境中获得速度、优化和功能上的优势。
TorchServe
TorchServe 是一款易于使用的工具,用于大规模部署 PyTorch 模型。它与云平台和环境无关,支持多模型服务、日志记录、指标监控以及创建用于应用程序集成的 RESTful 端点等功能。
## Convert the model from PyTorch to TorchServe format
torch-model-archiver --model-name densenet161 \
--version 1.0 --model-file serve/examples/image_classifier/densenet_161/model.py \
--serialized-file densenet161-8d451a50.pth \
--extra-files serve/examples/image_classifier/index_to_name.json \
--handler image_classifier
## Host your PyTorch model
torchserve --start --model-store model_store --models densenet161=densenet161.mar
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel
dist.init_process_group(backend='gloo')
model = DistributedDataParallel(model)
分布式训练
利用 Python 和 C++ 中可用的集合操作异步执行和点对点通信的原生支持,优化研究和生产环境下的性能。
移动端(实验性功能)
PyTorch 支持从 Python 开发到部署至 iOS 和 Android 的端到端工作流。它扩展了 PyTorch API,涵盖了在移动应用中集成机器学习所需的常见预处理和集成任务。
## Save your model
torch.jit.script(model).save("my_mobile_model.pt")
## iOS prebuilt binary
pod ‘LibTorch’
## Android prebuilt binary
implementation 'org.pytorch:pytorch_android:1.3.0'
## Run your model (Android example)
Tensor input = Tensor.fromBlob(data, new long[]{1, data.length});
IValue output = module.forward(IValue.tensor(input));
float[] scores = output.getTensor().getDataAsFloatArray();
import torchvision.models as models
resnet18 = models.resnet18(pretrained=True)
alexnet = models.alexnet(pretrained=True)
squeezenet = models.squeezenet1_0(pretrained=True)
vgg16 = models.vgg16(pretrained=True)
densenet = models.densenet161(pretrained=True)
inception = models.inception_v3(pretrained=True)
强大的生态系统
一个由研究人员和开发人员组成的活跃社区,构建了丰富的工具和库生态系统,用于扩展 PyTorch 并支持从计算机视觉到强化学习等领域的开发工作。
原生 ONNX 支持
以标准的 ONNX(开放神经网络交换)格式导出模型,以便直接访问兼容 ONNX 的平台、运行时、可视化工具等。
import torch.onnx
import torchvision
dummy_input = torch.randn(1, 3, 224, 224)
model = torchvision.models.alexnet(pretrained=True)
torch.onnx.export(model, dummy_input, "alexnet.onnx")
#include <torch/torch.h>
torch::nn::Linear model(num_features, 1);
torch::optim::SGD optimizer(model->parameters());
auto data_loader = torch::data::data_loader(dataset);
for (size_t epoch = 0; epoch < 10; ++epoch) { for (auto batch : data_loader) { auto prediction = model->forward(batch.data);
auto loss = loss_function(prediction, batch.target);
loss.backward();
optimizer.step();
}
}
C++ 前端
C++ 前端是 PyTorch 的纯 C++ 接口,遵循已有的 Python 前端设计和架构。它旨在支持高性能、低延迟和裸机 C++ 应用程序的研究与开发。
云支持
PyTorch 在各大主流云平台上均得到良好支持,通过预构建镜像、GPU 大规模训练、在生产级环境中运行模型的能力等,提供无缝的开发体验和便捷的扩展能力。
export IMAGE_FAMILY="pytorch-latest-cpu"
export ZONE="us-west1-b"
export INSTANCE_NAME="my-instance"
gcloud compute instances create $INSTANCE_NAME \
--zone=$ZONE \
--image-family=$IMAGE_FAMILY \
--image-project=deeplearning-platform-release