如何编写自己的 v2 变换¶
本指南说明如何编写与 torchvision 变换 V2 API 兼容的变换。
import torch
from torchvision import tv_tensors
from torchvision.transforms import v2
只需创建一个 nn.Module
并覆盖 forward
方法¶
在大多数情况下,只要您已经了解变换将接受的输入结构,这将是您需要做的全部操作。例如,如果您只是进行图像分类,您的变换通常会接受单个图像作为输入,或一个 (img, label)
输入。因此,您可以直接对 forward
方法进行硬编码,以使其仅接受该输入,例如
class MyCustomTransform(torch.nn.Module):
def forward(self, img, label):
# Do some transformations
return new_img, new_label
注意
这意味着如果您有一个与 V1 变换(torchvision.transforms
中的变换)兼容的自定义变换,那么它将继续与 V2 变换一起使用,无需任何更改!
我们将在下面使用典型的检测案例更完整地说明这一点,其中我们的样本只是图像、边界框和标签
class MyCustomTransform(torch.nn.Module):
def forward(self, img, bboxes, label): # we assume inputs are always structured like this
print(
f"I'm transforming an image of shape {img.shape} "
f"with bboxes = {bboxes}\n{label = }"
)
# Do some transformations. Here, we're just passing though the input
return img, bboxes, label
transforms = v2.Compose([
MyCustomTransform(),
v2.RandomResizedCrop((224, 224), antialias=True),
v2.RandomHorizontalFlip(p=1),
v2.Normalize(mean=[0, 0, 0], std=[1, 1, 1])
])
H, W = 256, 256
img = torch.rand(3, H, W)
bboxes = tv_tensors.BoundingBoxes(
torch.tensor([[0, 10, 10, 20], [50, 50, 70, 70]]),
format="XYXY",
canvas_size=(H, W)
)
label = 3
out_img, out_bboxes, out_label = transforms(img, bboxes, label)
I'm transforming an image of shape torch.Size([3, 256, 256]) with bboxes = BoundingBoxes([[ 0, 10, 10, 20],
[50, 50, 70, 70]], format=BoundingBoxFormat.XYXY, canvas_size=(256, 256))
label = 3
print(f"Output image shape: {out_img.shape}\nout_bboxes = {out_bboxes}\n{out_label = }")
Output image shape: torch.Size([3, 224, 224])
out_bboxes = BoundingBoxes([[224, 0, 224, 0],
[136, 0, 173, 0]], format=BoundingBoxFormat.XYXY, canvas_size=(224, 224))
out_label = 3
注意
在代码中使用 TVTensor 类时,请确保熟悉本节:我有一个 TVTensor,但现在我有一个 Tensor。求助!
支持任意输入结构¶
在上一节中,我们假设您已经了解了输入的结构,并且您对在代码中对这种预期结构进行硬编码感到满意。如果您希望自己的自定义变换尽可能灵活,这可能有点限制。
内置 Torchvision V2 变换的一个关键特性是它们可以接受任意输入结构,并以与输出相同的结构返回(包含已变换的条目)。例如,变换可以接受单个图像,或一个 (img, label)
元组,或者作为输入的任意嵌套字典
structured_input = {
"img": img,
"annotations": (bboxes, label),
"something_that_will_be_ignored": (1, "hello")
}
structured_output = v2.RandomHorizontalFlip(p=1)(structured_input)
assert isinstance(structured_output, dict)
assert structured_output["something_that_will_be_ignored"] == (1, "hello")
print(f"The transformed bboxes are:\n{structured_output['annotations'][0]}")
The transformed bboxes are:
BoundingBoxes([[246, 10, 256, 20],
[186, 50, 206, 70]], format=BoundingBoxFormat.XYXY, canvas_size=(256, 256))
如果您想在自己的变换中重现此行为,我们建议您查看我们的 代码 并根据您的需要进行调整。
简而言之,核心逻辑是使用 pytree 将输入解包成一个扁平列表,然后仅变换可以变换的条目(根据条目的**类**进行判断,因为所有 TVTensor 都是张量子类)以及此处未显示的一些自定义逻辑 - 查看代码以获取详细信息。然后对(可能已变换的)条目进行重新打包并返回,与输入的结构相同。
目前我们不提供面向公众开发的工具来实现这一点,但如果您觉得这将对您有所帮助,请在我们的 GitHub 仓库 上打开一个问题让我们知道。
脚本总运行时间:(0 分钟 0.007 秒)