注意
转到末尾 下载完整的示例代码。
使用 tensorclasses 进行批量数据加载¶
在本教程中,我们将演示如何将 tensorclasses 和内存映射张量结合使用,以便在模型训练管道中高效且透明地从磁盘加载数据。
基本思想是我们将整个数据集预加载到内存映射张量中,并在保存到磁盘之前应用任何非随机变换。这意味着我们不仅避免了每次迭代数据时都执行重复计算,而且还能够高效地以批处理方式从内存映射张量中加载数据,而不是从原始图像文件依次加载。
通过结合预处理、在连续物理内存存储上的加载以及设备上的批量转换,我们获得了比常规 torch + torchvision 管道快 10 倍的数据加载速度。
我们将使用与此迁移学习教程中使用的相同 ImageNet 子集,尽管我们也提供了在 ImageNet 上运行相同代码的实验结果。
注意
从这里下载数据并解压缩。在本教程中,我们假设解压缩后的数据保存在子目录data/
中。
import os
import time
from pathlib import Path
import torch
import torch.nn as nn
import tqdm
from tensordict import MemoryMappedTensor, tensorclass
from tensordict.utils import strtobool
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
if __name__ == "__main__":
NUM_WORKERS = int(os.environ.get("NUM_WORKERS", "4"))
device = "cuda:0" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")
##############################################################################
# Transforms
# ----------
# First we define train and val transforms that will be applied to train and
# val examples respectively. Note that there are random components in the
# train transform to prevent overfitting to training data over multiple
# epochs.
train_transform = transforms.Compose(
[
transforms.RandomResizedCrop(224),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
]
)
val_transform = transforms.Compose(
[
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
]
)
##############################################################################
# We use ``torchvision.datasets.ImageFolder`` to conveniently load and
# transform the data from disk.
data_dir = Path("data") / "hymenoptera_data/"
train_data = datasets.ImageFolder(
root=data_dir / "train", transform=train_transform
)
val_data = datasets.ImageFolder(root=data_dir / "val", transform=val_transform)
##############################################################################
# We'll also create a dataset of the raw training data that simply resizes
# the image to a common size and converts to tensor. We'll use this to
# load the data into memory-mapped tensors. The random transformations
# need to be different each time we iterate through the data, so they
# cannot be pre-computed. We also do not scale the data yet so that we can set the
# ``dtype`` of the memory-mapped array to ``uint8`` and save space.
train_data_raw = datasets.ImageFolder(
root=data_dir / "train",
transform=transforms.Compose(
[transforms.Resize((256, 256)), transforms.PILToTensor()]
),
)
##############################################################################
# Since we'll be loading our data in batches, we write a few custom transformations
# that take advantage of this, and apply the transformations in a vectorized way.
#
# First a transformation that can be used for normalization.
class InvAffine(nn.Module):
"""A custom normalization layer."""
def __init__(self, loc, scale):
super().__init__()
self.loc = loc
self.scale = scale
def forward(self, x):
return (x - self.loc) / self.scale
##############################################################################
# Next two transformations that can be used to randomly crop and flip the images.
class RandomHFlip(nn.Module):
def forward(self, x: torch.Tensor):
idx = (
torch.zeros(*x.shape[:-3], 1, 1, 1, device=x.device, dtype=torch.bool)
.bernoulli_()
.expand_as(x)
)
return x.masked_fill(idx, 0.0) + x.masked_fill(~idx, 0.0).flip(-1)
class RandomCrop(nn.Module):
def __init__(self, w, h):
super(RandomCrop, self).__init__()
self.w = w
self.h = h
def forward(self, x):
batch = x.shape[:-3]
index0 = torch.randint(x.shape[-2] - self.h, (*batch, 1), device=x.device)
index0 = index0 + torch.arange(self.h, device=x.device)
index0 = (
index0.unsqueeze(1)
.unsqueeze(-1)
.expand((*batch, 3, self.h, x.shape[-1]))
)
index1 = torch.randint(x.shape[-1] - self.w, (*batch, 1), device=x.device)
index1 = index1 + torch.arange(self.w, device=x.device)
index1 = (
index1.unsqueeze(1).unsqueeze(-2).expand((*batch, 3, self.h, self.w))
)
return x.gather(-2, index0).gather(-1, index1)
##############################################################################
# When each batch is loaded, we will scale it, then randomly crop and flip. The random
# transformations cannot be pre-applied as they must differ each time we iterate over
# the data. The scaling could be pre-applied in principle, but by waiting until we load
# the data into RAM, we are able to set the dtype of the memory-mapped array to
# ``uint8``, a significant space saving over ``float32``.
collate_transform = nn.Sequential(
InvAffine(
loc=torch.tensor([0.485, 0.456, 0.406], device=device).view(3, 1, 1) * 255,
scale=torch.tensor([0.229, 0.224, 0.225], device=device).view(3, 1, 1)
* 255,
),
RandomCrop(224, 224),
RandomHFlip(),
)
##############################################################################
# Representing data with a TensorClass
# ------------------------------------
# Tensorclasses are a good choice when the structure of your data is known
# apriori. They are dataclasses that expose dedicated tensor methods over
# their contents much like a ``TensorDict``.
#
# As well as specifying the contents (in this case ``images`` and
# ``targets``) we can also encapsulate related logic as custom methods
# when defining the class. Here we add a classmethod that takes a dataset
# and creates a tensorclass containing the data by iterating over the
# dataset. We create memory-mapped tensors to hold the data so that they
# can be efficiently loaded in batches later.
@tensorclass
class ImageNetData:
images: torch.Tensor
targets: torch.Tensor
@classmethod
def from_dataset(cls, dataset):
data = cls(
images=MemoryMappedTensor.empty(
(
len(dataset),
*dataset[0][0].squeeze().shape,
),
dtype=torch.uint8,
),
targets=MemoryMappedTensor.empty((len(dataset),), dtype=torch.int64),
batch_size=[len(dataset)],
)
# locks the tensorclass and ensures that is_memmap will return True.
data.memmap_()
batch = 64
dl = DataLoader(dataset, batch_size=batch, num_workers=NUM_WORKERS)
i = 0
pbar = tqdm.tqdm(total=len(dataset))
for image, target in dl:
_batch = image.shape[0]
pbar.update(_batch)
print(data)
print(cls(images=image, targets=target, batch_size=[_batch]))
data[i : i + _batch] = cls(
images=image, targets=target, batch_size=[_batch]
)
i += _batch
return data
##############################################################################
# We create two tensorclasses, one for the training and on for the
# validation data. Note that while this step can be slightly expensive, it
# allows us to save repeated computation later during training.
train_data_tc = ImageNetData.from_dataset(train_data_raw)
val_data_tc = ImageNetData.from_dataset(val_data)
##############################################################################
# DataLoaders
# -----------
#
# We can create dataloaders both from the ``torchvision``-provided
# Datasets, as well as from our memory-mapped tensorclasses.
#
# Since tensorclasses implement ``__len__`` and ``__getitem__`` (and also
# ``__getitems__``) we can use them like a map-style Dataset and create a
# ``DataLoader`` directly from them.
#
# Since the TensorClass data will be loaded in batches, we need to specify how these
# batches should be collated. For this we write the following helper class
class Collate(nn.Module):
def __init__(self, transform=None, device=None):
super().__init__()
self.transform = transform
self.device = torch.device(device)
def __call__(self, x: ImageNetData):
# move data to RAM
if self.device.type == "cuda":
out = x.pin_memory()
else:
out = x
if self.device:
# move data to gpu
out = out.to(self.device)
if self.transform:
# apply transforms on gpu
out.images = self.transform(out.images)
return out
##############################################################################
# ``DataLoader`` has support for multiple workers loading data in parallel. The
# tensorclass dataloader will use just one worker, but load data in batches.
#
# Note that under this approach our ``collate_fn`` is essentially just an ``nn.Module``,
# making it transparent and easy to implement. But this approach also offers
# flexibility, for example, if needed we could move the collation step into the training
# loop by considering the ``Collate`` module as part of the model.
batch_size = 8
train_dataloader = DataLoader(
train_data,
batch_size=batch_size,
num_workers=NUM_WORKERS,
)
val_dataloader = DataLoader(
val_data,
batch_size=batch_size,
num_workers=NUM_WORKERS,
)
train_dataloader_tc = DataLoader( # noqa: TOR401
train_data_tc,
batch_size=batch_size,
collate_fn=Collate(collate_transform, device),
)
val_dataloader_tc = DataLoader( # noqa: TOR401
val_data_tc,
batch_size=batch_size,
collate_fn=Collate(device=device),
)
##############################################################################
# We can now compare how long it takes to iterate once over the data in
# each case. The regular dataloader loads images one by one from disk,
# applies the transform sequentially and then stacks the results
# (note: we start measuring time a little after the first iteration, as
# starting the dataloader can take some time).
total = 0
for i, (image, target) in enumerate(train_dataloader):
if i == 3:
t0 = time.time()
if i >= 3:
total += image.shape[0]
image, target = image.to(device), target.to(device)
t = time.time() - t0
print(
f"One iteration over dataloader done! Rate: {total / t:4.4f} fps, time: {t: 4.4f}s"
)
##############################################################################
# Our tensorclass-based dataloader instead loads data from the
# memory-mapped tensor in batches. We then apply the batched random
# transformations to the batched images.
total = 0
for i, batch in enumerate(train_dataloader_tc):
if i == 3:
t0 = time.time()
if i >= 3:
total += batch.numel()
image, target = batch.images, batch.targets
t = time.time() - t0
print(
f"One iteration over tensorclass dataloader done! Rate: {total / t:4.4f} fps, time: {t: 4.4f}s"
)
##############################################################################
# In the case of the validation set, we see an even bigger performance
# improvement, because there are no random transformations, so we can save
# the fully transformed data in the memory-mapped tensor, eliminating the
# need for additional transformations as we load from disk.
total = 0
for i, (image, target) in enumerate(val_dataloader):
if i == 3:
t0 = time.time()
if i >= 3:
total += image.shape[0]
image, target = image.to(device), target.to(device)
t = time.time() - t0
print(
f"One iteration over val data done! Rate: {total / t:4.4f} fps, time: {t: 4.4f}s"
)
total = 0
for i, batch in enumerate(val_dataloader_tc):
if i == 3:
t0 = time.time()
if i >= 3:
total += batch.shape[0]
image, target = batch.images.contiguous().to(
device
), batch.targets.contiguous().to(device)
t = time.time() - t0
print(
f"One iteration over tensorclass val data done! Rate: {total / t:4.4f} fps, time: {t: 4.4f}s"
)
##############################################################################
# Results from ImageNet
# ---------------------
#
# We repeated the above on full-size ImageNet data, running on an AWS EC2 instance with
# 32 cores and 1 A100 GPU. We compare against the regular ``DataLoader`` with different
# numbers of workers. We found that our single-threaded TensorClass approach
# out-performed the ``DataLoader`` even when we used a large number of workers.
#
# .. image:: /reference/generated/tutorials/media/imagenet-benchmark-time.png
# :alt: Bar chart showing runtimes of dataloaders compared with TensorClass
#
# .. image:: /reference/generated/tutorials/media/imagenet-benchmark-speed.png
# :alt: Bar chart showing collection rate of dataloaders compared with TensorClass
##############################################################################
# This shows that much of the overhead is coming from i/o operations rather than the
# transforms, and hence explains how the memory-mapped array helps us load data more
# efficiently. Check out the `distributed example <https://github.com/pytorch/tensordict/tree/main/benchmarks/distributed/dataloading.py>`__
# for more context about the other results from these charts.
#
# We can get even better performance with the TensorClass approach by using multiple
# workers to load batches from the memory-mapped array, though this comes with some
# added complexity. See `this example in our benchmarks
# <https://github.com/pytorch/tensordict/blob/main/benchmarks/distributed/dataloading.py>`__
# for an example of how this could work.
Using device: cpu
0%| | 0/244 [00:00<?, ?it/s]
26%|██▌ | 64/244 [00:00<00:00, 219.17it/s]ImageNetData(
images=MemoryMappedTensor(shape=torch.Size([244, 3, 256, 256]), device=cpu, dtype=torch.uint8, is_shared=False),
targets=MemoryMappedTensor(shape=torch.Size([244]), device=cpu, dtype=torch.int64, is_shared=False),
batch_size=torch.Size([244]),
device=cpu,
is_shared=False)
ImageNetData(
images=Tensor(shape=torch.Size([64, 3, 256, 256]), device=cpu, dtype=torch.uint8, is_shared=True),
targets=Tensor(shape=torch.Size([64]), device=cpu, dtype=torch.int64, is_shared=True),
batch_size=torch.Size([64]),
device=None,
is_shared=False)
ImageNetData(
images=MemoryMappedTensor(shape=torch.Size([244, 3, 256, 256]), device=cpu, dtype=torch.uint8, is_shared=False),
targets=MemoryMappedTensor(shape=torch.Size([244]), device=cpu, dtype=torch.int64, is_shared=False),
batch_size=torch.Size([244]),
device=cpu,
is_shared=False)
ImageNetData(
images=Tensor(shape=torch.Size([64, 3, 256, 256]), device=cpu, dtype=torch.uint8, is_shared=True),
targets=Tensor(shape=torch.Size([64]), device=cpu, dtype=torch.int64, is_shared=True),
batch_size=torch.Size([64]),
device=None,
is_shared=False)
ImageNetData(
images=MemoryMappedTensor(shape=torch.Size([244, 3, 256, 256]), device=cpu, dtype=torch.uint8, is_shared=False),
targets=MemoryMappedTensor(shape=torch.Size([244]), device=cpu, dtype=torch.int64, is_shared=False),
batch_size=torch.Size([244]),
device=cpu,
is_shared=False)
ImageNetData(
images=Tensor(shape=torch.Size([64, 3, 256, 256]), device=cpu, dtype=torch.uint8, is_shared=True),
targets=Tensor(shape=torch.Size([64]), device=cpu, dtype=torch.int64, is_shared=True),
batch_size=torch.Size([64]),
device=None,
is_shared=False)
ImageNetData(
images=MemoryMappedTensor(shape=torch.Size([244, 3, 256, 256]), device=cpu, dtype=torch.uint8, is_shared=False),
targets=MemoryMappedTensor(shape=torch.Size([244]), device=cpu, dtype=torch.int64, is_shared=False),
batch_size=torch.Size([244]),
device=cpu,
is_shared=False)
ImageNetData(
images=Tensor(shape=torch.Size([52, 3, 256, 256]), device=cpu, dtype=torch.uint8, is_shared=True),
targets=Tensor(shape=torch.Size([52]), device=cpu, dtype=torch.int64, is_shared=True),
batch_size=torch.Size([52]),
device=None,
is_shared=False)
100%|██████████| 244/244 [00:00<00:00, 752.07it/s]
0%| | 0/153 [00:00<?, ?it/s]
42%|████▏ | 64/153 [00:00<00:00, 216.40it/s]ImageNetData(
images=MemoryMappedTensor(shape=torch.Size([153, 3, 224, 224]), device=cpu, dtype=torch.uint8, is_shared=False),
targets=MemoryMappedTensor(shape=torch.Size([153]), device=cpu, dtype=torch.int64, is_shared=False),
batch_size=torch.Size([153]),
device=cpu,
is_shared=False)
ImageNetData(
images=Tensor(shape=torch.Size([64, 3, 224, 224]), device=cpu, dtype=torch.float32, is_shared=True),
targets=Tensor(shape=torch.Size([64]), device=cpu, dtype=torch.int64, is_shared=True),
batch_size=torch.Size([64]),
device=None,
is_shared=False)
84%|████████▎ | 128/153 [00:00<00:00, 344.70it/s]ImageNetData(
images=MemoryMappedTensor(shape=torch.Size([153, 3, 224, 224]), device=cpu, dtype=torch.uint8, is_shared=False),
targets=MemoryMappedTensor(shape=torch.Size([153]), device=cpu, dtype=torch.int64, is_shared=False),
batch_size=torch.Size([153]),
device=cpu,
is_shared=False)
ImageNetData(
images=Tensor(shape=torch.Size([64, 3, 224, 224]), device=cpu, dtype=torch.float32, is_shared=True),
targets=Tensor(shape=torch.Size([64]), device=cpu, dtype=torch.int64, is_shared=True),
batch_size=torch.Size([64]),
device=None,
is_shared=False)
ImageNetData(
images=MemoryMappedTensor(shape=torch.Size([153, 3, 224, 224]), device=cpu, dtype=torch.uint8, is_shared=False),
targets=MemoryMappedTensor(shape=torch.Size([153]), device=cpu, dtype=torch.int64, is_shared=False),
batch_size=torch.Size([153]),
device=cpu,
is_shared=False)
ImageNetData(
images=Tensor(shape=torch.Size([25, 3, 224, 224]), device=cpu, dtype=torch.float32, is_shared=True),
targets=Tensor(shape=torch.Size([25]), device=cpu, dtype=torch.int64, is_shared=True),
batch_size=torch.Size([25]),
device=None,
is_shared=False)
100%|██████████| 153/153 [00:00<00:00, 362.46it/s]
One iteration over dataloader done! Rate: 1016.7976 fps, time: 0.2164s
One iteration over tensorclass dataloader done! Rate: 1646.2043 fps, time: 0.1336s
One iteration over val data done! Rate: 668.7678 fps, time: 0.1929s
One iteration over tensorclass val data done! Rate: 20452.2856 fps, time: 0.0063s
脚本的总运行时间:(0 分钟 1.558 秒)