快捷方式

Transforms v2 入门

注意

colab 上试用或 转到结尾 下载完整示例代码。

此示例说明了开始使用新的 torchvision.transforms.v2 API 所需了解的一切。我们将涵盖图像分类等简单任务,以及对象检测/分割等更高级的任务。

首先,进行一些设置

from pathlib import Path
import torch
import matplotlib.pyplot as plt
plt.rcParams["savefig.bbox"] = 'tight'

from torchvision.transforms import v2
from torchvision.io import read_image

torch.manual_seed(1)

# If you're trying to run that on collab, you can download the assets and the
# helpers from https://github.com/pytorch/vision/tree/main/gallery/
from helpers import plot
img = read_image(str(Path('../assets') / 'astronaut.jpg'))
print(f"{type(img) = }, {img.dtype = }, {img.shape = }")
type(img) = <class 'torch.Tensor'>, img.dtype = torch.uint8, img.shape = torch.Size([3, 512, 512])

基础知识

Torchvision 转换的行为类似于常规的 torch.nn.Module(实际上,大多数转换都是):实例化一个转换,传递一个输入,获得一个转换后的输出

transform = v2.RandomCrop(size=(224, 224))
out = transform(img)

plot([img, out])
plot transforms getting started

我只想进行图像分类

如果您只关心图像分类,事情就非常简单。一个基本的分类管道可能如下所示

transforms = v2.Compose([
    v2.RandomResizedCrop(size=(224, 224), antialias=True),
    v2.RandomHorizontalFlip(p=0.5),
    v2.ToDtype(torch.float32, scale=True),
    v2.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
out = transforms(img)

plot([img, out])
plot transforms getting started

这种转换管道通常作为 transform 参数传递给 数据集,例如 ImageNet(..., transform=transforms)

这基本上就是所有内容。从这里开始,请阅读我们的 主要文档 了解推荐的做法和约定,或探索更多 示例,例如如何使用增强转换,例如 CutMix 和 MixUp

注意

如果您已经依赖于 torchvision.transforms v1 API,我们建议您 切换到新的 v2 转换。这很容易:v2 转换与 v1 API 完全兼容,因此您只需要更改导入!

检测、分割、视频

新的 Torchvision 转换位于 torchvision.transforms.v2 命名空间中,支持超越图像分类的任务:它们还可以转换边界框、分割/检测掩码或视频。

让我们简要地看一下一个使用边界框的检测示例。

from torchvision import tv_tensors  # we'll describe this a bit later, bare with us

boxes = tv_tensors.BoundingBoxes(
    [
        [15, 10, 370, 510],
        [275, 340, 510, 510],
        [130, 345, 210, 425]
    ],
    format="XYXY", canvas_size=img.shape[-2:])

transforms = v2.Compose([
    v2.RandomResizedCrop(size=(224, 224), antialias=True),
    v2.RandomPhotometricDistort(p=1),
    v2.RandomHorizontalFlip(p=1),
])
out_img, out_boxes = transforms(img, boxes)
print(type(boxes), type(out_boxes))

plot([(img, boxes), (out_img, out_boxes)])
plot transforms getting started
<class 'torchvision.tv_tensors._bounding_boxes.BoundingBoxes'> <class 'torchvision.tv_tensors._bounding_boxes.BoundingBoxes'>

上面的示例重点关注对象检测。但如果我们有用于对象分割或语义分割的掩码 (torchvision.tv_tensors.Mask),或视频 (torchvision.tv_tensors.Video),我们可以以完全相同的方式将它们传递给转换。

到目前为止,您可能已经有一些问题:这些 TVTensors 是什么,我们如何使用它们,以及这些转换的预期输入/输出是什么?我们将在接下来的部分中回答这些问题。

什么是 TVTensors?

TVTensors 是 torch.Tensor 子类。可用的 TVTensors 是 ImageBoundingBoxesMaskVideo

TVTensors 看起来像普通张量,并且感觉就像普通张量一样 - 它们确实是张量。所有对普通 torch.Tensor 支持的操作,例如 .sum() 或任何 torch.* 运算符,对 TVTensor 也适用

img_dp = tv_tensors.Image(torch.randint(0, 256, (3, 256, 256), dtype=torch.uint8))

print(f"{isinstance(img_dp, torch.Tensor) = }")
print(f"{img_dp.dtype = }, {img_dp.shape = }, {img_dp.sum() = }")
isinstance(img_dp, torch.Tensor) = True
img_dp.dtype = torch.uint8, img_dp.shape = torch.Size([3, 256, 256]), img_dp.sum() = tensor(25087958)

这些 TVTensor 类是转换的核心:为了转换给定的输入,转换首先查看对象的 **类**,然后相应地分派到适当的实现。

此时,您不需要了解有关 TVTensors 的更多信息,但想要了解更深入内容的进阶用户可以参考 TVTensors 常见问题解答

我应该传递什么作为输入?

在上面,我们看到了两个示例:一个示例中我们传递了一个单独的图像作为输入,即 out = transforms(img),另一个示例中我们传递了一个图像和边界框,即 out_img, out_boxes = transforms(img, boxes)

实际上,转换支持 **任意输入结构**。输入可以是单个图像、元组、任意嵌套的字典……几乎任何内容。相同结构将作为输出返回。在下面,我们使用相同的检测转换,但传递一个元组 (图像,目标字典) 作为输入,我们获得与输出相同的结构

target = {
    "boxes": boxes,
    "labels": torch.arange(boxes.shape[0]),
    "this_is_ignored": ("arbitrary", {"structure": "!"})
}

# Re-using the transforms and definitions from above.
out_img, out_target = transforms(img, target)

plot([(img, target["boxes"]), (out_img, out_target["boxes"])])
print(f"{out_target['this_is_ignored']}")
plot transforms getting started
('arbitrary', {'structure': '!'})

我们传递了一个元组,因此我们得到一个元组作为返回值,第二个元素是转换后的目标字典。转换实际上并不关心输入的结构;如上所述,它们只关心对象的 **类型** 并相应地转换它们。

像字符串或整数这样的 *外部* 对象只是被直接传递。这在调试时很有用,例如,您想将每个样本的路径与之关联!

注意

**免责声明** 此说明稍微高级,在第一次阅读时可以安全跳过。

纯粹的 torch.Tensor 对象通常被视为图像(对于特定于视频的转换,则视为视频)。事实上,您可能已经注意到,在上面的代码中,我们根本没有使用 Image 类,但我们的图像仍然得到了正确的转换。转换遵循以下逻辑来确定纯张量应该被视为图像(或视频),还是被忽略

  • 如果输入中存在 ImageVideoPIL.Image.Image 实例,则所有其他纯张量都将被直接传递。

  • 如果不存在 ImageVideo 实例,则只有第一个纯粹的 torch.Tensor 将被转换为图像或视频,而所有其他纯粹的张量将被直接传递。这里“第一个”指的是“深度优先遍历中的第一个”。

这就是上面检测示例中发生的情况:第一个纯粹的张量是图像,因此它被正确地转换了,所有其他纯粹的张量实例,如 labels,都被直接传递了(尽管标签仍然可以通过某些转换进行转换,例如 SanitizeBoundingBoxes!)

Transforms 和 Datasets 的互操作性

粗略地说,数据集的输出必须与转换的输入相对应。如何做到这一点取决于您是使用 torchvision 内置数据集 还是您自己的自定义数据集。

使用内置数据集

如果您只是进行图像分类,您无需执行任何操作。只需使用数据集的 transform 参数,例如 ImageNet(..., transform=transforms),您就可以开始了。

Torchvision 还支持用于目标检测或分割的数据集,例如 torchvision.datasets.CocoDetection。这些数据集早于 torchvision.transforms.v2 模块和 TVTensor 的存在,因此它们不会开箱即用地返回 TVTensor。

强制这些数据集返回 TVTensor 并使其与 v2 转换兼容的一种简单方法是使用 torchvision.datasets.wrap_dataset_for_transforms_v2() 函数

from torchvision.datasets import CocoDetection, wrap_dataset_for_transforms_v2

dataset = CocoDetection(..., transforms=my_transforms)
dataset = wrap_dataset_for_transforms_v2(dataset)
# Now the dataset returns TVTensors!

使用您自己的数据集

如果您有自定义数据集,则需要将您的对象转换为相应的 TVTensor 类。创建 TVTensor 实例非常容易,请参考 如何构建 TVTensor? 了解更多详细信息。

您可以在两个主要位置实现该转换逻辑

  • 在数据集的 __getitem__ 方法结束时,在返回样本之前(或通过对数据集进行子类化)。

  • 作为您的转换管道的第一步

无论哪种方式,逻辑都将取决于您的特定数据集。

脚本的总运行时间:(0 分钟 0.860 秒)

由 Sphinx-Gallery 生成的图库

文档

访问 PyTorch 的全面开发者文档

查看文档

教程

获取适用于初学者和高级开发者的深入教程

查看教程

资源

查找开发资源并获得问题的解答

查看资源