torch.overrides¶
此模块为 __torch_function__
协议公开了各种辅助函数。有关 __torch_function__
协议的更多详细信息,请参阅 扩展 torch Python API。
函数¶
- torch.overrides.get_ignored_functions()[source]¶
返回不能被
__torch_function__
覆盖的公共函数。- 返回值
一个包含在 torch API 中公开可用但不能使用
__torch_function__
覆盖的函数的元组。这主要是因为这些函数的参数都不是张量或类张量。- 返回类型
Set[Callable]
示例
>>> torch.Tensor.as_subclass in torch.overrides.get_ignored_functions() True >>> torch.add in torch.overrides.get_ignored_functions() False
- torch.overrides.get_overridable_functions()[source]¶
列出可以通过 __torch_function__ 覆盖的函数
- 返回值
一个字典,将包含可覆盖函数的命名空间映射到该命名空间中可以覆盖的函数。
- 返回类型
Dict[Any, List[Callable]]
- torch.overrides.resolve_name(f)[source]¶
获取传递给 __torch_function__ 的函数的可读字符串名称
- 参数
f (Callable) – 要解析其名称的函数。
- 返回值
函数的名称;如果被求值,它应该返回输入函数。
- 返回类型
- torch.overrides.get_testing_overrides()[source]¶
返回一个字典,其中包含所有可覆盖函数的虚拟覆盖
- 返回值
一个字典,将 PyTorch API 中的可覆盖函数映射到 lambda 函数,这些函数与真实函数具有相同的签名,并无条件地返回 -1。这些 lambda 函数对于测试定义了
__torch_function__
的类型的 API 覆盖范围很有用。- 返回类型
Dict[Callable, Callable]
示例
>>> import inspect >>> my_add = torch.overrides.get_testing_overrides()[torch.add] >>> inspect.signature(my_add) <Signature (input, other, out=None)>
- torch.overrides.handle_torch_function(public_api, relevant_args, *args, **kwargs)[source]¶
实现一个带有
__torch_function__
覆盖检查的函数。请参阅 torch::autograd::handle_torch_function,了解 C++ 实现中此函数的等效项。
- 参数
- 返回值
调用
implementation
或__torch_function__
方法的结果(如适用)。- 返回类型
:raises TypeError : 如果未找到实现。
示例
>>> def func(a): ... if has_torch_function_unary(a): ... return handle_torch_function(func, (a,), a) ... return a + 0
- torch.overrides.has_torch_function()¶
检查可迭代对象中的元素是否有 __torch_function__ 实现,或者是否启用了 __torch_function__ 模式。将精确的
Tensor
和Parameter
视为不可分派。使用它来保护对handle_torch_function()
的调用;不要使用它来测试某物是否类似 Tensor,而是使用is_tensor_like()
。:param relevant_args: 要检查 __torch_function__ 方法的可迭代对象或参数。:type relevant_args: 可迭代对象- 返回值
如果 relevant_args 的任何元素都有 __torch_function__ 实现,则为 True,否则为 False。
- 返回类型
另请参阅
torch.is_tensor_like
检查某物是否类似 Tensor,包括精确的
Tensor
。
- torch.overrides.is_tensor_like(inp)[source]¶
如果传入的输入类似 Tensor,则返回
True
。目前,只要输入类型的属性上存在
__torch_function__
,就会发生这种情况。示例
Tensor 的子类通常类似 Tensor。
>>> class SubTensor(torch.Tensor): ... >>> is_tensor_like(SubTensor([0])) True
内置或用户类型通常不类似 Tensor。
>>> is_tensor_like(6) False >>> is_tensor_like(None) False >>> class NotATensor: ... >>> is_tensor_like(NotATensor()) False
但是,可以通过实现 __torch_function__ 使它们类似 Tensor。
>>> class TensorLike: ... @classmethod ... def __torch_function__(cls, func, types, args, kwargs): ... return -1 >>> is_tensor_like(TensorLike()) True
- torch.overrides.is_tensor_method_or_property(func)[source]¶
如果传入的函数是
torch.Tensor
的方法或属性的处理程序,则返回 True,如传递到__torch_function__
中一样。注意
对于属性,必须传入其
__get__
方法。这可能特别需要以下原因
方法/属性有时不包含 __module__ 插槽。
它们要求第一个传入的参数是
torch.Tensor
的实例。
示例
>>> is_tensor_method_or_property(torch.Tensor.add) True >>> is_tensor_method_or_property(torch.add) False
- 返回类型
- torch.overrides.wrap_torch_function(dispatcher)[source]¶
用
__torch_function__
相关功能包装给定函数。- 参数
dispatcher (Callable) – 返回传递给函数的类似 Tensor 可迭代对象的函数。
注意
此装饰器可能会降低代码性能。通常,将代码表示为一系列本身支持 __torch_function__ 的函数就足够了。如果你发现自己处于极少数情况下需要这种情况,例如,如果你正在包装一个低级库,并且你还需要它对类似 Tensor 的东西工作,那么此函数可用。
示例
>>> def dispatcher(a): # Must have the same signature as func ... return (a,) >>> @torch.overrides.wrap_torch_function(dispatcher) >>> def func(a): # This will make func dispatchable by __torch_function__ ... return a + 0