如何使用 CutMix 和 MixUp¶
CutMix
和 MixUp
是流行的增强策略,可以提高分类准确性。
这些转换与 Torchvision 中的其他转换略有不同,因为它们期望将批次的样本作为输入,而不是单个图像。在本例中,我们将解释如何使用它们:在 DataLoader
之后,或作为整理函数的一部分。
import torch
from torchvision.datasets import FakeData
from torchvision.transforms import v2
NUM_CLASSES = 100
预处理管道¶
我们将使用一个简单但典型的图像分类管道
preproc = v2.Compose([
v2.PILToTensor(),
v2.RandomResizedCrop(size=(224, 224), antialias=True),
v2.RandomHorizontalFlip(p=0.5),
v2.ToDtype(torch.float32, scale=True), # to float32 in [0, 1]
v2.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)), # typically from ImageNet
])
dataset = FakeData(size=1000, num_classes=NUM_CLASSES, transform=preproc)
img, label = dataset[0]
print(f"{type(img) = }, {img.dtype = }, {img.shape = }, {label = }")
type(img) = <class 'torch.Tensor'>, img.dtype = torch.float32, img.shape = torch.Size([3, 224, 224]), label = 67
需要注意的一点是,CutMix 和 MixUp 都不属于此预处理管道。我们稍后在定义 DataLoader 时添加它们。作为复习,如果我们不使用 CutMix 或 MixUp,DataLoader 和训练循环将如下所示
from torch.utils.data import DataLoader
dataloader = DataLoader(dataset, batch_size=4, shuffle=True)
for images, labels in dataloader:
print(f"{images.shape = }, {labels.shape = }")
print(labels.dtype)
# <rest of the training loop here>
break
images.shape = torch.Size([4, 3, 224, 224]), labels.shape = torch.Size([4])
torch.int64
在何处使用 MixUp 和 CutMix¶
在 DataLoader 之后¶
现在让我们添加 CutMix 和 MixUp。在 DataLoader 之后执行此操作的最简单方法是:DataLoader 已经为我们批处理了图像和标签,这正是这些转换期望作为输入的内容
dataloader = DataLoader(dataset, batch_size=4, shuffle=True)
cutmix = v2.CutMix(num_classes=NUM_CLASSES)
mixup = v2.MixUp(num_classes=NUM_CLASSES)
cutmix_or_mixup = v2.RandomChoice([cutmix, mixup])
for images, labels in dataloader:
print(f"Before CutMix/MixUp: {images.shape = }, {labels.shape = }")
images, labels = cutmix_or_mixup(images, labels)
print(f"After CutMix/MixUp: {images.shape = }, {labels.shape = }")
# <rest of the training loop here>
break
Before CutMix/MixUp: images.shape = torch.Size([4, 3, 224, 224]), labels.shape = torch.Size([4])
After CutMix/MixUp: images.shape = torch.Size([4, 3, 224, 224]), labels.shape = torch.Size([4, 100])
请注意标签是如何转换的:我们从形状为 (batch_size,) 的批处理标签变为形状为 (batch_size, num_classes) 的张量。转换后的标签仍可原样传递给损失函数,如 torch.nn.functional.cross_entropy()
。
作为整理函数的一部分¶
在 DataLoader 之后传递转换是使用 CutMix 和 MixUp 的最简单方法,但缺点是它没有利用 DataLoader 多处理。为此,我们可以将这些转换作为整理函数的一部分传递(请参阅 PyTorch 文档 以了解有关整理的更多信息)。
from torch.utils.data import default_collate
def collate_fn(batch):
return cutmix_or_mixup(*default_collate(batch))
dataloader = DataLoader(dataset, batch_size=4, shuffle=True, num_workers=2, collate_fn=collate_fn)
for images, labels in dataloader:
print(f"{images.shape = }, {labels.shape = }")
# No need to call cutmix_or_mixup, it's already been called as part of the DataLoader!
# <rest of the training loop here>
break
images.shape = torch.Size([4, 3, 224, 224]), labels.shape = torch.Size([4, 100])
非标准输入格式¶
到目前为止,我们一直在使用典型的样本结构,其中我们将 (images, labels)
作为输入。MixUp 和 CutMix 将默认情况下与大多数常见的样本结构神奇地一起使用:元组,其中第二个参数是张量标签,或带有“label[s]”键的字典。有关更多详细信息,请查看 labels_getter
参数的文档。
如果你的样本具有不同的结构,你仍然可以通过将可调用对象传递给 labels_getter
参数来使用 CutMix 和 MixUp。例如
batch = {
"imgs": torch.rand(4, 3, 224, 224),
"target": {
"classes": torch.randint(0, NUM_CLASSES, size=(4,)),
"some_other_key": "this is going to be passed-through"
}
}
def labels_getter(batch):
return batch["target"]["classes"]
out = v2.CutMix(num_classes=NUM_CLASSES, labels_getter=labels_getter)(batch)
print(f"{out['imgs'].shape = }, {out['target']['classes'].shape = }")
out['imgs'].shape = torch.Size([4, 3, 224, 224]), out['target']['classes'].shape = torch.Size([4, 100])
脚本的总运行时间:(0 分钟 0.185 秒)