快捷方式

TVTensors FAQ

注意

Colab 上试用或转到末尾下载完整的示例代码。

TVTensors 是与 torchvision.transforms.v2 一起引入的 Tensor 子类。此示例展示了这些 TVTensors 是什么以及它们如何工作。

警告

目标受众 除非您正在编写自己的转换或自己的 TVTensors,否则您可能不需要阅读本指南。这是一个相当底层的topic,大多数用户不需要担心:您无需了解 TVTensors 的内部原理即可高效地依赖 torchvision.transforms.v2。但是,对于尝试实现自己的数据集、转换或直接使用 TVTensors 的高级用户来说,它可能很有用。

import PIL.Image

import torch
from torchvision import tv_tensors

什么是 TVTensors?

TVTensors 是零拷贝张量子类

tensor = torch.rand(3, 256, 256)
image = tv_tensors.Image(tensor)

assert isinstance(image, torch.Tensor)
assert image.data_ptr() == tensor.data_ptr()

在底层,torchvision.transforms.v2 中需要它们来正确地为输入数据分派到适当的函数。

torchvision.tv_tensors 支持四种类型的 TVTensors

我可以使用 TVTensor 做什么?

TVTensors 看起来和感觉起来就像普通的张量 - 它们就是张量。普通 torch.Tensor 上支持的所有内容,如 .sum() 或任何 torch.* 运算符,也适用于 TVTensors。有关一些注意事项,请参阅 我曾经有一个 TVTensor,但现在我有一个 Tensor。 怎么办?

如何构建 TVTensor?

使用构造函数

每个 TVTensor 类都接受任何可以转换为 Tensor 的类张量数据

image = tv_tensors.Image([[[[0, 1], [1, 0]]]])
print(image)
Image([[[[0, 1],
         [1, 0]]]], )

与其他 PyTorch 创建操作类似,构造函数还接受 dtypedevicerequires_grad 参数。

float_image = tv_tensors.Image([[[0, 1], [1, 0]]], dtype=torch.float32, requires_grad=True)
print(float_image)
Image([[[0., 1.],
        [1., 0.]]], grad_fn=<AliasBackward0>, )

此外,ImageMask 也可以直接接受 PIL.Image.Image

image = tv_tensors.Image(PIL.Image.open("../assets/astronaut.jpg"))
print(image.shape, image.dtype)
torch.Size([3, 512, 512]) torch.uint8

某些 TVTensor 需要传入额外的元数据才能构建。例如,BoundingBoxes 需要坐标格式以及相应图像的大小 (canvas_size) 以及实际值。这些元数据是正确转换边界框所必需的。

bboxes = tv_tensors.BoundingBoxes(
    [[17, 16, 344, 495], [0, 10, 0, 10]],
    format=tv_tensors.BoundingBoxFormat.XYXY,
    canvas_size=image.shape[-2:]
)
print(bboxes)
BoundingBoxes([[ 17,  16, 344, 495],
               [  0,  10,   0,  10]], format=BoundingBoxFormat.XYXY, canvas_size=torch.Size([512, 512]))

使用 tv_tensors.wrap()

您还可以使用 wrap() 函数将张量对象包装到 TVTensor 中。当您已经拥有所需类型的对象时,这非常有用,这通常发生在编写转换时:您只想包装输出,就像输入一样。

new_bboxes = torch.tensor([0, 20, 30, 40])
new_bboxes = tv_tensors.wrap(new_bboxes, like=bboxes)
assert isinstance(new_bboxes, tv_tensors.BoundingBoxes)
assert new_bboxes.canvas_size == bboxes.canvas_size

new_bboxes 的元数据与 bboxes 相同,但您可以将其作为参数传递以覆盖它。

我曾经有一个 TVTensor,但现在我有一个 Tensor。 怎么办?

默认情况下,对 TVTensor 对象的操作将返回纯 Tensor

assert isinstance(bboxes, tv_tensors.BoundingBoxes)

# Shift bboxes by 3 pixels in both H and W
new_bboxes = bboxes + 3

assert isinstance(new_bboxes, torch.Tensor)
assert not isinstance(new_bboxes, tv_tensors.BoundingBoxes)

注意

此行为仅影响原生的 torch 操作。如果您使用的是内置的 torchvision 转换或函数,您将始终获得与您作为输入传递的类型相同的输出(纯 TensorTVTensor)。

但我想要返回一个 TVTensor!

您可以通过调用 TVTensor 构造函数,或者使用 wrap() 函数(有关更多详细信息,请参见上面的 如何构建 TVTensor?)将纯张量重新包装到 TVTensor 中

new_bboxes = bboxes + 3
new_bboxes = tv_tensors.wrap(new_bboxes, like=bboxes)
assert isinstance(new_bboxes, tv_tensors.BoundingBoxes)

或者,您可以将 set_return_type() 用作整个程序的全局配置设置,或用作上下文管理器(阅读其文档以了解有关注意事项的更多信息)

with tv_tensors.set_return_type("TVTensor"):
    new_bboxes = bboxes + 3
assert isinstance(new_bboxes, tv_tensors.BoundingBoxes)

为什么会这样?

出于性能原因TVTensor 类是 Tensor 子类,因此任何涉及 TVTensor 对象的操作都将通过 __torch_function__ 协议。这会产生少量开销,我们希望尽可能避免这种情况。这对于内置的 torchvision 转换来说无关紧要,因为我们可以避免那里的开销,但这在您模型的 forward 中可能会成为问题。

无论如何,另一种选择也好不了多少。 对于每个保留 TVTensor 类型有意义的操作,也有许多操作最好返回纯 Tensor:例如,img.sum() 仍然是 Image 吗? 如果我们一路保留 TVTensor 类型,即使是模型的 logits 或损失函数的输出最终也会成为 Image 类型,而且这肯定是不希望看到的。

注意

这种行为是我们正在积极寻求反馈的事情。如果您对此感到惊讶,或者您对如何更好地支持您的用例有任何建议,请通过此 issue 与我们联系:https://github.com/pytorch/vision/issues/7319

例外情况

此“解包”规则有一些例外:clone()to()torch.Tensor.detach()requires_grad_() 保留 TVTensor 类型。

对 TVTensors 的就地操作(如 obj.add_())将保留 obj 的类型。但是,就地操作的返回值将是纯张量

image = tv_tensors.Image([[[0, 1], [1, 0]]])

new_image = image.add_(1).mul_(2)

# image got transformed in-place and is still a TVTensor Image, but new_image
# is a Tensor. They share the same underlying data and they're equal, just
# different classes.
assert isinstance(image, tv_tensors.Image)
print(image)

assert isinstance(new_image, torch.Tensor) and not isinstance(new_image, tv_tensors.Image)
assert (new_image == image).all()
assert new_image.data_ptr() == image.data_ptr()
Image([[[2, 4],
        [4, 2]]], )

脚本的总运行时间: (0 分钟 0.009 秒)

由 Sphinx-Gallery 生成的图库

文档

访问 PyTorch 的全面开发者文档

查看文档

教程

获取针对初学者和高级开发者的深入教程

查看教程

资源

查找开发资源并获得问题解答

查看资源