快捷方式

Transforms v2:端到端目标检测/分割示例

注意

Colab 上尝试或 转到结尾 下载完整的示例代码。

目标检测和分割任务得到原生支持:torchvision.transforms.v2 允许联合变换图像、视频、边界框和蒙版。

此示例展示了使用 torchvision.datasetstorchvision.modelstorchvision.transforms.v2 中的 Torchvision 实用程序进行的端到端实例分割训练案例。此处介绍的所有内容都可以在目标检测或语义分割任务中以类似的方式应用。

import pathlib

import torch
import torch.utils.data

from torchvision import models, datasets, tv_tensors
from torchvision.transforms import v2

torch.manual_seed(0)

# This loads fake data for illustration purposes of this example. In practice, you'll have
# to replace this with the proper data.
# If you're trying to run that on Colab, you can download the assets and the
# helpers from https://github.com/pytorch/vision/tree/main/gallery/
ROOT = pathlib.Path("../assets") / "coco"
IMAGES_PATH = str(ROOT / "images")
ANNOTATIONS_PATH = str(ROOT / "instances.json")
from helpers import plot

数据集准备

我们首先加载 CocoDetection 数据集,看看它当前返回的内容。

dataset = datasets.CocoDetection(IMAGES_PATH, ANNOTATIONS_PATH)

sample = dataset[0]
img, target = sample
print(f"{type(img) = }\n{type(target) = }\n{type(target[0]) = }\n{target[0].keys() = }")
loading annotations into memory...
Done (t=0.00s)
creating index...
index created!
type(img) = <class 'PIL.Image.Image'>
type(target) = <class 'list'>
type(target[0]) = <class 'dict'>
target[0].keys() = dict_keys(['segmentation', 'iscrowd', 'image_id', 'bbox', 'category_id', 'id'])

Torchvision 数据集保留了数据集作者的预期数据结构和类型。因此,默认情况下,输出结构可能并不总是与模型或变换兼容。

为了克服这个问题,我们可以使用 wrap_dataset_for_transforms_v2() 函数。对于 CocoDetection,这会将目标结构更改为列表的单个字典

dataset = datasets.wrap_dataset_for_transforms_v2(dataset, target_keys=("boxes", "labels", "masks"))

sample = dataset[0]
img, target = sample
print(f"{type(img) = }\n{type(target) = }\n{target.keys() = }")
print(f"{type(target['boxes']) = }\n{type(target['labels']) = }\n{type(target['masks']) = }")
type(img) = <class 'PIL.Image.Image'>
type(target) = <class 'dict'>
target.keys() = dict_keys(['boxes', 'masks', 'labels'])
type(target['boxes']) = <class 'torchvision.tv_tensors._bounding_boxes.BoundingBoxes'>
type(target['labels']) = <class 'torch.Tensor'>
type(target['masks']) = <class 'torchvision.tv_tensors._mask.Mask'>

我们使用 target_keys 参数指定我们感兴趣的输出类型。我们的数据集现在返回一个目标,该目标是一个字典,其值为 TVTensors(都是 torch.Tensor 子类)。我们从之前的输出中删除了所有不必要的键,但是如果您需要任何原始键(例如“image_id”,您仍然可以请求它)。

注意

如果您只想进行检测,则不需要也不应该在 target_keys 中传递“masks”:如果样本中存在蒙版,它们将被变换,从而不必要地减慢您的变换速度。

作为基准,让我们看一下没有变换的样本

plot([dataset[0], dataset[1]])
plot transforms e2e

变换

现在让我们定义我们的预处理变换。所有变换都知道如何处理图像、边界框和蒙版(如果相关)。

变换通常作为数据集的 transforms 参数传递,以便它们可以利用 torch.utils.data.DataLoader 的多处理功能。

transforms = v2.Compose(
    [
        v2.ToImage(),
        v2.RandomPhotometricDistort(p=1),
        v2.RandomZoomOut(fill={tv_tensors.Image: (123, 117, 104), "others": 0}),
        v2.RandomIoUCrop(),
        v2.RandomHorizontalFlip(p=1),
        v2.SanitizeBoundingBoxes(),
        v2.ToDtype(torch.float32, scale=True),
    ]
)

dataset = datasets.CocoDetection(IMAGES_PATH, ANNOTATIONS_PATH, transforms=transforms)
dataset = datasets.wrap_dataset_for_transforms_v2(dataset, target_keys=["boxes", "labels", "masks"])
loading annotations into memory...
Done (t=0.00s)
creating index...
index created!

这里有一些值得注意的地方

  • 我们将 PIL 图像转换为 Image 对象。这不是严格必要的,但依赖于张量(这里:张量子类)通常会 更快

  • 我们调用 SanitizeBoundingBoxes 来确保我们删除退化的边界框,以及它们对应的标签和蒙版。 SanitizeBoundingBoxes 应该至少在检测管道末尾放置一次;如果使用了 RandomIoUCrop,这一点尤其重要。

让我们看看带有增强管道的样本是什么样的

plot([dataset[0], dataset[1]])
plot transforms e2e

我们可以看到图像的颜色发生了扭曲、放大或缩小,并且翻转了。边界框和蒙版也相应地进行了变换。我们现在就可以开始训练了。

数据加载和训练循环

下面我们使用 Mask-RCNN,它是一个实例分割模型,但是我们在本教程中介绍的所有内容也适用于目标检测和语义分割任务。

data_loader = torch.utils.data.DataLoader(
    dataset,
    batch_size=2,
    # We need a custom collation function here, since the object detection
    # models expect a sequence of images and target dictionaries. The default
    # collation function tries to torch.stack() the individual elements,
    # which fails in general for object detection, because the number of bounding
    # boxes varies between the images of the same batch.
    collate_fn=lambda batch: tuple(zip(*batch)),
)

model = models.get_model("maskrcnn_resnet50_fpn_v2", weights=None, weights_backbone=None).train()

for imgs, targets in data_loader:
    loss_dict = model(imgs, targets)
    # Put your training logic here

    print(f"{[img.shape for img in imgs] = }")
    print(f"{[type(target) for target in targets] = }")
    for name, loss_val in loss_dict.items():
        print(f"{name:<20}{loss_val:.3f}")
[img.shape for img in imgs] = [torch.Size([3, 512, 512]), torch.Size([3, 409, 493])]
[type(target) for target in targets] = [<class 'dict'>, <class 'dict'>]
loss_classifier     4.722
loss_box_reg        0.006
loss_mask           0.734
loss_objectness     0.691
loss_rpn_box_reg    0.036

训练参考

从那里,您可以查看 torchvision 参考,在那里您将找到我们用于训练模型的实际训练脚本。

免责声明 我们参考中的代码比您在自己的用例中需要的代码更复杂:这是因为我们支持不同的后端(PIL、张量、TVTensors)和不同的变换命名空间(v1 和 v2)。所以不要害怕简化,只保留您需要的部分。

脚本的总运行时间: (0 分钟 4.882 秒)

由 Sphinx-Gallery 生成的画廊

文档

访问 PyTorch 的综合开发者文档

查看文档

教程

获取针对初学者和高级开发人员的深入教程

查看教程

资源

查找开发资源并获得问题的解答

查看资源