• 教程 >
  • 编写自定义数据集、DataLoader 和 Transform
快捷方式

编写自定义数据集、DataLoader 和 Transform

作者: Sasank Chilamkurthy

解决任何机器学习问题都需要投入大量精力来准备数据。PyTorch 提供了许多工具来简化数据加载,并希望使您的代码更具可读性。在本教程中,我们将了解如何从非平凡的数据集中加载和预处理/增强数据。

要运行本教程,请确保已安装以下软件包

  • scikit-image:用于图像 I/O 和变换

  • pandas:用于更轻松的 csv 解析

import os
import torch
import pandas as pd
from skimage import io, transform
import numpy as np
import matplotlib.pyplot as plt
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, utils

# Ignore warnings
import warnings
warnings.filterwarnings("ignore")

plt.ion()   # interactive mode
<contextlib.ExitStack object at 0x7f41d05c3130>

我们将要处理的数据集是人脸姿态数据集。这意味着人脸会被这样注释

../_images/landmarked_face2.png

总共为每张人脸注释了 68 个不同的地标点。

注意

这里 下载数据集,以便图像位于名为“data/faces/”的目录中。此数据集实际上是通过对标记为“人脸”的 ImageNet 中的一些图像应用优秀的 dlib 的姿态估计 生成的。

数据集附带一个 .csv 文件,其中包含如下所示的注释

image_name,part_0_x,part_0_y,part_1_x,part_1_y,part_2_x, ... ,part_67_x,part_67_y
0805personali01.jpg,27,83,27,98, ... 84,134
1084239450_e76e00b7e7.jpg,70,236,71,257, ... ,128,312

让我们以CSV中的一张图像名称及其标注为例,例如第65行的person-7.jpg。读取它,将图像名称存储在img_name中,并将它的标注存储在一个(L, 2)数组landmarks中,其中L是该行中地标的数量。

landmarks_frame = pd.read_csv('data/faces/face_landmarks.csv')

n = 65
img_name = landmarks_frame.iloc[n, 0]
landmarks = landmarks_frame.iloc[n, 1:]
landmarks = np.asarray(landmarks, dtype=float).reshape(-1, 2)

print('Image name: {}'.format(img_name))
print('Landmarks shape: {}'.format(landmarks.shape))
print('First 4 Landmarks: {}'.format(landmarks[:4]))
Image name: person-7.jpg
Landmarks shape: (68, 2)
First 4 Landmarks: [[32. 65.]
 [33. 76.]
 [34. 86.]
 [34. 97.]]

让我们编写一个简单的辅助函数来显示图像及其地标,并用它来显示一个示例。

def show_landmarks(image, landmarks):
    """Show image with landmarks"""
    plt.imshow(image)
    plt.scatter(landmarks[:, 0], landmarks[:, 1], s=10, marker='.', c='r')
    plt.pause(0.001)  # pause a bit so that plots are updated

plt.figure()
show_landmarks(io.imread(os.path.join('data/faces/', img_name)),
               landmarks)
plt.show()
data loading tutorial

数据集类

torch.utils.data.Dataset是一个表示数据集的抽象类。您的自定义数据集应该继承Dataset并覆盖以下方法

  • __len__,以便len(dataset)返回数据集的大小。

  • __getitem__,以支持索引,以便可以使用dataset[i]获取第\(i\)个样本。

让我们为我们的面部地标数据集创建一个数据集类。我们将在__init__中读取csv,但将图像读取留给__getitem__。这是内存高效的,因为并非所有图像都同时存储在内存中,而是根据需要读取。

我们的数据集样本将是一个字典{'image': image, 'landmarks': landmarks}。我们的数据集将接受一个可选参数transform,以便可以在样本上应用任何所需的处理。我们将在下一节中看到transform的有用性。

class FaceLandmarksDataset(Dataset):
    """Face Landmarks dataset."""

    def __init__(self, csv_file, root_dir, transform=None):
        """
        Arguments:
            csv_file (string): Path to the csv file with annotations.
            root_dir (string): Directory with all the images.
            transform (callable, optional): Optional transform to be applied
                on a sample.
        """
        self.landmarks_frame = pd.read_csv(csv_file)
        self.root_dir = root_dir
        self.transform = transform

    def __len__(self):
        return len(self.landmarks_frame)

    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()

        img_name = os.path.join(self.root_dir,
                                self.landmarks_frame.iloc[idx, 0])
        image = io.imread(img_name)
        landmarks = self.landmarks_frame.iloc[idx, 1:]
        landmarks = np.array([landmarks], dtype=float).reshape(-1, 2)
        sample = {'image': image, 'landmarks': landmarks}

        if self.transform:
            sample = self.transform(sample)

        return sample

让我们实例化这个类并遍历数据样本。我们将打印前4个样本的大小并显示它们的地标。

face_dataset = FaceLandmarksDataset(csv_file='data/faces/face_landmarks.csv',
                                    root_dir='data/faces/')

fig = plt.figure()

for i, sample in enumerate(face_dataset):
    print(i, sample['image'].shape, sample['landmarks'].shape)

    ax = plt.subplot(1, 4, i + 1)
    plt.tight_layout()
    ax.set_title('Sample #{}'.format(i))
    ax.axis('off')
    show_landmarks(**sample)

    if i == 3:
        plt.show()
        break
Sample #0, Sample #1, Sample #2, Sample #3
0 (324, 215, 3) (68, 2)
1 (500, 333, 3) (68, 2)
2 (250, 258, 3) (68, 2)
3 (434, 290, 3) (68, 2)

变换

从上面我们可以看到一个问题,那就是样本的大小不一致。大多数神经网络期望图像具有固定的大小。因此,我们需要编写一些预处理代码。让我们创建三个变换

  • Rescale:缩放图像

  • RandomCrop:随机裁剪图像。这是数据增强。

  • ToTensor:将numpy图像转换为torch图像(我们需要交换轴)。

我们将把它们写成可调用类而不是简单的函数,这样每次调用变换时就不需要传递变换的参数。为此,我们只需要实现__call__方法,如果需要,还需要实现__init__方法。然后我们可以像这样使用变换

tsfm = Transform(params)
transformed_sample = tsfm(sample)

请注意,这些变换必须同时应用于图像和地标。

class Rescale(object):
    """Rescale the image in a sample to a given size.

    Args:
        output_size (tuple or int): Desired output size. If tuple, output is
            matched to output_size. If int, smaller of image edges is matched
            to output_size keeping aspect ratio the same.
    """

    def __init__(self, output_size):
        assert isinstance(output_size, (int, tuple))
        self.output_size = output_size

    def __call__(self, sample):
        image, landmarks = sample['image'], sample['landmarks']

        h, w = image.shape[:2]
        if isinstance(self.output_size, int):
            if h > w:
                new_h, new_w = self.output_size * h / w, self.output_size
            else:
                new_h, new_w = self.output_size, self.output_size * w / h
        else:
            new_h, new_w = self.output_size

        new_h, new_w = int(new_h), int(new_w)

        img = transform.resize(image, (new_h, new_w))

        # h and w are swapped for landmarks because for images,
        # x and y axes are axis 1 and 0 respectively
        landmarks = landmarks * [new_w / w, new_h / h]

        return {'image': img, 'landmarks': landmarks}


class RandomCrop(object):
    """Crop randomly the image in a sample.

    Args:
        output_size (tuple or int): Desired output size. If int, square crop
            is made.
    """

    def __init__(self, output_size):
        assert isinstance(output_size, (int, tuple))
        if isinstance(output_size, int):
            self.output_size = (output_size, output_size)
        else:
            assert len(output_size) == 2
            self.output_size = output_size

    def __call__(self, sample):
        image, landmarks = sample['image'], sample['landmarks']

        h, w = image.shape[:2]
        new_h, new_w = self.output_size

        top = np.random.randint(0, h - new_h + 1)
        left = np.random.randint(0, w - new_w + 1)

        image = image[top: top + new_h,
                      left: left + new_w]

        landmarks = landmarks - [left, top]

        return {'image': image, 'landmarks': landmarks}


class ToTensor(object):
    """Convert ndarrays in sample to Tensors."""

    def __call__(self, sample):
        image, landmarks = sample['image'], sample['landmarks']

        # swap color axis because
        # numpy image: H x W x C
        # torch image: C x H x W
        image = image.transpose((2, 0, 1))
        return {'image': torch.from_numpy(image),
                'landmarks': torch.from_numpy(landmarks)}

注意

在上面的例子中,RandomCrop使用了外部库的随机数生成器(在本例中为Numpy的np.random.int)。这可能导致DataLoader出现意外行为(参见此处)。在实践中,最好坚持使用PyTorch的随机数生成器,例如使用torch.randint代替。

组合变换

现在,我们将变换应用于样本。

假设我们希望将图像的较短边缩放至256,然后从中随机裁剪一个大小为224的正方形。即,我们希望组合RescaleRandomCrop变换。torchvision.transforms.Compose是一个简单的可调用类,它允许我们做到这一点。

scale = Rescale(256)
crop = RandomCrop(128)
composed = transforms.Compose([Rescale(256),
                               RandomCrop(224)])

# Apply each of the above transforms on sample.
fig = plt.figure()
sample = face_dataset[65]
for i, tsfrm in enumerate([scale, crop, composed]):
    transformed_sample = tsfrm(sample)

    ax = plt.subplot(1, 3, i + 1)
    plt.tight_layout()
    ax.set_title(type(tsfrm).__name__)
    show_landmarks(**transformed_sample)

plt.show()
Rescale, RandomCrop, Compose

遍历数据集

让我们把这些都放在一起,创建一个具有组合变换的数据集。概括地说,每次采样此数据集时

  • 从文件中动态读取图像

  • 将变换应用于读取的图像

  • 由于其中一个变换是随机的,因此在采样时会增强数据

我们可以像以前一样,使用for i in range循环遍历创建的数据集。

transformed_dataset = FaceLandmarksDataset(csv_file='data/faces/face_landmarks.csv',
                                           root_dir='data/faces/',
                                           transform=transforms.Compose([
                                               Rescale(256),
                                               RandomCrop(224),
                                               ToTensor()
                                           ]))

for i, sample in enumerate(transformed_dataset):
    print(i, sample['image'].size(), sample['landmarks'].size())

    if i == 3:
        break
0 torch.Size([3, 224, 224]) torch.Size([68, 2])
1 torch.Size([3, 224, 224]) torch.Size([68, 2])
2 torch.Size([3, 224, 224]) torch.Size([68, 2])
3 torch.Size([3, 224, 224]) torch.Size([68, 2])

但是,通过使用简单的for循环遍历数据,我们损失了许多功能。特别是,我们错过了

  • 批量处理数据

  • 随机打乱数据

  • 使用multiprocessing工作程序并行加载数据。

torch.utils.data.DataLoader是一个迭代器,它提供了所有这些功能。下面使用的参数应该很清楚。一个值得注意的参数是collate_fn。您可以使用collate_fn指定样本如何精确地进行批处理。但是,对于大多数用例,默认的collate应该可以正常工作。

dataloader = DataLoader(transformed_dataset, batch_size=4,
                        shuffle=True, num_workers=0)


# Helper function to show a batch
def show_landmarks_batch(sample_batched):
    """Show image with landmarks for a batch of samples."""
    images_batch, landmarks_batch = \
            sample_batched['image'], sample_batched['landmarks']
    batch_size = len(images_batch)
    im_size = images_batch.size(2)
    grid_border_size = 2

    grid = utils.make_grid(images_batch)
    plt.imshow(grid.numpy().transpose((1, 2, 0)))

    for i in range(batch_size):
        plt.scatter(landmarks_batch[i, :, 0].numpy() + i * im_size + (i + 1) * grid_border_size,
                    landmarks_batch[i, :, 1].numpy() + grid_border_size,
                    s=10, marker='.', c='r')

        plt.title('Batch from dataloader')

# if you are using Windows, uncomment the next line and indent the for loop.
# you might need to go back and change ``num_workers`` to 0.

# if __name__ == '__main__':
for i_batch, sample_batched in enumerate(dataloader):
    print(i_batch, sample_batched['image'].size(),
          sample_batched['landmarks'].size())

    # observe 4th batch and stop.
    if i_batch == 3:
        plt.figure()
        show_landmarks_batch(sample_batched)
        plt.axis('off')
        plt.ioff()
        plt.show()
        break
Batch from dataloader
0 torch.Size([4, 3, 224, 224]) torch.Size([4, 68, 2])
1 torch.Size([4, 3, 224, 224]) torch.Size([4, 68, 2])
2 torch.Size([4, 3, 224, 224]) torch.Size([4, 68, 2])
3 torch.Size([4, 3, 224, 224]) torch.Size([4, 68, 2])

后记:torchvision

在本教程中,我们了解了如何编写和使用数据集、变换和数据加载器。torchvision包提供了一些常见的数据集和变换。您甚至可能不必编写自定义类。torchvision中提供的一个更通用的数据集是ImageFolder。它假设图像按以下方式组织

root/ants/xxx.png
root/ants/xxy.jpeg
root/ants/xxz.png
.
.
.
root/bees/123.jpg
root/bees/nsdf3.png
root/bees/asd932_.png

其中“ants”、“bees”等是类标签。类似地,对PIL.Image进行操作的通用变换,如RandomHorizontalFlipScale,也可用。您可以使用它们来编写如下所示的数据加载器

import torch
from torchvision import transforms, datasets

data_transform = transforms.Compose([
        transforms.RandomSizedCrop(224),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406],
                             std=[0.229, 0.224, 0.225])
    ])
hymenoptera_dataset = datasets.ImageFolder(root='hymenoptera_data/train',
                                           transform=data_transform)
dataset_loader = torch.utils.data.DataLoader(hymenoptera_dataset,
                                             batch_size=4, shuffle=True,
                                             num_workers=4)

有关包含训练代码的示例,请参见计算机视觉迁移学习教程

脚本的总运行时间:(0分钟2.656秒)

由Sphinx-Gallery生成的图库

文档

访问PyTorch的全面开发者文档

查看文档

教程

获取针对初学者和高级开发人员的深入教程

查看教程

资源

查找开发资源并获得问题的解答

查看资源