快捷方式

如何编写自己的 TVTensor 类

注意

Colab 上尝试或 转到结尾 下载完整的示例代码。

本指南面向高级用户和下游库维护人员。我们解释了如何编写自己的 TVTensor 类,以及如何使其与内置的 Torchvision v2 变换兼容。在继续之前,请确保您已阅读 TVTensor 常见问题解答

import torch
from torchvision import tv_tensors
from torchvision.transforms import v2

我们将创建一个非常简单的类,它只继承自基础 TVTensor 类。这足以涵盖您需要了解的内容,以便实现更复杂的用例。如果您需要创建一个包含元数据的类,请查看 BoundingBoxes 类的 实现方式

class MyTVTensor(tv_tensors.TVTensor):
    pass


my_dp = MyTVTensor([1, 2, 3])
my_dp
MyTVTensor([1., 2., 3.])

现在我们已经定义了自定义 TVTensor 类,我们希望它与内置的 torchvision 变换和函数式 API 兼容。为此,我们需要实现一个内核,它执行变换的核心,然后通过 register_kernel() 将其“挂钩”到我们想要支持的函数式 API。

我们在下面说明了此过程:我们为 MyTVTensor 类的“水平翻转”操作创建了一个内核,并将其注册到函数式 API。

from torchvision.transforms.v2 import functional as F


@F.register_kernel(functional="hflip", tv_tensor_cls=MyTVTensor)
def hflip_my_tv_tensor(my_dp, *args, **kwargs):
    print("Flipping!")
    out = my_dp.flip(-1)
    return tv_tensors.wrap(out, like=my_dp)

要了解为什么使用 wrap(),请查看 我有一个 TVTensor,但现在我有一个 Tensor。帮助!。暂时忽略 *args, **kwargs,我们将在下面 参数转发,并确保内核的未来兼容性 中解释。

注意

在上面的 register_kernel 调用中,我们使用了字符串 functional="hflip" 来引用我们想要挂钩的函数式 API。我们也可以使用函数式 API 本身,即 @register_kernel(functional=F.hflip, ...)

现在我们已经注册了内核,我们可以对 MyTVTensor 实例调用函数式 API

my_dp = MyTVTensor(torch.rand(3, 256, 256))
_ = F.hflip(my_dp)
Flipping!

我们还可以使用 RandomHorizontalFlip 变换,因为它在内部依赖于 hflip()

t = v2.RandomHorizontalFlip(p=1)
_ = t(my_dp)
Flipping!

注意

我们无法为变换类注册内核,只能为**函数式 API**注册内核。我们无法为变换类注册内核的原因是,一个变换可能在内部依赖于多个函数式 API,因此一般来说,我们无法为给定的类注册单个内核。

参数转发,并确保内核的未来兼容性

您正在挂钩的函数式 API 是公开的,因此**向后**兼容:我们保证这些函数式 API 的参数不会在没有适当的弃用周期的情况下被删除或重命名。但是,我们不保证**向前**兼容性,我们可能会在将来添加新参数。

想象一下,在将来的某个版本中,Torchvision 将新的 inplace 参数添加到其 hflip() 函数式 API 中。如果您已经定义并注册了自己的内核,如下所示:

def hflip_my_tv_tensor(my_dp):  # noqa
    print("Flipping!")
    out = my_dp.flip(-1)
    return tv_tensors.wrap(out, like=my_dp)

那么调用 F.hflip(my_dp) 将**失败**,因为 hflip 将尝试将新的 inplace 参数传递给您的内核,但您的内核不接受它。

因此,我们建议始终在内核签名中定义 *args, **kwargs,如上所示。这样,您的内核将能够接受我们将来可能添加的任何新参数。(从技术上讲,只添加 **kwargs 就足够了)。

脚本总运行时间:(0 分钟 0.004 秒)

由 Sphinx-Gallery 生成的图库

文档

访问 PyTorch 的全面开发者文档

查看文档

教程

获取面向初学者和高级开发人员的深入教程

查看教程

资源

查找开发资源并获得问题的解答

查看资源