快捷方式

训练数据集示例

这是用于训练示例的数据集。它使用 PyTorch Lightning 库。

import os.path
import tarfile
from typing import Callable, Optional

import fsspec
import numpy
import pytorch_lightning as pl
from PIL import Image
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from torchvision.datasets.folder import is_image_file
from tqdm import tqdm

这使用 torchvision 来定义一个数据集,我们将在之后将其用于我们的 Pytorch Lightning 数据模块。

class ImageFolderSamplesDataset(datasets.ImageFolder):
    """
    ImageFolderSamplesDataset is a wrapper around ImageFolder that allows you to
    limit the number of samples.
    """

    def __init__(
        self,
        root: str,
        transform: Optional[Callable[..., object]] = None,
        num_samples: Optional[int] = None,
        **kwargs: object,
    ) -> None:
        """
        Args:
            num_samples: optional. limits the size of the dataset
        """
        super().__init__(root, transform=transform)
        self.num_samples = num_samples

    def __len__(self) -> int:
        if self.num_samples is not None:
            return self.num_samples
        return super().__len__()

为了方便使用,我们定义了一个闪电数据模块,以便我们可以在我们的训练器和其他需要加载数据的组件之间重复使用它。

# pyre-fixme[13]: Attribute `test_ds` is never initialized.
# pyre-fixme[13]: Attribute `train_ds` is never initialized.
# pyre-fixme[13]: Attribute `val_ds` is never initialized.
class TinyImageNetDataModule(pl.LightningDataModule):
    """
    TinyImageNetDataModule is a pytorch LightningDataModule for the tiny
    imagenet dataset.
    """

    train_ds: ImageFolderSamplesDataset
    val_ds: ImageFolderSamplesDataset
    test_ds: ImageFolderSamplesDataset

    def __init__(
        self, data_dir: str, batch_size: int = 16, num_samples: Optional[int] = None
    ) -> None:
        super().__init__()
        self.data_dir = data_dir
        self.batch_size = batch_size
        self.num_samples = num_samples

    def setup(self, stage: Optional[str] = None) -> None:
        # Setup data loader and transforms
        img_transform = transforms.Compose(
            [
                transforms.ToTensor(),
            ]
        )
        self.train_ds = ImageFolderSamplesDataset(
            root=os.path.join(self.data_dir, "train"),
            transform=img_transform,
            num_samples=self.num_samples,
        )
        self.val_ds = ImageFolderSamplesDataset(
            root=os.path.join(self.data_dir, "val"),
            transform=img_transform,
            num_samples=self.num_samples,
        )
        self.test_ds = ImageFolderSamplesDataset(
            root=os.path.join(self.data_dir, "test"),
            transform=img_transform,
            num_samples=self.num_samples,
        )

    def train_dataloader(self) -> DataLoader:
        return DataLoader(self.train_ds, batch_size=self.batch_size)

    def val_dataloader(self) -> DataLoader:
        return DataLoader(self.val_ds, batch_size=self.batch_size)

    def test_dataloader(self) -> DataLoader:
        return DataLoader(self.test_ds, batch_size=self.batch_size)

    def teardown(self, stage: Optional[str] = None) -> None:
        pass

为了在不同组件之间传递数据,我们使用 fsspec,它允许我们读取/写入云或本地文件存储。

def download_data(remote_path: str, tmpdir: str) -> str:
    """
    download_data downloads the training data from the specified remote path via
    fsspec and places it in the tmpdir unextracted.
    """
    if os.path.isdir(remote_path):
        print("dataset path is a directory, using as is")
        return remote_path

    tar_path = os.path.join(tmpdir, "data.tar.gz")
    print(f"downloading dataset from {remote_path} to {tar_path}...")
    fs, _, rpaths = fsspec.get_fs_token_paths(remote_path)
    assert len(rpaths) == 1, "must have single path"
    fs.get(rpaths[0], tar_path)

    data_path = os.path.join(tmpdir, "data")
    print(f"extracting {tar_path} to {data_path}...")
    with tarfile.open(tar_path, mode="r") as f:
        f.extractall(data_path)

    return data_path


def create_random_data(output_path: str, num_images: int = 250) -> None:
    """
    Fills the given path with randomly generated 64x64 images.
    This can be used for quick testing of the workflow of the model.
    Does NOT pack the files into a tar, but does preprocess them.
    """
    train_path = os.path.join(output_path, "train")
    class1_train_path = os.path.join(train_path, "class1")
    class2_train_path = os.path.join(train_path, "class2")

    val_path = os.path.join(output_path, "val")
    class1_val_path = os.path.join(val_path, "class1")
    class2_val_path = os.path.join(val_path, "class2")

    test_path = os.path.join(output_path, "test")
    class1_test_path = os.path.join(test_path, "class1")
    class2_test_path = os.path.join(test_path, "class2")

    paths = [
        class1_train_path,
        class1_val_path,
        class1_test_path,
        class2_train_path,
        class2_val_path,
        class2_test_path,
    ]

    for path in paths:
        try:
            os.makedirs(path)
        except FileExistsError:
            pass

        for i in range(num_images):
            pixels = numpy.random.rand(64, 64, 3) * 255
            im = Image.fromarray(pixels.astype("uint8")).convert("RGB")
            im.save(os.path.join(path, f"rand_image_{i}.jpeg"))

    process_images(output_path)


def process_images(img_root: str) -> None:
    print("transforming images...")
    transform = transforms.Compose(
        [
            transforms.ToTensor(),
            transforms.Normalize((0.5,), (0.5,)),
            transforms.ToPILImage(),
        ]
    )

    image_files = []
    for root, _, fnames in os.walk(img_root):
        for fname in fnames:
            path = os.path.join(root, fname)
            if not is_image_file(path):
                continue
            image_files.append(path)
    for path in tqdm(image_files, miniters=int(len(image_files) / 2000)):
        f = Image.open(path)
        f = transform(f)
        f.save(path)


# sphinx_gallery_thumbnail_path = '_static/img/gallery-lib.png'

脚本的总运行时间:(0 分钟 0.000 秒)

由 Sphinx-Gallery 生成的画廊

文档

访问 PyTorch 的全面开发者文档

查看文档

教程

获取针对初学者和高级开发人员的深入教程

查看教程

资源

查找开发资源并获得问题的解答

查看资源