如何编写您自己的 v2 转换¶
本指南解释了如何编写与 torchvision transforms V2 API 兼容的转换。
只需创建一个 nn.Module
并覆盖 forward
方法¶
在大多数情况下,这就是您所需要的一切,只要您已经知道您的转换将期望的输入结构。 例如,如果您只是在进行图像分类,您的转换通常会接受单个图像作为输入,或者 (img, label)
输入。 因此,您可以将您的 forward
方法硬编码为仅接受该输入,例如:
class MyCustomTransform(torch.nn.Module):
def forward(self, img, label):
# Do some transformations
return new_img, new_label
注意
这意味着,如果您有一个自定义转换,它已经与 V1 转换(torchvision.transforms
中的那些转换)兼容,那么它仍然可以与 V2 转换一起使用,而无需进行任何更改!
我们将在下面使用典型的检测案例更完整地说明这一点,在检测案例中,我们的样本只是图像、边界框和标签
class MyCustomTransform(torch.nn.Module):
def forward(self, img, bboxes, label): # we assume inputs are always structured like this
print(
f"I'm transforming an image of shape {img.shape} "
f"with bboxes = {bboxes}\n{label = }"
)
# Do some transformations. Here, we're just passing though the input
return img, bboxes, label
transforms = v2.Compose([
MyCustomTransform(),
v2.RandomResizedCrop((224, 224), antialias=True),
v2.RandomHorizontalFlip(p=1),
v2.Normalize(mean=[0, 0, 0], std=[1, 1, 1])
])
H, W = 256, 256
img = torch.rand(3, H, W)
bboxes = tv_tensors.BoundingBoxes(
torch.tensor([[0, 10, 10, 20], [50, 50, 70, 70]]),
format="XYXY",
canvas_size=(H, W)
)
label = 3
out_img, out_bboxes, out_label = transforms(img, bboxes, label)
I'm transforming an image of shape torch.Size([3, 256, 256]) with bboxes = BoundingBoxes([[ 0, 10, 10, 20],
[50, 50, 70, 70]], format=BoundingBoxFormat.XYXY, canvas_size=(256, 256))
label = 3
print(f"Output image shape: {out_img.shape}\nout_bboxes = {out_bboxes}\n{out_label = }")
Output image shape: torch.Size([3, 224, 224])
out_bboxes = BoundingBoxes([[224, 0, 224, 0],
[136, 0, 173, 0]], format=BoundingBoxFormat.XYXY, canvas_size=(224, 224))
out_label = 3
注意
在代码中使用 TVTensor 类时,请务必熟悉以下部分:我有一个 TVTensor,但现在我有一个 Tensor。 救命!
支持任意输入结构¶
在上面的部分中,我们假设您已经知道输入的结构,并且可以接受在代码中硬编码此预期结构。 如果您希望您的自定义转换尽可能灵活,这可能会受到一些限制。
内置 Torchvision V2 转换的一个关键特性是,它们可以接受任意输入结构,并返回与输出结构相同的结构(带有转换后的条目)。 例如,转换可以接受单个图像,或者 (img, label)
元组,或者任意嵌套字典作为输入。 以下是关于内置转换 RandomHorizontalFlip
的一个示例
structured_input = {
"img": img,
"annotations": (bboxes, label),
"something that will be ignored": (1, "hello"),
"another tensor that is ignored": torch.arange(10),
}
structured_output = v2.RandomHorizontalFlip(p=1)(structured_input)
assert isinstance(structured_output, dict)
assert structured_output["something that will be ignored"] == (1, "hello")
assert (structured_output["another tensor that is ignored"] == torch.arange(10)).all()
print(f"The input bboxes are:\n{structured_input['annotations'][0]}")
print(f"The transformed bboxes are:\n{structured_output['annotations'][0]}")
The input bboxes are:
BoundingBoxes([[ 0, 10, 10, 20],
[50, 50, 70, 70]], format=BoundingBoxFormat.XYXY, canvas_size=(256, 256))
The transformed bboxes are:
BoundingBoxes([[246, 10, 256, 20],
[186, 50, 206, 70]], format=BoundingBoxFormat.XYXY, canvas_size=(256, 256))
基础知识:覆盖 transform() 方法¶
为了在您的自定义转换中支持任意输入,您需要从 Transform
继承并覆盖 .transform() 方法(而不是 forward() 方法!)。 以下是一个基本示例
class MyCustomTransform(v2.Transform):
def transform(self, inpt: Any, params: Dict[str, Any]):
if type(inpt) == torch.Tensor:
print(f"I'm transforming an image of shape {inpt.shape}")
return inpt + 1 # dummy transformation
elif isinstance(inpt, tv_tensors.BoundingBoxes):
print(f"I'm transforming bounding boxes! {inpt.canvas_size = }")
return tv_tensors.wrap(inpt + 100, like=inpt) # dummy transformation
my_custom_transform = MyCustomTransform()
structured_output = my_custom_transform(structured_input)
assert isinstance(structured_output, dict)
assert structured_output["something that will be ignored"] == (1, "hello")
assert (structured_output["another tensor that is ignored"] == torch.arange(10)).all()
print(f"The input bboxes are:\n{structured_input['annotations'][0]}")
print(f"The transformed bboxes are:\n{structured_output['annotations'][0]}")
I'm transforming an image of shape torch.Size([3, 256, 256])
I'm transforming bounding boxes! inpt.canvas_size = (256, 256)
The input bboxes are:
BoundingBoxes([[ 0, 10, 10, 20],
[50, 50, 70, 70]], format=BoundingBoxFormat.XYXY, canvas_size=(256, 256))
The transformed bboxes are:
BoundingBoxes([[100, 110, 110, 120],
[150, 150, 170, 170]], format=BoundingBoxFormat.XYXY, canvas_size=(256, 256))
需要注意的重要一点是,当我们在 structured_input
上调用 my_custom_transform
时,输入会被展平,然后每个单独的部分都会传递给 transform()
。 也就是说,transform()`
接收输入图像,然后接收边界框等。 在 transform()
中,您可以根据输入的类型决定如何转换每个输入。
如果您好奇为什么另一个张量 (torch.arange()
) 没有传递给 transform()
,请参阅 此注释 了解更多详细信息。
高级:make_params()
方法¶
make_params()
方法在对每个输入调用 transform()
之前在内部调用。 这通常对于生成随机参数值很有用。 在下面的示例中,我们使用它以 0.5 的概率随机应用转换
class MyRandomTransform(MyCustomTransform):
def __init__(self, p=0.5):
self.p = p
super().__init__()
def make_params(self, flat_inputs: List[Any]) -> Dict[str, Any]:
apply_transform = (torch.rand(size=(1,)) < self.p).item()
params = dict(apply_transform=apply_transform)
return params
def transform(self, inpt: Any, params: Dict[str, Any]):
if not params["apply_transform"]:
print("Not transforming anything!")
return inpt
else:
return super().transform(inpt, params)
my_random_transform = MyRandomTransform()
torch.manual_seed(0)
_ = my_random_transform(structured_input) # transforms
_ = my_random_transform(structured_input) # doesn't transform
I'm transforming an image of shape torch.Size([3, 256, 256])
I'm transforming bounding boxes! inpt.canvas_size = (256, 256)
Not transforming anything!
Not transforming anything!
注意
对于此类随机参数生成,在 make_params()
而不是在 transform()
中发生非常重要,这样,对于给定的转换调用,相同的 RNG 以相同的方式应用于所有输入。 如果我们在 transform()
中执行 RNG,我们可能会冒例如转换图像,但不转换边界框的风险。
make_params()
方法将所有输入的列表作为参数(此列表中的每个元素稍后将传递给 transform()
)。 您可以使用 flat_inputs
来例如计算输入的维度,使用 query_chw()
或 query_size()
。
make_params()
应该返回一个 dict(或者实际上,您想要的任何内容),然后它将被传递给 transform()
。
脚本总运行时间: (0 分钟 0.009 秒)