• 教程 >
  • 编写自定义数据集、数据加载器和转换
快捷方式

编写自定义数据集、数据加载器和转换

创建于:2017 年 6 月 10 日 | 最后更新:2024 年 1 月 19 日 | 最后验证:2024 年 11 月 05 日

作者Sasank Chilamkurthy

解决任何机器学习问题的大量工作都投入在准备数据上。PyTorch 提供了许多工具,使数据加载变得容易,并希望使您的代码更具可读性。在本教程中,我们将看到如何从一个重要的非平凡数据集中加载和预处理/增强数据。

要运行本教程,请确保已安装以下软件包

  • scikit-image:用于图像 io 和转换

  • pandas:用于更轻松的 csv 解析

import os
import torch
import pandas as pd
from skimage import io, transform
import numpy as np
import matplotlib.pyplot as plt
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, utils

# Ignore warnings
import warnings
warnings.filterwarnings("ignore")

plt.ion()   # interactive mode
<contextlib.ExitStack object at 0x7ff7f616b310>

我们将要处理的数据集是面部姿势。这意味着像这样注释面部

../_images/landmarked_face2.png

总共,为每张脸注释了 68 个不同的地标点。

注意

此处下载数据集,以便图像位于名为“data/faces/”的目录中。此数据集实际上是通过在来自 imagenet 的标记为“face”的几张图像上应用出色的 dlib 的姿势估计生成的。

数据集附带一个带有注释的 .csv 文件,如下所示

image_name,part_0_x,part_0_y,part_1_x,part_1_y,part_2_x, ... ,part_67_x,part_67_y
0805personali01.jpg,27,83,27,98, ... 84,134
1084239450_e76e00b7e7.jpg,70,236,71,257, ... ,128,312

让我们从 CSV 中提取一个图像名称及其注释,在本例中,以 person-7.jpg 的行索引号 65 为例。读取它,将图像名称存储在 img_name 中,并将其注释存储在 (L, 2) 数组 landmarks 中,其中 L 是该行中地标的数量。

landmarks_frame = pd.read_csv('data/faces/face_landmarks.csv')

n = 65
img_name = landmarks_frame.iloc[n, 0]
landmarks = landmarks_frame.iloc[n, 1:]
landmarks = np.asarray(landmarks, dtype=float).reshape(-1, 2)

print('Image name: {}'.format(img_name))
print('Landmarks shape: {}'.format(landmarks.shape))
print('First 4 Landmarks: {}'.format(landmarks[:4]))
Image name: person-7.jpg
Landmarks shape: (68, 2)
First 4 Landmarks: [[32. 65.]
 [33. 76.]
 [34. 86.]
 [34. 97.]]

让我们编写一个简单的辅助函数来显示图像及其地标,并使用它来显示样本。

def show_landmarks(image, landmarks):
    """Show image with landmarks"""
    plt.imshow(image)
    plt.scatter(landmarks[:, 0], landmarks[:, 1], s=10, marker='.', c='r')
    plt.pause(0.001)  # pause a bit so that plots are updated

plt.figure()
show_landmarks(io.imread(os.path.join('data/faces/', img_name)),
               landmarks)
plt.show()
data loading tutorial

数据集类

torch.utils.data.Dataset 是表示数据集的抽象类。您的自定义数据集应继承 Dataset 并覆盖以下方法

  • __len__,以便 len(dataset) 返回数据集的大小。

  • __getitem__ 以支持索引,以便可以使用 dataset[i] 来获取第 \(i\) 个样本。

让我们为我们的面部地标数据集创建一个数据集类。我们将在 __init__ 中读取 csv,但将图像的读取留给 __getitem__。这是内存高效的,因为所有图像不会一次存储在内存中,而是根据需要读取。

我们数据集的样本将是一个 dict {'image': image, 'landmarks': landmarks}。我们的数据集将采用一个可选参数 transform,以便可以在样本上应用任何所需的处理。我们将在下一节中看到 transform 的用处。

class FaceLandmarksDataset(Dataset):
    """Face Landmarks dataset."""

    def __init__(self, csv_file, root_dir, transform=None):
        """
        Arguments:
            csv_file (string): Path to the csv file with annotations.
            root_dir (string): Directory with all the images.
            transform (callable, optional): Optional transform to be applied
                on a sample.
        """
        self.landmarks_frame = pd.read_csv(csv_file)
        self.root_dir = root_dir
        self.transform = transform

    def __len__(self):
        return len(self.landmarks_frame)

    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()

        img_name = os.path.join(self.root_dir,
                                self.landmarks_frame.iloc[idx, 0])
        image = io.imread(img_name)
        landmarks = self.landmarks_frame.iloc[idx, 1:]
        landmarks = np.array([landmarks], dtype=float).reshape(-1, 2)
        sample = {'image': image, 'landmarks': landmarks}

        if self.transform:
            sample = self.transform(sample)

        return sample

让我们实例化这个类并遍历数据样本。我们将打印前 4 个样本的大小并显示其地标。

face_dataset = FaceLandmarksDataset(csv_file='data/faces/face_landmarks.csv',
                                    root_dir='data/faces/')

fig = plt.figure()

for i, sample in enumerate(face_dataset):
    print(i, sample['image'].shape, sample['landmarks'].shape)

    ax = plt.subplot(1, 4, i + 1)
    plt.tight_layout()
    ax.set_title('Sample #{}'.format(i))
    ax.axis('off')
    show_landmarks(**sample)

    if i == 3:
        plt.show()
        break
Sample #0, Sample #1, Sample #2, Sample #3
0 (324, 215, 3) (68, 2)
1 (500, 333, 3) (68, 2)
2 (250, 258, 3) (68, 2)
3 (434, 290, 3) (68, 2)

转换

我们可以从上面看到的一个问题是样本的大小不相同。大多数神经网络都期望图像具有固定的大小。因此,我们将需要编写一些预处理代码。让我们创建三个转换

  • Rescale:缩放图像

  • RandomCrop:从图像中随机裁剪。这是数据增强。

  • ToTensor:将 numpy 图像转换为 torch 图像(我们需要交换轴)。

我们将把它们写成可调用类而不是简单函数,以便每次调用时都不需要传递转换的参数。为此,我们只需要实现 __call__ 方法,如果需要,还可以实现 __init__ 方法。然后我们可以像这样使用转换

tsfm = Transform(params)
transformed_sample = tsfm(sample)

观察下面这些转换是如何必须同时应用于图像和地标的。

class Rescale(object):
    """Rescale the image in a sample to a given size.

    Args:
        output_size (tuple or int): Desired output size. If tuple, output is
            matched to output_size. If int, smaller of image edges is matched
            to output_size keeping aspect ratio the same.
    """

    def __init__(self, output_size):
        assert isinstance(output_size, (int, tuple))
        self.output_size = output_size

    def __call__(self, sample):
        image, landmarks = sample['image'], sample['landmarks']

        h, w = image.shape[:2]
        if isinstance(self.output_size, int):
            if h > w:
                new_h, new_w = self.output_size * h / w, self.output_size
            else:
                new_h, new_w = self.output_size, self.output_size * w / h
        else:
            new_h, new_w = self.output_size

        new_h, new_w = int(new_h), int(new_w)

        img = transform.resize(image, (new_h, new_w))

        # h and w are swapped for landmarks because for images,
        # x and y axes are axis 1 and 0 respectively
        landmarks = landmarks * [new_w / w, new_h / h]

        return {'image': img, 'landmarks': landmarks}


class RandomCrop(object):
    """Crop randomly the image in a sample.

    Args:
        output_size (tuple or int): Desired output size. If int, square crop
            is made.
    """

    def __init__(self, output_size):
        assert isinstance(output_size, (int, tuple))
        if isinstance(output_size, int):
            self.output_size = (output_size, output_size)
        else:
            assert len(output_size) == 2
            self.output_size = output_size

    def __call__(self, sample):
        image, landmarks = sample['image'], sample['landmarks']

        h, w = image.shape[:2]
        new_h, new_w = self.output_size

        top = np.random.randint(0, h - new_h + 1)
        left = np.random.randint(0, w - new_w + 1)

        image = image[top: top + new_h,
                      left: left + new_w]

        landmarks = landmarks - [left, top]

        return {'image': image, 'landmarks': landmarks}


class ToTensor(object):
    """Convert ndarrays in sample to Tensors."""

    def __call__(self, sample):
        image, landmarks = sample['image'], sample['landmarks']

        # swap color axis because
        # numpy image: H x W x C
        # torch image: C x H x W
        image = image.transpose((2, 0, 1))
        return {'image': torch.from_numpy(image),
                'landmarks': torch.from_numpy(landmarks)}

注意

在上面的示例中,RandomCrop 使用外部库的随机数生成器(在本例中为 Numpy 的 np.random.int)。这可能会导致 DataLoader 出现意外行为(请参阅此处)。在实践中,坚持使用 PyTorch 的随机数生成器更安全,例如使用 torch.randint 代替。

组合转换

现在,我们将转换应用于样本。

假设我们想要将图像的较短边缩放到 256,然后从中随机裁剪一个大小为 224 的正方形。即,我们想要组合 RescaleRandomCrop 转换。torchvision.transforms.Compose 是一个简单的可调用类,允许我们执行此操作。

scale = Rescale(256)
crop = RandomCrop(128)
composed = transforms.Compose([Rescale(256),
                               RandomCrop(224)])

# Apply each of the above transforms on sample.
fig = plt.figure()
sample = face_dataset[65]
for i, tsfrm in enumerate([scale, crop, composed]):
    transformed_sample = tsfrm(sample)

    ax = plt.subplot(1, 3, i + 1)
    plt.tight_layout()
    ax.set_title(type(tsfrm).__name__)
    show_landmarks(**transformed_sample)

plt.show()
Rescale, RandomCrop, Compose

遍历数据集

让我们将所有这些放在一起,创建一个带有组合转换的数据集。总结一下,每次对这个数据集进行采样时

  • 从文件中即时读取图像

  • 转换应用于读取的图像

  • 由于其中一个转换是随机的,因此在采样时会增强数据

我们可以像以前一样使用 for i in range 循环遍历创建的数据集。

transformed_dataset = FaceLandmarksDataset(csv_file='data/faces/face_landmarks.csv',
                                           root_dir='data/faces/',
                                           transform=transforms.Compose([
                                               Rescale(256),
                                               RandomCrop(224),
                                               ToTensor()
                                           ]))

for i, sample in enumerate(transformed_dataset):
    print(i, sample['image'].size(), sample['landmarks'].size())

    if i == 3:
        break
0 torch.Size([3, 224, 224]) torch.Size([68, 2])
1 torch.Size([3, 224, 224]) torch.Size([68, 2])
2 torch.Size([3, 224, 224]) torch.Size([68, 2])
3 torch.Size([3, 224, 224]) torch.Size([68, 2])

但是,我们通过使用简单的 for 循环来遍历数据,丢失了很多功能。特别是,我们错过了

  • 批量处理数据

  • 打乱数据

  • 使用 multiprocessing 工作进程并行加载数据。

torch.utils.data.DataLoader 是一个迭代器,它提供了所有这些功能。下面使用的参数应该很清楚。一个值得关注的参数是 collate_fn。您可以使用 collate_fn 指定需要如何批量处理样本。但是,默认的 collate 应该适用于大多数用例。

dataloader = DataLoader(transformed_dataset, batch_size=4,
                        shuffle=True, num_workers=0)


# Helper function to show a batch
def show_landmarks_batch(sample_batched):
    """Show image with landmarks for a batch of samples."""
    images_batch, landmarks_batch = \
            sample_batched['image'], sample_batched['landmarks']
    batch_size = len(images_batch)
    im_size = images_batch.size(2)
    grid_border_size = 2

    grid = utils.make_grid(images_batch)
    plt.imshow(grid.numpy().transpose((1, 2, 0)))

    for i in range(batch_size):
        plt.scatter(landmarks_batch[i, :, 0].numpy() + i * im_size + (i + 1) * grid_border_size,
                    landmarks_batch[i, :, 1].numpy() + grid_border_size,
                    s=10, marker='.', c='r')

        plt.title('Batch from dataloader')

# if you are using Windows, uncomment the next line and indent the for loop.
# you might need to go back and change ``num_workers`` to 0.

# if __name__ == '__main__':
for i_batch, sample_batched in enumerate(dataloader):
    print(i_batch, sample_batched['image'].size(),
          sample_batched['landmarks'].size())

    # observe 4th batch and stop.
    if i_batch == 3:
        plt.figure()
        show_landmarks_batch(sample_batched)
        plt.axis('off')
        plt.ioff()
        plt.show()
        break
Batch from dataloader
0 torch.Size([4, 3, 224, 224]) torch.Size([4, 68, 2])
1 torch.Size([4, 3, 224, 224]) torch.Size([4, 68, 2])
2 torch.Size([4, 3, 224, 224]) torch.Size([4, 68, 2])
3 torch.Size([4, 3, 224, 224]) torch.Size([4, 68, 2])

后记:torchvision

在本教程中,我们已经了解了如何编写和使用数据集、转换和数据加载器。torchvision 包提供了一些常用的数据集和转换。您甚至可能不必编写自定义类。torchvision 中提供的更通用的数据集之一是 ImageFolder。它假定图像按以下方式组织

root/ants/xxx.png
root/ants/xxy.jpeg
root/ants/xxz.png
.
.
.
root/bees/123.jpg
root/bees/nsdf3.png
root/bees/asd932_.png

其中“ants”、“bees”等是类标签。类似地,对 PIL.Image 进行操作的通用转换(如 RandomHorizontalFlipScale)也可用。您可以使用它们来编写像这样的数据加载器

import torch
from torchvision import transforms, datasets

data_transform = transforms.Compose([
        transforms.RandomSizedCrop(224),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406],
                             std=[0.229, 0.224, 0.225])
    ])
hymenoptera_dataset = datasets.ImageFolder(root='hymenoptera_data/train',
                                           transform=data_transform)
dataset_loader = torch.utils.data.DataLoader(hymenoptera_dataset,
                                             batch_size=4, shuffle=True,
                                             num_workers=4)

有关带有训练代码的示例,请参阅计算机视觉迁移学习教程

脚本的总运行时间: (0 分钟 2.385 秒)

由 Sphinx-Gallery 生成的图库

文档

访问 PyTorch 的全面开发者文档

查看文档

教程

获取面向初学者和高级开发者的深度教程

查看教程

资源

查找开发资源并获得问题解答

查看资源