set_return_type¶
- class torchvision.tv_tensors.set_return_type(return_type: str)[source]¶
设置 torch 对
TVTensor
执行操作时的返回类型。这只影响 torch 操作的行为。它对
torchvision
的 transforms 或 functionals 没有影响,后者总是返回与输入相同的类型作为输出。警告
如果您使用
set_return_type("TVTensor")
,建议在 transform 管道的末尾使用ToPureTensor
。这将避免模型forward()
中的__torch_function__
开销。可以作为整个程序的全局标志使用
img = tv_tensors.Image(torch.rand(3, 5, 5)) img + 2 # This is a pure Tensor (default behaviour) set_return_type("TVTensor") img + 2 # This is an Image
或作为上下文管理器限制范围
img = tv_tensors.Image(torch.rand(3, 5, 5)) img + 2 # This is a pure Tensor with set_return_type("TVTensor"): img + 2 # This is an Image img + 2 # This is a pure Tensor
- 参数:
return_type (str) – 可以是 “TVTensor” 或 “Tensor” (不区分大小写)。默认值为 “Tensor” (即纯
torch.Tensor
)。
使用
set_return_type
的示例