注意
点击 此处 下载完整的示例代码
训练数据集示例¶
这是用于训练示例的数据集。它使用 PyTorch Lightning 库。
import os.path
import tarfile
from typing import Callable, Optional
import fsspec
import numpy
import pytorch_lightning as pl
from PIL import Image
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from torchvision.datasets.folder import is_image_file
from tqdm import tqdm
这使用 torchvision 来定义一个数据集,我们将在之后将其用于我们的 Pytorch Lightning 数据模块。
class ImageFolderSamplesDataset(datasets.ImageFolder):
"""
ImageFolderSamplesDataset is a wrapper around ImageFolder that allows you to
limit the number of samples.
"""
def __init__(
self,
root: str,
transform: Optional[Callable[..., object]] = None,
num_samples: Optional[int] = None,
**kwargs: object,
) -> None:
"""
Args:
num_samples: optional. limits the size of the dataset
"""
super().__init__(root, transform=transform)
self.num_samples = num_samples
def __len__(self) -> int:
if self.num_samples is not None:
return self.num_samples
return super().__len__()
为了方便使用,我们定义了一个闪电数据模块,以便我们可以在我们的训练器和其他需要加载数据的组件之间重复使用它。
# pyre-fixme[13]: Attribute `test_ds` is never initialized.
# pyre-fixme[13]: Attribute `train_ds` is never initialized.
# pyre-fixme[13]: Attribute `val_ds` is never initialized.
class TinyImageNetDataModule(pl.LightningDataModule):
"""
TinyImageNetDataModule is a pytorch LightningDataModule for the tiny
imagenet dataset.
"""
train_ds: ImageFolderSamplesDataset
val_ds: ImageFolderSamplesDataset
test_ds: ImageFolderSamplesDataset
def __init__(
self, data_dir: str, batch_size: int = 16, num_samples: Optional[int] = None
) -> None:
super().__init__()
self.data_dir = data_dir
self.batch_size = batch_size
self.num_samples = num_samples
def setup(self, stage: Optional[str] = None) -> None:
# Setup data loader and transforms
img_transform = transforms.Compose(
[
transforms.ToTensor(),
]
)
self.train_ds = ImageFolderSamplesDataset(
root=os.path.join(self.data_dir, "train"),
transform=img_transform,
num_samples=self.num_samples,
)
self.val_ds = ImageFolderSamplesDataset(
root=os.path.join(self.data_dir, "val"),
transform=img_transform,
num_samples=self.num_samples,
)
self.test_ds = ImageFolderSamplesDataset(
root=os.path.join(self.data_dir, "test"),
transform=img_transform,
num_samples=self.num_samples,
)
def train_dataloader(self) -> DataLoader:
return DataLoader(self.train_ds, batch_size=self.batch_size)
def val_dataloader(self) -> DataLoader:
return DataLoader(self.val_ds, batch_size=self.batch_size)
def test_dataloader(self) -> DataLoader:
return DataLoader(self.test_ds, batch_size=self.batch_size)
def teardown(self, stage: Optional[str] = None) -> None:
pass
为了在不同组件之间传递数据,我们使用 fsspec,它允许我们读取/写入云或本地文件存储。
def download_data(remote_path: str, tmpdir: str) -> str:
"""
download_data downloads the training data from the specified remote path via
fsspec and places it in the tmpdir unextracted.
"""
if os.path.isdir(remote_path):
print("dataset path is a directory, using as is")
return remote_path
tar_path = os.path.join(tmpdir, "data.tar.gz")
print(f"downloading dataset from {remote_path} to {tar_path}...")
fs, _, rpaths = fsspec.get_fs_token_paths(remote_path)
assert len(rpaths) == 1, "must have single path"
fs.get(rpaths[0], tar_path)
data_path = os.path.join(tmpdir, "data")
print(f"extracting {tar_path} to {data_path}...")
with tarfile.open(tar_path, mode="r") as f:
f.extractall(data_path)
return data_path
def create_random_data(output_path: str, num_images: int = 250) -> None:
"""
Fills the given path with randomly generated 64x64 images.
This can be used for quick testing of the workflow of the model.
Does NOT pack the files into a tar, but does preprocess them.
"""
train_path = os.path.join(output_path, "train")
class1_train_path = os.path.join(train_path, "class1")
class2_train_path = os.path.join(train_path, "class2")
val_path = os.path.join(output_path, "val")
class1_val_path = os.path.join(val_path, "class1")
class2_val_path = os.path.join(val_path, "class2")
test_path = os.path.join(output_path, "test")
class1_test_path = os.path.join(test_path, "class1")
class2_test_path = os.path.join(test_path, "class2")
paths = [
class1_train_path,
class1_val_path,
class1_test_path,
class2_train_path,
class2_val_path,
class2_test_path,
]
for path in paths:
try:
os.makedirs(path)
except FileExistsError:
pass
for i in range(num_images):
pixels = numpy.random.rand(64, 64, 3) * 255
im = Image.fromarray(pixels.astype("uint8")).convert("RGB")
im.save(os.path.join(path, f"rand_image_{i}.jpeg"))
process_images(output_path)
def process_images(img_root: str) -> None:
print("transforming images...")
transform = transforms.Compose(
[
transforms.ToTensor(),
transforms.Normalize((0.5,), (0.5,)),
transforms.ToPILImage(),
]
)
image_files = []
for root, _, fnames in os.walk(img_root):
for fname in fnames:
path = os.path.join(root, fname)
if not is_image_file(path):
continue
image_files.append(path)
for path in tqdm(image_files, miniters=int(len(image_files) / 2000)):
f = Image.open(path)
f = transform(f)
f.save(path)
# sphinx_gallery_thumbnail_path = '_static/img/gallery-lib.png'
脚本的总运行时间:(0 分钟 0.000 秒)