快捷方式

数据预处理应用示例

这是一个简单的 TorchX 应用,它通过 HTTP 下载一些数据,使用 torchvision 对图像进行规范化,然后通过 fsspec 重新上传数据。

使用

注意

datapreproc 应用是一个单进程 Python 程序,因此对于本地运行,可以将其作为常规 Python 程序运行:python ./datapreproc.py。TorchX 允许您在远程集群上运行此应用。

要使用 TorchX 在本地启动(参见上面的说明),请运行

$ torchx run -s local_cwd utils.python       --script ./datapreproc/datapreproc.py       --       --input_path="http://cs231n.stanford.edu/tiny-imagenet-200.zip"       --output_path=/tmp/torchx/datapreproc

要将此应用启动到远程集群,只需在 -s 选项中指定不同的调度器即可。

$ torchx run -s kubernetes -cfg queue=foo,namespace=bar utils.python       --script ./datapreproc/datapreproc.py       --       --input_path="http://cs231n.stanford.edu/tiny-imagenet-200.zip"       --output_path=/tmp/torchx/datapreproc
import argparse
import os
import sys
import tarfile
import tempfile
import zipfile
from typing import List

import fsspec
from PIL import Image
from torchvision import transforms
from torchvision.datasets.folder import is_image_file
from tqdm import tqdm


def parse_args(argv: List[str]) -> argparse.Namespace:
    parser = argparse.ArgumentParser(
        description="example data preprocessing",
    )
    parser.add_argument(
        "--input_path",
        type=str,
        help="dataset to download",
        default="http://cs231n.stanford.edu/tiny-imagenet-200.zip",
    )
    parser.add_argument(
        "--output_path",
        type=str,
        help="remote path to save the .tar.gz data to",
        required=True,
    )
    parser.add_argument(
        "--limit",
        type=int,
        help="limit number of processed examples",
    )
    return parser.parse_args(argv)


def download_and_extract_zip_archive(url: str, path: str) -> None:
    with fsspec.open(url, "rb") as f:
        with zipfile.ZipFile(f, "r") as zip_ref:
            zip_ref.extractall(path)


def main(argv: List[str]) -> None:
    args = parse_args(argv)
    with tempfile.TemporaryDirectory() as tmpdir:
        print(f"downloading {args.input_path} to {tmpdir}...")
        download_and_extract_zip_archive(args.input_path, tmpdir)

        img_root = os.path.join(
            tmpdir,
            os.path.splitext(os.path.basename(args.input_path))[0],
        )
        print(f"img_root={img_root}")

        print("transforming images...")
        transform = transforms.Compose(
            [
                transforms.ToTensor(),
                transforms.Normalize((0.5,), (0.5,)),
                transforms.ToPILImage(),
            ]
        )

        image_files = []
        for root, _, fnames in os.walk(img_root):
            for fname in fnames:
                path = os.path.join(root, fname)
                if not is_image_file(path):
                    continue
                image_files.append(path)

                if args.limit and len(image_files) > args.limit:
                    break
        for path in tqdm(image_files, miniters=int(len(image_files) / 2000)):
            f = Image.open(path)
            f = transform(f)
            f.save(path)

        tar_path = os.path.join(tmpdir, "out.tar.gz")
        print(f"packing images into {tar_path}...")
        with tarfile.open(tar_path, mode="w:gz") as f:
            f.add(img_root, arcname="")

        print(f"uploading dataset to {args.output_path}...")
        fs, _, rpaths = fsspec.get_fs_token_paths(args.output_path)
        assert len(rpaths) == 1, "must have single output path"
        if fs.exists(rpaths[0]):
            fs.rm(rpaths[0])
        fs.put(tar_path, rpaths[0])


if __name__ == "__main__" and "NOTEBOOK" not in globals():
    main(sys.argv[1:])


# sphinx_gallery_thumbnail_path = '_static/img/gallery-app.png'

脚本的总运行时间:(0 分钟 0.000 秒)

Sphinx-Gallery 生成的图库

文档

访问 PyTorch 的全面开发者文档

查看文档

教程

获取针对初学者和高级开发人员的深入教程

查看教程

资源

查找开发资源并获得问题的解答

查看资源