注意
点击 这里 下载完整的示例代码
训练器示例¶
这是一个使用 PyTorch Lightning 训练模型的 TorchX 应用程序示例。
此应用程序仅使用标准 OSS 库,没有运行时 torchx 依赖项。它使用 fsspec 保存和加载数据和模型,这使得应用程序独立于其运行的环境。
用法¶
要将训练器作为 ddp 应用程序在本地运行,使用 1 个节点和 2 个每个节点的工作程序(世界大小 = 2)
$ torchx run -s local_cwd dist.ddp
-j 1x2
--script ./lightning/train.py
--
--epochs=1
--output_path=/tmp/torchx/train
--log_path=/tmp/torchx/logs
--skip_export
注意
--
用于区分组件 (dist.ddp
) 和应用程序参数。
使用 --help
选项查看应用程序选项的完整列表
$ torchx run -s local_cwd dist.ddp -j 1x1 --script ./lightning/train.py -- --help
它与 ./train.py --help
相当。要在远程调度器上运行,请使用 -s
选项指定调度器。根据远程调度器的类型,您可能需要使用 -cfg
选项传递其他调度器配置。有关更多详细信息,请参见 远程调度器。
import argparse
import os
import sys
import tempfile
from typing import List, Optional
import pytorch_lightning as pl
import torch
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.loggers import TensorBoardLogger
from torchx.examples.apps.lightning.data import (
create_random_data,
download_data,
TinyImageNetDataModule,
)
from torchx.examples.apps.lightning.model import (
export_inference_model,
TinyImageNetModel,
)
from torchx.examples.apps.lightning.profiler import SimpleLoggingProfiler
# ensure data and module are on the path
sys.path.append(".")
def parse_args(argv: List[str]) -> argparse.Namespace:
parser = argparse.ArgumentParser(description="pytorch lightning TorchX example app")
parser.add_argument(
"--epochs", type=int, default=3, help="number of epochs to train"
)
parser.add_argument("--lr", type=float, help="learning rate")
parser.add_argument(
"--batch_size", type=int, default=32, help="batch size to use for training"
)
parser.add_argument("--num_samples", type=int, default=10, help="num_samples")
parser.add_argument(
"--data_path",
type=str,
help="path to load the training data from, if not provided, random data will be generated",
)
parser.add_argument("--skip_export", action="store_true")
parser.add_argument("--load_path", type=str, help="checkpoint path to load from")
parser.add_argument(
"--output_path",
type=str,
help="path to place checkpoints and model outputs, if not specified, checkpoints are not saved",
)
parser.add_argument(
"--log_path",
type=str,
help="path to place the tensorboard logs",
default="/tmp",
)
parser.add_argument(
"--layers",
nargs="+",
type=int,
help="the MLP hidden layers and sizes, used for neural architecture search",
)
return parser.parse_args(argv)
def get_model_checkpoint(args: argparse.Namespace) -> Optional[ModelCheckpoint]:
if not args.output_path:
return None
# Note: It is important that each rank behaves the same.
# All of the ranks, or none of them should return ModelCheckpoint
# Otherwise, there will be deadlock for distributed training
return ModelCheckpoint(
monitor="train_loss",
dirpath=args.output_path,
save_last=True,
)
def main(argv: List[str]) -> None:
with tempfile.TemporaryDirectory() as tmpdir:
args = parse_args(argv)
# Init our model
model = TinyImageNetModel(args.layers)
print(model)
# Download and setup the data module
if not args.data_path:
data_path = os.path.join(tmpdir, "data")
os.makedirs(data_path)
create_random_data(data_path)
else:
data_path = download_data(args.data_path, tmpdir)
data = TinyImageNetDataModule(
data_dir=data_path,
batch_size=args.batch_size,
num_samples=args.num_samples,
)
# Setup model checkpointing
checkpoint_callback = get_model_checkpoint(args)
callbacks = []
if checkpoint_callback:
callbacks.append(checkpoint_callback)
if args.load_path:
print(f"loading checkpoint: {args.load_path}...")
model.load_from_checkpoint(checkpoint_path=args.load_path)
logger = TensorBoardLogger(
save_dir=args.log_path, version=1, name="lightning_logs"
)
# Initialize a trainer
trainer = pl.Trainer(
num_nodes=int(os.environ.get("GROUP_WORLD_SIZE", 1)),
accelerator="gpu" if torch.cuda.is_available() else "cpu",
devices=int(os.environ.get("LOCAL_WORLD_SIZE", 1)),
strategy="ddp",
logger=logger,
max_epochs=args.epochs,
callbacks=callbacks,
profiler=SimpleLoggingProfiler(logger),
)
# Train the model ⚡
trainer.fit(model, data)
print(
f"train acc: {model.train_acc.compute()}, val acc: {model.val_acc.compute()}"
)
rank = int(os.environ.get("RANK", 0))
if rank == 0 and not args.skip_export and args.output_path:
# Export the inference model
export_inference_model(model, args.output_path, tmpdir)
if __name__ == "__main__" and "NOTEBOOK" not in globals():
main(sys.argv[1:])
# sphinx_gallery_thumbnail_path = '_static/img/gallery-app.png'
脚本的总运行时间:(0 分钟 0.000 秒)