• 文档 >
  • 命名张量运算符覆盖范围
快捷方式

命名张量运算符覆盖范围

请先阅读命名张量以了解命名张量的介绍。

本文档是名称推断的参考,这是一个定义了命名张量如何使用名称的过程:

  1. 使用名称提供额外的自动运行时正确性检查

  2. 将名称从输入张量传播到输出张量

以下是支持命名张量及其相关名称推断规则的所有运算符的列表。

如果您没有看到此处列出的运算符,但它对您的用例有帮助,请搜索是否已经提交了问题,如果没有,请提交一个问题

警告

命名张量 API 处于实验阶段,可能会发生变化。

支持的运算符

API

名称推断规则

Tensor.abs()torch.abs()

保留输入名称

Tensor.abs_()

保留输入名称

Tensor.acos()torch.acos()

保留输入名称

Tensor.acos_()

保留输入名称

Tensor.add()torch.add()

统一来自输入的名称

Tensor.add_()

统一来自输入的名称

Tensor.addmm()torch.addmm()

缩减掉维度

Tensor.addmm_()

缩减掉维度

Tensor.addmv()torch.addmv()

缩减掉维度

Tensor.addmv_()

缩减掉维度

Tensor.align_as()

见文档

Tensor.align_to()

见文档

Tensor.all()torch.all()

Tensor.any()torch.any()

Tensor.asin()torch.asin()

保留输入名称

Tensor.asin_()

保留输入名称

Tensor.atan()torch.atan()

保留输入名称

Tensor.atan2()torch.atan2()

统一来自输入的名称

Tensor.atan2_()

统一来自输入的名称

Tensor.atan_()

保留输入名称

Tensor.bernoulli()torch.bernoulli()

保留输入名称

Tensor.bernoulli_()

Tensor.bfloat16()

保留输入名称

Tensor.bitwise_not()torch.bitwise_not()

保留输入名称

Tensor.bitwise_not_()

Tensor.bmm()torch.bmm()

缩减掉维度

Tensor.bool()

保留输入名称

Tensor.byte()

保留输入名称

torch.cat()

统一来自输入的名称

Tensor.cauchy_()

Tensor.ceil()torch.ceil()

保留输入名称

Tensor.ceil_()

Tensor.char()

保留输入名称

Tensor.chunk()torch.chunk()

保留输入名称

Tensor.clamp(), torch.clamp()

保留输入名称

Tensor.clamp_()

Tensor.copy_()

out 函数和就地变体

Tensor.cos(), torch.cos()

保留输入名称

Tensor.cos_()

Tensor.cosh(), torch.cosh()

保留输入名称

Tensor.cosh_()

Tensor.acosh(), torch.acosh()

保留输入名称

Tensor.acosh_()

Tensor.cpu()

保留输入名称

Tensor.cuda()

保留输入名称

Tensor.cumprod(), torch.cumprod()

保留输入名称

Tensor.cumsum(), torch.cumsum()

保留输入名称

Tensor.data_ptr()

Tensor.deg2rad(), torch.deg2rad()

保留输入名称

Tensor.deg2rad_()

Tensor.detach(), torch.detach()

保留输入名称

Tensor.detach_()

Tensor.device, torch.device()

Tensor.digamma(), torch.digamma()

保留输入名称

Tensor.digamma_()

Tensor.dim()

Tensor.div(), torch.div()

统一来自输入的名称

Tensor.div_()

统一来自输入的名称

Tensor.dot(), torch.dot()

Tensor.double()

保留输入名称

Tensor.element_size()

torch.empty()

工厂函数

torch.empty_like()

工厂函数

Tensor.eq(), torch.eq()

统一来自输入的名称

Tensor.erf(), torch.erf()

保留输入名称

Tensor.erf_()

Tensor.erfc(), torch.erfc()

保留输入名称

Tensor.erfc_()

Tensor.erfinv(), torch.erfinv()

保留输入名称

Tensor.erfinv_()

Tensor.exp(), torch.exp()

保留输入名称

Tensor.exp_()

Tensor.expand()

保留输入名称

Tensor.expm1(), torch.expm1()

保留输入名称

Tensor.expm1_()

Tensor.exponential_()

Tensor.fill_()

Tensor.flatten(), torch.flatten()

见文档

Tensor.float()

保留输入名称

Tensor.floor(), torch.floor()

保留输入名称

Tensor.floor_()

Tensor.frac(), torch.frac()

保留输入名称

Tensor.frac_()

Tensor.ge(), torch.ge()

统一来自输入的名称

Tensor.get_device(), torch.get_device()

Tensor.grad

Tensor.gt(), torch.gt()

统一来自输入的名称

Tensor.half()

保留输入名称

Tensor.has_names()

见文档

Tensor.index_fill(), torch.index_fill()

保留输入名称

Tensor.index_fill_()

Tensor.int()

保留输入名称

Tensor.is_contiguous()

Tensor.is_cuda

Tensor.is_floating_point(), torch.is_floating_point()

Tensor.is_leaf

Tensor.is_pinned()

Tensor.is_shared()

Tensor.is_signed(), torch.is_signed()

Tensor.is_sparse

Tensor.is_sparse_csr

torch.is_tensor()

Tensor.item()

Tensor.itemsize

Tensor.kthvalue(), torch.kthvalue()

移除维度

Tensor.le(), torch.le()

统一来自输入的名称

Tensor.log(), torch.log()

保留输入名称

Tensor.log10(), torch.log10()

保留输入名称

Tensor.log10_()

Tensor.log1p(), torch.log1p()

保留输入名称

Tensor.log1p_()

Tensor.log2(), torch.log2()

保留输入名称

Tensor.log2_()

Tensor.log_()

Tensor.log_normal_()

Tensor.logical_not(), torch.logical_not()

保留输入名称

Tensor.logical_not_()

Tensor.logsumexp(), torch.logsumexp()

移除维度

Tensor.long()

保留输入名称

Tensor.lt(), torch.lt()

统一来自输入的名称

torch.manual_seed()

Tensor.masked_fill(), torch.masked_fill()

保留输入名称

Tensor.masked_fill_()

Tensor.masked_select(), torch.masked_select()

将掩码与输入对齐,然后统一来自输入张量的名称

Tensor.matmul(), torch.matmul()

缩减掉维度

Tensor.mean(), torch.mean()

移除维度

Tensor.median(), torch.median()

移除维度

Tensor.nanmedian(), torch.nanmedian()

移除维度

Tensor.mm(), torch.mm()

缩减掉维度

Tensor.mode(), torch.mode()

移除维度

Tensor.mul(), torch.mul()

统一来自输入的名称

Tensor.mul_()

统一来自输入的名称

Tensor.mv(), torch.mv()

缩减掉维度

Tensor.names

见文档

Tensor.narrow(), torch.narrow()

保留输入名称

Tensor.nbytes

Tensor.ndim

Tensor.ndimension()

Tensor.ne(), torch.ne()

统一来自输入的名称

Tensor.neg(), torch.neg()

保留输入名称

Tensor.neg_()

torch.normal()

保留输入名称

Tensor.normal_()

Tensor.numel(), torch.numel()

torch.ones()

工厂函数

Tensor.pow(), torch.pow()

统一来自输入的名称

Tensor.pow_()

Tensor.prod(), torch.prod()

移除维度

Tensor.rad2deg(), torch.rad2deg()

保留输入名称

Tensor.rad2deg_()

torch.rand()

工厂函数

torch.rand()

工厂函数

torch.randn()

工厂函数

torch.randn()

工厂函数

Tensor.random_()

Tensor.reciprocal(), torch.reciprocal()

保留输入名称

Tensor.reciprocal_()

Tensor.refine_names()

见文档

Tensor.register_hook()

Tensor.register_post_accumulate_grad_hook()

Tensor.rename()

见文档

Tensor.rename_()

见文档

Tensor.requires_grad

Tensor.requires_grad_()

Tensor.resize_()

仅允许不改变形状的调整大小

Tensor.resize_as_()

仅允许不改变形状的调整大小

Tensor.round(), torch.round()

保留输入名称

Tensor.round_()

Tensor.rsqrt(), torch.rsqrt()

保留输入名称

Tensor.rsqrt_()

Tensor.select(), torch.select()

移除维度

Tensor.short()

保留输入名称

Tensor.sigmoid(), torch.sigmoid()

保留输入名称

Tensor.sigmoid_()

Tensor.sign(), torch.sign()

保留输入名称

Tensor.sign_()

Tensor.sgn(), torch.sgn()

保留输入名称

Tensor.sgn_()

Tensor.sin(), torch.sin()

保留输入名称

Tensor.sin_()

Tensor.sinh(), torch.sinh()

保留输入名称

Tensor.sinh_()

Tensor.asinh(), torch.asinh()

保留输入名称

Tensor.asinh_()

Tensor.size()

Tensor.softmax(), torch.softmax()

保留输入名称

Tensor.split(), torch.split()

保留输入名称

Tensor.sqrt(), torch.sqrt()

保留输入名称

Tensor.sqrt_()

Tensor.squeeze(), torch.squeeze()

移除维度

Tensor.std(), torch.std()

移除维度

torch.std_mean()

移除维度

Tensor.stride()

Tensor.sub(), torch.sub()

统一来自输入的名称

Tensor.sub_()

统一来自输入的名称

Tensor.sum(), torch.sum()

移除维度

Tensor.tan(), torch.tan()

保留输入名称

Tensor.tan_()

Tensor.tanh(), torch.tanh()

保留输入名称

Tensor.tanh_()

Tensor.atanh(), torch.atanh()

保留输入名称

Tensor.atanh_()

torch.tensor()

工厂函数

Tensor.to()

保留输入名称

Tensor.topk(), torch.topk()

移除维度

Tensor.transpose(), torch.transpose()

置换维度

Tensor.trunc()torch.trunc()

保留输入名称

Tensor.trunc_()

Tensor.type()

Tensor.type_as()

保留输入名称

Tensor.unbind()torch.unbind()

移除维度

Tensor.unflatten()

见文档

Tensor.uniform_()

Tensor.var()torch.var()

移除维度

torch.var_mean()

移除维度

Tensor.zero_()

torch.zeros()

工厂函数

保留输入名称

所有逐元素一元函数都遵循此规则,以及其他一些一元函数。

  • 检查名称:无

  • 传播名称:输入张量的名称将传播到输出。

>>> x = torch.randn(3, 3, names=('N', 'C'))
>>> x.abs().names
('N', 'C')

移除维度

所有减少操作,如 sum() 通过在所需维度上进行减少来移除维度。其他操作,如 select()squeeze() 会移除维度。

在可以将整数维度索引传递给操作符的地方,也可以传递维度名称。接受维度索引列表的函数也可以接受维度名称列表。

  • 检查名称:如果 dimdims 被传递为名称列表,则检查这些名称是否存在于 self 中。

  • 传播名称:如果输入张量中由 dimdims 指定的维度不出现在输出张量中,则这些维度的相应名称不会出现在 output.names 中。

>>> x = torch.randn(1, 3, 3, 3, names=('N', 'C', 'H', 'W'))
>>> x.squeeze('N').names
('C', 'H', 'W')

>>> x = torch.randn(3, 3, 3, 3, names=('N', 'C', 'H', 'W'))
>>> x.sum(['N', 'C']).names
('H', 'W')

# Reduction ops with keepdim=True don't actually remove dimensions.
>>> x = torch.randn(3, 3, 3, 3, names=('N', 'C', 'H', 'W'))
>>> x.sum(['N', 'C'], keepdim=True).names
('N', 'C', 'H', 'W')

统一来自输入的名称

所有二元算术运算都遵循此规则。广播操作仍然从右侧按位置广播,以保持与无名张量的兼容性。要按名称执行显式广播,请使用 Tensor.align_as()

  • 检查名称:所有名称都必须从右侧按位置匹配。即,在 tensor + other 中,match(tensor.names[i], other.names[i]) 对于所有 i 必须为真,在 (-min(tensor.dim(), other.dim()) + 1, -1] 中。

  • 检查名称:此外,所有命名维度都必须从右侧对齐。在匹配过程中,如果我们将命名维度 A 与无名维度 None 匹配,则 A 不得出现在具有无名维度的张量中。

  • 传播名称:从两个张量中统一来自右侧的名称对以生成输出名称。

例如,

# tensor: Tensor[   N, None]
# other:  Tensor[None,    C]
>>> tensor = torch.randn(3, 3, names=('N', None))
>>> other = torch.randn(3, 3, names=(None, 'C'))
>>> (tensor + other).names
('N', 'C')

检查名称

  • match(tensor.names[-1], other.names[-1])True

  • match(tensor.names[-2], tensor.names[-2])True

  • 因为我们在 tensor 中将 None'C' 匹配,请检查以确保 'C' 不存在于 tensor 中(它不存在)。

  • 检查以确保 'N' 不存在于 other 中(它不存在)。

最后,输出名称使用 [unify('N', None), unify(None, 'C')] = ['N', 'C'] 计算。

更多示例

# Dimensions don't match from the right:
# tensor: Tensor[N, C]
# other:  Tensor[   N]
>>> tensor = torch.randn(3, 3, names=('N', 'C'))
>>> other = torch.randn(3, names=('N',))
>>> (tensor + other).names
RuntimeError: Error when attempting to broadcast dims ['N', 'C'] and dims
['N']: dim 'C' and dim 'N' are at the same position from the right but do
not match.

# Dimensions aren't aligned when matching tensor.names[-1] and other.names[-1]:
# tensor: Tensor[N, None]
# other:  Tensor[      N]
>>> tensor = torch.randn(3, 3, names=('N', None))
>>> other = torch.randn(3, names=('N',))
>>> (tensor + other).names
RuntimeError: Misaligned dims when attempting to broadcast dims ['N'] and
dims ['N', None]: dim 'N' appears in a different position from the right
across both lists.

注意

在最后两个示例中,都可以按名称对齐张量,然后执行加法。使用 Tensor.align_as() 按名称对齐张量,或使用 Tensor.align_to() 将张量对齐到自定义维度顺序。

置换维度

某些操作,如 Tensor.t(),会置换维度的顺序。维度名称与单个维度相关联,因此它们也会被置换。

如果操作符接受位置索引 dim,它也可以接受维度名称作为 dim

  • 检查名称:如果 dim 被传递为名称,则检查它是否存在于张量中。

  • 传播名称:以与被置换的维度相同的顺序置换维度名称。

>>> x = torch.randn(3, 3, names=('N', 'C'))
>>> x.transpose('N', 'C').names
('C', 'N')

收缩维度

矩阵乘法函数遵循此规则的某些变体。让我们首先介绍 torch.mm(),然后概括批量矩阵乘法的规则。

对于 torch.mm(tensor, other)

  • 检查名称:无

  • 传播名称:结果名称为 (tensor.names[-2], other.names[-1])

>>> x = torch.randn(3, 3, names=('N', 'D'))
>>> y = torch.randn(3, 3, names=('in', 'out'))
>>> x.mm(y).names
('N', 'out')

本质上,矩阵乘法对两个维度执行点积,并将其折叠。当两个张量进行矩阵乘法时,收缩的维度会消失,不会出现在输出张量中。

torch.mv()torch.dot() 的工作方式类似:名称推断不检查输入名称,并移除参与点积的维度。

>>> x = torch.randn(3, 3, names=('N', 'D'))
>>> y = torch.randn(3, names=('something',))
>>> x.mv(y).names
('N',)

现在,让我们看一下 torch.matmul(tensor, other)。假设 tensor.dim() >= 2other.dim() >= 2

  • 检查名称:检查输入的批量维度是否对齐且可广播。请参阅 统一来自输入的名称,了解输入对齐的含义。

  • 传播名称:结果名称通过统一批量维度并移除收缩维度获得:unify(tensor.names[:-2], other.names[:-2]) + (tensor.names[-2], other.names[-1])

示例

# Batch matrix multiply of matrices Tensor['C', 'D'] and Tensor['E', 'F'].
# 'A', 'B' are batch dimensions.
>>> x = torch.randn(3, 3, 3, 3, names=('A', 'B', 'C', 'D'))
>>> y = torch.randn(3, 3, 3, names=('B', 'E', 'F'))
>>> torch.matmul(x, y).names
('A', 'B', 'C', 'F')

最后,许多 matmul 函数都有融合的 add 版本。即,addmm()addmv()。这些被视为组合了例如 mm() 的名称推断和 add() 的名称推断。

工厂函数

工厂函数现在接受一个新的 names 参数,该参数将名称与每个维度相关联。

>>> torch.zeros(2, 3, names=('N', 'C'))
tensor([[0., 0., 0.],
        [0., 0., 0.]], names=('N', 'C'))

out 函数和就地变体

指定为 out= 张量的张量具有以下行为

  • 如果没有命名维度,则从操作中计算出的名称将传播到它。

  • 如果它有任何命名维度,则从操作中计算出的名称必须与现有名称完全相同。否则,操作会出错。

所有就地方法都会修改输入,使其具有与从名称推断中计算出的名称相同的名称。例如

>>> x = torch.randn(3, 3)
>>> y = torch.randn(3, 3, names=('N', 'C'))
>>> x.names
(None, None)

>>> x += y
>>> x.names
('N', 'C')

文档

访问 PyTorch 的全面开发人员文档

查看文档

教程

获取面向初学者和高级开发人员的深入教程

查看教程

资源

查找开发资源并获得问题的解答

查看资源