跳转到主要内容
博客

将 TorchVision 的转换扩展到目标检测、分割和视频任务

注意:此帖子的一个旧版本于2022年11月发布。鉴于即将于2023年3月与PyTorch 2.0同步发布的torchvision 0.15版本,我们已更新此帖子,提供最新信息。

TorchVision 正在扩展其 Transforms API!新功能如下:

  • 您可以将它们不仅用于图像分类,还用于目标检测、实例和语义分割以及视频分类。
  • 您可以使用新的函数式转换来转换视频、边界框和分割掩码。

API与之前的版本完全向后兼容,并保持不变以协助迁移和采用。我们现在将此新API作为Beta版本发布在`torchvision.transforms.v2`命名空间中,我们非常希望获得您的早期反馈以改进其功能。如果您有任何问题或建议,请联系我们

当前Transforms的局限性

TorchVision现有的Transforms API(即V1)仅支持单个图像。因此,它只能用于分类任务。

from torchvision import transforms
trans = transforms.Compose([
   transforms.ColorJitter(contrast=0.5),
   transforms.RandomRotation(30),
   transforms.CenterCrop(480),
])
imgs = trans(imgs)

上述方法不支持目标检测和分割。这一限制使得任何非分类的计算机视觉任务成为次要任务,因为无法使用Transforms API执行必要的增强。历史上,这使得使用TorchVision的原语训练高精度模型变得困难,因此我们的模型库与SoTA相比落后了几分。

为了规避这一限制,TorchVision在其参考脚本中提供了自定义实现,展示了如何在每个任务中执行增强。尽管这种做法使我们能够训练高精度的分类目标检测和分割模型,但它是一种不完善的方法,使得这些变换无法从TorchVision二进制文件中导入。

新的Transforms API

Transforms V2 API支持视频、边界框和分割掩码,这意味着它为许多计算机视觉任务提供了原生支持。新解决方案是即插即用的替代方案。

import torchvision.transforms.v2 as transforms

# Exactly the same interface as V1:
trans = transforms.Compose([
    transforms.ColorJitter(contrast=0.5),
    transforms.RandomRotation(30),
    transforms.CenterCrop(480),
])
imgs, bboxes, labels = trans(imgs, bboxes, labels)

新的Transform类可以接收任意数量的输入,而不强制特定顺序或结构。

# Already supported:
trans(imgs)  # Image Classification
trans(videos)  # Video Tasks
trans(imgs, bboxes, labels)  # Object Detection
trans(imgs, bboxes, masks, labels)  # Instance Segmentation
trans(imgs, masks)  # Semantic Segmentation
trans({"image": imgs, "box": bboxes, "tag": labels})  # Arbitrary Structure

# Future support:
trans(imgs, bboxes, labels, keypoints)  # Keypoint Detection
trans(stereo_images, disparities, masks)  # Depth Perception
trans(image1, image2, optical_flows, masks)  # Optical Flow
trans(imgs_or_videos, labels)  # MixUp/CutMix-style Transforms

Transform类确保对所有输入应用相同的随机变换,以确保一致的结果。

功能API已更新,以支持所有输入所需的所有信号处理核(大小调整、裁剪、仿射变换、填充等)。

from torchvision.transforms.v2 import functional as F


# High-level dispatcher, accepts any supported input type, fully BC
F.resize(inpt, size=[224, 224])
# Image tensor kernel
F.resize_image_tensor(img_tensor, size=[224, 224], antialias=True) 
# PIL image kernel
F.resize_image_pil(img_pil, size=[224, 224], interpolation=BILINEAR)
# Video kernel
F.resize_video(video, size=[224, 224], antialias=True) 
# Mask kernel
F.resize_mask(mask, size=[224, 224])
# Bounding box kernel
F.resize_bounding_box(bbox, size=[224, 224], spatial_size=[256, 256])

在底层,API使用Tensor子类来封装输入,附加有用的元数据并调度到正确的内核。为了使您的数据与这些新的变换兼容,您可以使用提供的dataset包装器,它应该适用于大多数torchvision内置数据集,或者您可以手动将数据包装到Datapoints中。

from torchvision.datasets import wrap_dataset_for_transforms_v2
ds = CocoDetection(..., transforms=v2_transforms)
ds = wrap_dataset_for_transforms_v2(ds) # data is now compatible with transforms v2!

# Or wrap your data manually using the lower-level Datapoint classes:
from torchvision import datapoints

imgs = datapoints.Image(images)
vids = datapoints.Video(videos)
masks = datapoints.Mask(target["masks“])
bboxes = datapoints.BoundingBox(target["boxes“], format=”XYXY”, spatial_size=imgs.shape)

除了新的API之外,我们现在还为SoTA研究中使用的几种数据增强提供了可导入的实现,例如大规模抖动(Large Scale Jitter)自动增强(AutoAugmentation)方法和多种新的几何、颜色和类型转换变换。

该API继续支持PIL和Tensor后端用于图像、单输入或批量输入,并在功能和类API上保持JIT-scriptability。新API已验证,可实现与先前实现相同的精度。

一个端到端的例子

这是一个使用以下图像的新API示例。它适用于PIL图像和Tensors。有关更多示例和教程,请查看我们的画廊!

from torchvision import io, utils
from torchvision import datapoints
from torchvision.transforms import v2 as T
from torchvision.transforms.v2 import functional as F

# Defining and wrapping input to appropriate Tensor Subclasses
path = "COCO_val2014_000000418825.jpg"
img = datapoints.Image(io.read_image(path))
# img = PIL.Image.open(path)
bboxes = datapoints.BoundingBox(
    [[2, 0, 206, 253], [396, 92, 479, 241], [328, 253, 417, 332],
     [148, 68, 256, 182], [93, 158, 170, 260], [432, 0, 438, 26],
     [422, 0, 480, 25], [419, 39, 424, 52], [448, 37, 456, 62],
     [435, 43, 437, 50], [461, 36, 469, 63], [461, 75, 469, 94],
     [469, 36, 480, 64], [440, 37, 446, 56], [398, 233, 480, 304],
     [452, 39, 463, 63], [424, 38, 429, 50]],
    format=datapoints.BoundingBoxFormat.XYXY,
    spatial_size=F.get_spatial_size(img),
)
labels = [59, 58, 50, 64, 76, 74, 74, 74, 74, 74, 74, 74, 74, 74, 50, 74, 74]
# Defining and applying Transforms V2
trans = T.Compose(
    [
        T.ColorJitter(contrast=0.5),
        T.RandomRotation(30),
        T.CenterCrop(480),
    ]
)
img, bboxes, labels = trans(img, bboxes, labels)
# Visualizing results
viz = utils.draw_bounding_boxes(F.to_image_tensor(img), boxes=bboxes)
F.to_pil_image(viz).show()

开发里程碑和未来工作

这是我们目前在开发中的进展:

  • 设计API
  • 为转换视频、边界框、掩码和标签编写内核
  • 在新API上重写所有现有的Transform类(稳定版+参考版)
    • 图像分类
    • 视频分类
    • 目标检测
    • 实例分割
    • 语义分割
  • 验证新API在所有支持的任务和后端上的准确性
  • 速度基准测试和性能优化(进行中 – 计划于12月进行)
  • 从原型毕业(计划于第一季度进行)
  • 增加对深度感知、关键点检测、光流等功能的支持(未来)
  • 增加对MixUp和CutMix等批量变换的平滑支持

我们非常希望获得您的反馈以改进其功能。如果您有任何问题或建议,请联系我们。