快捷方式

如何编写自己的 v2 转换

注意

尝试在 colab 上或 转到结尾 下载完整的示例代码。

本指南介绍如何编写与 torchvision 转换 V2 API 兼容的转换。

import torch
from torchvision import tv_tensors
from torchvision.transforms import v2

只需创建一个 nn.Module 并覆盖 forward 方法

在大多数情况下,只要您已经了解转换将期望的输入结构,这将是您所需要的全部。例如,如果您只是进行图像分类,您的转换通常会接受单个图像作为输入,或者 (img, label) 输入。因此,您可以直接将 forward 方法硬编码为仅接受该输入,例如

class MyCustomTransform(torch.nn.Module):
    def forward(self, img, label):
        # Do some transformations
        return new_img, new_label

注意

这意味着,如果您有一个与 V1 转换(位于 torchvision.transforms 中)兼容的自定义转换,它将无需任何更改即可与 V2 转换一起使用!

我们将在下面用一个典型的检测案例更完整地说明这一点,在这个案例中,我们的样本只是图像、边界框和标签。

class MyCustomTransform(torch.nn.Module):
    def forward(self, img, bboxes, label):  # we assume inputs are always structured like this
        print(
            f"I'm transforming an image of shape {img.shape} "
            f"with bboxes = {bboxes}\n{label = }"
        )
        # Do some transformations. Here, we're just passing though the input
        return img, bboxes, label


transforms = v2.Compose([
    MyCustomTransform(),
    v2.RandomResizedCrop((224, 224), antialias=True),
    v2.RandomHorizontalFlip(p=1),
    v2.Normalize(mean=[0, 0, 0], std=[1, 1, 1])
])

H, W = 256, 256
img = torch.rand(3, H, W)
bboxes = tv_tensors.BoundingBoxes(
    torch.tensor([[0, 10, 10, 20], [50, 50, 70, 70]]),
    format="XYXY",
    canvas_size=(H, W)
)
label = 3

out_img, out_bboxes, out_label = transforms(img, bboxes, label)
I'm transforming an image of shape torch.Size([3, 256, 256]) with bboxes = BoundingBoxes([[ 0, 10, 10, 20],
               [50, 50, 70, 70]], format=BoundingBoxFormat.XYXY, canvas_size=(256, 256))
label = 3
print(f"Output image shape: {out_img.shape}\nout_bboxes = {out_bboxes}\n{out_label = }")
Output image shape: torch.Size([3, 224, 224])
out_bboxes = BoundingBoxes([[224,   0, 224,   0],
               [136,   0, 173,   0]], format=BoundingBoxFormat.XYXY, canvas_size=(224, 224))
out_label = 3

注意

在代码中使用 TVTensor 类时,请务必熟悉以下部分:我有一个 TVTensor,但现在我有一个 Tensor。求助!

支持任意输入结构

在上面的部分中,我们假设您已经了解了输入的结构,并且您对在代码中硬编码此预期结构感到满意。如果您希望您的自定义转换尽可能灵活,这可能会有点限制。

内置 Torchvision V2 转换的一个关键特性是它们可以接受任意输入结构,并返回与输出相同的结构(带有转换后的条目)。例如,转换可以接受单个图像,或一个 (img, label) 元组,或一个任意嵌套字典作为输入。

structured_input = {
    "img": img,
    "annotations": (bboxes, label),
    "something_that_will_be_ignored": (1, "hello")
}
structured_output = v2.RandomHorizontalFlip(p=1)(structured_input)

assert isinstance(structured_output, dict)
assert structured_output["something_that_will_be_ignored"] == (1, "hello")
print(f"The transformed bboxes are:\n{structured_output['annotations'][0]}")
The transformed bboxes are:
BoundingBoxes([[246,  10, 256,  20],
               [186,  50, 206,  70]], format=BoundingBoxFormat.XYXY, canvas_size=(256, 256))

如果您希望在自己的转换中复制此行为,我们建议您查看我们的 代码 并根据您的需要进行调整。

简而言之,核心逻辑是使用 pytree 将输入解包成一个扁平列表,然后仅转换可以转换的条目(根据条目的 **类** 进行决策,因为所有 TVTensor 都是张量子类)加上一些此处不计分的自定义逻辑 - 检查代码以获取详细信息。然后,将(可能已转换的)条目重新打包并返回,其结构与输入相同。

我们目前没有提供面向开发人员的公共工具来实现这一点,但如果您认为这很有价值,请通过在我们的 GitHub 仓库 上创建一个问题来告知我们。

脚本的总运行时间:(0 分钟 0.006 秒)

由 Sphinx-Gallery 生成的图库

文档

访问 PyTorch 的全面开发者文档

查看文档

教程

获取针对初学者和高级开发人员的深入教程

查看教程

资源

查找开发资源并获得问题的解答

查看资源