transforms v2 入门¶
此示例说明了开始使用新的 torchvision.transforms.v2
API 所需了解的一切。我们将介绍简单的任务,如图像分类,以及更高级的任务,如对象检测/分割。
首先,进行一些设置
from pathlib import Path
import torch
import matplotlib.pyplot as plt
plt.rcParams["savefig.bbox"] = 'tight'
from torchvision.transforms import v2
from torchvision.io import decode_image
torch.manual_seed(1)
# If you're trying to run that on Colab, you can download the assets and the
# helpers from https://github.com/pytorch/vision/tree/main/gallery/
from helpers import plot
img = decode_image(str(Path('../assets') / 'astronaut.jpg'))
print(f"{type(img) = }, {img.dtype = }, {img.shape = }")
type(img) = <class 'torch.Tensor'>, img.dtype = torch.uint8, img.shape = torch.Size([3, 512, 512])
基础知识¶
Torchvision 转换的行为类似于常规的 torch.nn.Module
(实际上,它们大多数都是):实例化一个转换,传递一个输入,获得一个转换后的输出
transform = v2.RandomCrop(size=(224, 224))
out = transform(img)
plot([img, out])
data:image/s3,"s3://crabby-images/c8db8/c8db8e4c3c780421f1735627c03e158b8b9ded31" alt="plot transforms getting started"
我只想做图像分类¶
如果您只关心图像分类,事情就非常简单。一个基本的分类管道可能如下所示
transforms = v2.Compose([
v2.RandomResizedCrop(size=(224, 224), antialias=True),
v2.RandomHorizontalFlip(p=0.5),
v2.ToDtype(torch.float32, scale=True),
v2.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
out = transforms(img)
plot([img, out])
data:image/s3,"s3://crabby-images/92f90/92f90f1efb0b4459e0d1575d83439e47444119f8" alt="plot transforms getting started"
这样的转换管道通常作为 transform
参数传递给 数据集,例如 ImageNet(..., transform=transforms)
。
基本上就是这样。从那里,通读我们的 主要文档 以了解有关推荐实践和约定的更多信息,或浏览更多 示例,例如如何使用增强转换,如 CutMix 和 MixUp。
注意
如果您已经依赖于 torchvision.transforms
v1 API,我们建议 切换到新的 v2 转换。这非常容易:v2 转换与 v1 API 完全兼容,因此您只需要更改导入即可!
检测、分割、视频¶
torchvision.transforms.v2
命名空间中的新 Torchvision 转换支持超出图像分类的任务:它们还可以转换边界框、分割/检测掩码或视频。
让我们简要地看一下带有边界框的检测示例。
from torchvision import tv_tensors # we'll describe this a bit later, bare with us
boxes = tv_tensors.BoundingBoxes(
[
[15, 10, 370, 510],
[275, 340, 510, 510],
[130, 345, 210, 425]
],
format="XYXY", canvas_size=img.shape[-2:])
transforms = v2.Compose([
v2.RandomResizedCrop(size=(224, 224), antialias=True),
v2.RandomPhotometricDistort(p=1),
v2.RandomHorizontalFlip(p=1),
])
out_img, out_boxes = transforms(img, boxes)
print(type(boxes), type(out_boxes))
plot([(img, boxes), (out_img, out_boxes)])
data:image/s3,"s3://crabby-images/cef06/cef06f6efd95f87d13a20e5dcc3d1f4ea3dbdc1b" alt="plot transforms getting started"
<class 'torchvision.tv_tensors._bounding_boxes.BoundingBoxes'> <class 'torchvision.tv_tensors._bounding_boxes.BoundingBoxes'>
上面的示例侧重于对象检测。但是,如果我们有用于对象分割或语义分割的掩码(torchvision.tv_tensors.Mask
)或视频(torchvision.tv_tensors.Video
),我们可以以完全相同的方式将它们传递给转换。
到现在,您可能有一些问题:什么是 TVTensor?我们如何使用它们?这些转换的预期输入/输出是什么?我们将在下一节中回答这些问题。
什么是 TVTensor?¶
TVTensor 是 torch.Tensor
子类。可用的 TVTensor 有 Image
、BoundingBoxes
、Mask
和 Video
。
TVTensor 的外观和感觉就像常规张量 - 它们就是张量。普通 torch.Tensor
上支持的一切,如 .sum()
或任何 torch.*
运算符,也适用于 TVTensor
img_dp = tv_tensors.Image(torch.randint(0, 256, (3, 256, 256), dtype=torch.uint8))
print(f"{isinstance(img_dp, torch.Tensor) = }")
print(f"{img_dp.dtype = }, {img_dp.shape = }, {img_dp.sum() = }")
isinstance(img_dp, torch.Tensor) = True
img_dp.dtype = torch.uint8, img_dp.shape = torch.Size([3, 256, 256]), img_dp.sum() = tensor(25087958)
这些 TVTensor 类是转换的核心:为了转换给定的输入,转换首先查看对象的类,并相应地调度到适当的实现。
您现在无需了解有关 TVTensor 的更多信息,但想要了解更多信息的高级用户可以参考 TVTensor 常见问题解答。
我应该传递什么作为输入?¶
上面,我们看到了两个示例:一个是我们传递单个图像作为输入,即 out = transforms(img)
,另一个是我们同时传递图像和边界框,即 out_img, out_boxes = transforms(img, boxes)
。
实际上,转换支持任意输入结构。输入可以是单个图像、元组、任意嵌套的字典……几乎任何东西。相同的结构将作为输出返回。下面,我们使用相同的检测转换,但传递一个元组 (图像, target_dict) 作为输入,我们得到与输出相同的结构
target = {
"boxes": boxes,
"labels": torch.arange(boxes.shape[0]),
"this_is_ignored": ("arbitrary", {"structure": "!"})
}
# Re-using the transforms and definitions from above.
out_img, out_target = transforms(img, target)
plot([(img, target["boxes"]), (out_img, out_target["boxes"])])
print(f"{out_target['this_is_ignored']}")
data:image/s3,"s3://crabby-images/bd918/bd9182e5ba111000edf64ba73785597953178808" alt="plot transforms getting started"
('arbitrary', {'structure': '!'})
我们传递了一个元组,所以我们得到一个元组作为返回,第二个元素是转换后的目标字典。转换实际上并不关心输入的结构;如上所述,它们只关心对象的类型并相应地转换它们。
外部 对象(如字符串或整数)只是被传递。例如,如果您想在调试时将路径与每个样本关联,这将非常有用!
注意
免责声明 此注释略微高级,首次阅读可以安全跳过。
纯 torch.Tensor
对象通常被视为图像(或视频特定转换的视频)。实际上,您可能已经注意到,在上面的代码中,我们根本没有使用 Image
类,但我们的图像得到了正确的转换。转换遵循以下逻辑来确定是否应将纯张量视为图像(或视频),还是仅忽略
如果输入中存在
Image
、Video
或PIL.Image.Image
实例,则所有其他纯张量都将被传递。如果没有
Image
或Video
实例,则只有第一个纯torch.Tensor
将被转换为图像或视频,而所有其他纯张量都将被传递。“第一个”是指“深度优先遍历中的第一个”。
这就是上面检测示例中发生的情况:第一个纯张量是图像,因此它得到了正确的转换,而所有其他纯张量实例(如 labels
)都被传递了(尽管标签仍然可以被某些转换转换,如 SanitizeBoundingBoxes
!)。
转换和数据集的互操作性¶
粗略地说,数据集的输出必须对应于转换的输入。如何做到这一点取决于您使用的是 torchvision 内置数据集 还是您自己的自定义数据集。
使用内置数据集¶
如果您只是进行图像分类,则无需执行任何操作。只需使用数据集的 transform
参数,例如 ImageNet(..., transform=transforms)
,就可以了。
Torchvision 还支持用于对象检测或分割的数据集,如 torchvision.datasets.CocoDetection
。这些数据集早于 torchvision.transforms.v2
模块和 TVTensor 的存在,因此它们不会开箱即用地返回 TVTensor。
强制这些数据集返回 TVTensor 并使其与 v2 转换兼容的一种简单方法是使用 torchvision.datasets.wrap_dataset_for_transforms_v2()
函数
from torchvision.datasets import CocoDetection, wrap_dataset_for_transforms_v2
dataset = CocoDetection(..., transforms=my_transforms)
dataset = wrap_dataset_for_transforms_v2(dataset)
# Now the dataset returns TVTensors!
使用您自己的数据集¶
如果您有自定义数据集,则需要将您的对象转换为适当的 TVTensor 类。创建 TVTensor 实例非常容易,请参阅 如何创建 TVTensor? 了解更多详细信息。
您可以在两个主要位置实现该转换逻辑
在数据集的
__getitem__
方法的末尾,在返回样本之前(或通过子类化数据集)。作为转换管道的第一步
无论哪种方式,逻辑都将取决于您的特定数据集。
脚本的总运行时间: (0 分 0.685 秒)