如何编写自己的 TVTensor 类¶
本指南适用于高级用户和下游库维护者。我们将解释如何编写自己的 TVTensor 类,以及如何使其与内置的 Torchvision v2 变换兼容。在继续之前,请确保您已阅读 TVTensors FAQ。
import torch
from torchvision import tv_tensors
from torchvision.transforms import v2
我们将创建一个非常简单的类,它只继承自基础 TVTensor 类。这足以涵盖实现更复杂用例所需的知识。如果需要创建携带元数据的类,可以参考 BoundingBoxes 类的实现。
class MyTVTensor(tv_tensors.TVTensor):
pass
my_dp = MyTVTensor([1, 2, 3])
my_dp
MyTVTensor([1., 2., 3.])
现在我们已经定义了自定义的 TVTensor 类,我们希望它能与内置的 torchvision 变换和函数式 API 兼容。为此,我们需要实现一个执行变换核心逻辑的 kernel,然后通过 register_kernel() 将其“挂钩”到我们想要支持的 functional 上。
我们将在下方说明此过程:为 MyTVTensor 类的“水平翻转”操作创建 kernel,并将其注册到函数式 API。
from torchvision.transforms.v2 import functional as F
@F.register_kernel(functional="hflip", tv_tensor_cls=MyTVTensor)
def hflip_my_tv_tensor(my_dp, *args, **kwargs):
print("Flipping!")
out = my_dp.flip(-1)
return tv_tensors.wrap(out, like=my_dp)
要理解为何使用 wrap(),请参阅 我有一个 TVTensor,但现在它变成了 Tensor。救命!。请暂时忽略 *args, **kwargs,我们将在下方 参数转发以及确保 kernel 的未来兼容性 中解释它。
注意
在我们上面调用 register_kernel 时,我们使用字符串 functional="hflip" 来指代我们想要挂钩的 functional。我们也可以直接使用 functional 本身,即 @register_kernel(functional=F.hflip, ...)。
现在我们已经注册了 kernel,可以在 MyTVTensor 实例上调用函数式 API。
my_dp = MyTVTensor(torch.rand(3, 256, 256))
_ = F.hflip(my_dp)
Flipping!
我们还可以使用 RandomHorizontalFlip 变换,因为它内部依赖于 hflip()。
t = v2.RandomHorizontalFlip(p=1)
_ = t(my_dp)
Flipping!
注意
我们不能为变换类注册 kernel,只能为 functional 注册 kernel。不能注册变换类的原因是,一个变换内部可能依赖多个 functional,所以通常情况下我们无法为给定类注册单个 kernel。
参数转发以及确保 kernel 的未来兼容性¶
您挂钩的函数式 API 是公开的,因此具有向后兼容性:我们保证这些 functional 的参数不会在没有适当弃用周期的情况下被移除或重命名。然而,我们不保证向前兼容性,将来我们可能会添加新参数。
假设在将来的版本中,Torchvision 的 hflip() functional 中添加了一个新的 inplace 参数。如果您已经将自己的 kernel 定义并注册为
def hflip_my_tv_tensor(my_dp): # noqa
print("Flipping!")
out = my_dp.flip(-1)
return tv_tensors.wrap(out, like=my_dp)
那么调用 F.hflip(my_dp) 将会失败,因为 hflip 会尝试将新的 inplace 参数传递给您的 kernel,但您的 kernel 不接受它。
因此,我们建议您始终按照上述示例在 kernel 的签名中包含 *args, **kwargs。这样,您的 kernel 就能接受未来可能添加的任何新参数。(技术上讲,只添加 **kwargs 应该就足够了)。
脚本总运行时间: (0 分钟 0.004 秒)