快捷方式

set_return_type

class torchvision.tv_tensors.set_return_type(return_type: str)[source]

设置 torch 对 TVTensor 执行操作时的返回类型。

这只影响 torch 操作的行为。它对 torchvision 的 transforms 或 functionals 没有影响,后者总是返回与输入相同的类型作为输出。

警告

如果您使用 set_return_type("TVTensor"),建议在 transform 管道的末尾使用 ToPureTensor。这将避免模型 forward() 中的 __torch_function__ 开销。

可以作为整个程序的全局标志使用

img = tv_tensors.Image(torch.rand(3, 5, 5))
img + 2  # This is a pure Tensor (default behaviour)

set_return_type("TVTensor")
img + 2  # This is an Image

或作为上下文管理器限制范围

img = tv_tensors.Image(torch.rand(3, 5, 5))
img + 2  # This is a pure Tensor
with set_return_type("TVTensor"):
    img + 2  # This is an Image
img + 2  # This is a pure Tensor
参数:

return_type (str) – 可以是 “TVTensor” 或 “Tensor” (不区分大小写)。默认值为 “Tensor” (即纯 torch.Tensor)。

使用 set_return_type 的示例

TVTensors 常见问题

TVTensors 常见问题

文档

查阅 PyTorch 的完整开发者文档

查看文档

教程

获取面向初学者和高级开发者的深入教程

查看教程

资源

查找开发资源并获得问题解答

查看资源