快捷方式

torch.overrides

此模块为 __torch_function__ 协议提供了各种帮助器函数。有关 __torch_function__ 协议的更多详细信息,请参阅 Extending torch Python API

函数

torch.overrides.get_ignored_functions()[source]

返回无法被 __torch_function__ 覆盖的公共函数。

返回

在 torch API 中公开可用但无法使用 __torch_function__ 覆盖的函数元组。这主要是因为这些函数的任何参数都不是张量或类张量。

返回类型

Set[Callable]

示例

>>> torch.Tensor.as_subclass in torch.overrides.get_ignored_functions()
True
>>> torch.add in torch.overrides.get_ignored_functions()
False
torch.overrides.get_overridable_functions()[source]

列出可通过 __torch_function__ 覆盖的函数

返回

一个字典,它将包含可覆盖函数的命名空间映射到该命名空间中可被覆盖的函数。

返回类型

Dict[Any, List[Callable]]

torch.overrides.resolve_name(f)[source]

获取传递给 __torch_function__ 的函数的可读字符串名称

参数

f (Callable) – 要解析其名称的函数。

返回

函数的名称;如果经过 eval,它应返回输入函数。

返回类型

str

torch.overrides.get_testing_overrides()[source]

返回一个包含所有可覆盖函数的虚拟覆盖的字典

返回

一个字典,它将 PyTorch API 中的可覆盖函数映射到具有与真实函数相同的签名的 lambda 函数,并无条件地返回 -1。这些 lambda 函数对于测试定义 __torch_function__ 的类型的 API 覆盖范围很有用。

返回类型

Dict[Callable, Callable]

示例

>>> import inspect
>>> my_add = torch.overrides.get_testing_overrides()[torch.add]
>>> inspect.signature(my_add)
<Signature (input, other, out=None)>
torch.overrides.handle_torch_function(public_api, relevant_args, *args, **kwargs)[source]

使用 __torch_function__ 覆盖实现一个函数并进行检查。

请参阅 torch::autograd::handle_torch_function,了解 C++ 实现中此函数的等效项。

参数
  • public_api (function) – 由公共 torch API 公开的函数,最初像 public_api(*args, **kwargs) 那样被调用,现在正在检查其参数。

  • relevant_args (iterable) – 用于检查 __torch_function__ 方法的参数的可迭代对象。

  • args (tuple) – 最初传递给 public_api 的任意位置参数。

  • kwargs (tuple) – 最初传递给 public_api 的任意关键字参数。

返回

从调用 implementation__torch_function__ 方法(视情况而定)得到的结果。

返回类型

对象

:raises TypeError : 如果找不到实现。

示例

>>> def func(a):
...     if has_torch_function_unary(a):
...         return handle_torch_function(func, (a,), a)
...     return a + 0
torch.overrides.has_torch_function()

检查可迭代元素中的 __torch_function__ 实现或是否启用了 __torch_function__ 模式。将确切的 TensorParameter 视为不可分发的。使用此方法来保护对 handle_torch_function() 的调用;不要使用它来测试某个内容是否类似于 Tensor,而是使用 is_tensor_like()。 :param relevant_args: 要检查 __torch_function__ 方法的可迭代或参数。 :type relevant_args: 可迭代

返回

如果 relevant_args 的任何元素具有 __torch_function__ 实现,则为 True,否则为 False。

返回类型

布尔值

另请参阅

torch.is_tensor_like

检查某个内容是否类似于 Tensor,包括确切的 Tensor

torch.overrides.is_tensor_like(inp)[源代码]

如果传入的输入是类似于 Tensor 的内容,则返回 True

目前,每当输入的类型上存在 __torch_function__ 属性时,就会发生这种情况。

示例

张量的子类通常类似于 Tensor。

>>> class SubTensor(torch.Tensor): ...
>>> is_tensor_like(SubTensor([0]))
True

内置或用户类型通常不像 Tensor。

>>> is_tensor_like(6)
False
>>> is_tensor_like(None)
False
>>> class NotATensor: ...
>>> is_tensor_like(NotATensor())
False

但是,可以通过实现 __torch_function__ 使它们类似于 Tensor。

>>> class TensorLike:
...     @classmethod
...     def __torch_function__(cls, func, types, args, kwargs):
...         return -1
>>> is_tensor_like(TensorLike())
True
torch.overrides.is_tensor_method_or_property(func)[源代码]

如果传入的函数是属于 torch.Tensor 的方法或属性的处理程序(如传入 __torch_function__ 的),则返回 True。

注意

对于属性,必须传入它们的 __get__ 方法。

可能需要这样做,特别是出于以下原因

  1. 方法/属性有时不包含 __module__ 插槽。

  2. 它们要求第一个传入的参数是 torch.Tensor 的实例。

示例

>>> is_tensor_method_or_property(torch.Tensor.add)
True
>>> is_tensor_method_or_property(torch.add)
False
返回类型

布尔值

torch.overrides.wrap_torch_function(dispatcher)[source]

__torch_function__ 相关功能包装给定函数。

参数

dispatcher (可调用对象) – 返回传递给函数的张量类似物的可迭代对象的调用对象。

注意

此装饰器可能会降低代码的性能。通常,将代码表示为一系列函数就足够了,这些函数本身支持 __torch_function__。如果您发现自己处于罕见的情况下,例如,如果您正在包装一个低级库,并且还需要它适用于张量类似物,那么可以使用此函数。

示例

>>> def dispatcher(a): # Must have the same signature as func
...     return (a,)
>>> @torch.overrides.wrap_torch_function(dispatcher)
>>> def func(a): # This will make func dispatchable by __torch_function__
...     return a + 0

文档

访问 PyTorch 的全面开发者文档

查看文档

教程

获取面向初学者和高级开发者的深入教程

查看教程

资源

查找开发资源并获得问题的解答

查看资源