快捷方式

torch.einsum

torch.einsum(equation, *operands) 张量[source][source]

根据爱因斯坦求和约定表示法,沿指定维度对输入 operands 的元素乘积求和。

Einsum 允许使用基于爱因斯坦求和约定的简写格式(由 equation 指定)来计算许多常见的多维线性代数数组操作。这种格式的详细信息将在下面描述,但其总体思想是为输入 operands 的每个维度标记一个下标,并定义哪些下标属于输出。然后,通过对 operands 元素沿着不属于输出的下标维度求积再求和,来计算输出。例如,矩阵乘法可以使用 einsum 表示为 torch.einsum(“ij,jk->ik”, A, B)。这里,j 是求和下标,i 和 k 是输出下标(关于原因的更多详细信息请参见下面的章节)。

方程表达式

equation 字符串指定了输入 operands 的每个维度所对应的下标([a-zA-Z] 中的字母),顺序与维度顺序一致,使用逗号 (‘,’) 分隔每个操作数的下标,例如 ‘ij,jk’ 指定了两个 2D 操作数的下标。标记有相同下标的维度必须是可广播的,也就是说,它们的尺寸必须匹配或为 1。例外情况是,如果一个下标在同一个输入操作数中重复出现,则此操作数中标记该下标的维度尺寸必须匹配,并且该操作数将沿这些维度被其对角线替换。在 equation 中只出现一次的下标将成为输出的一部分,并按字母升序排列。输出是通过将输入 operands 元素逐个相乘(根据下标对齐维度),然后对不属于输出的下标维度求和计算得出的。

此外,可以通过在方程末尾添加箭头 (‘->’) 并跟随输出下标来显式定义输出下标。例如,以下方程计算矩阵乘积的转置:‘ij,jk->ki’。输出下标必须至少在某个输入操作数中出现一次,且在输出中最多出现一次。

可以使用省略号 (‘…’) 代替下标,以广播省略号所覆盖的维度。每个输入操作数最多可以包含一个省略号,它将覆盖未被下标覆盖的维度,例如,对于一个 5 维的输入操作数,方程 ‘ab…c’ 中的省略号覆盖第三和第四维。省略号在不同操作数中不必覆盖相同数量的维度,但省略号的“形状”(它们覆盖的维度尺寸)必须能够一起广播。如果未使用箭头 (‘->’) 表示法显式定义输出,则省略号将首先出现在输出中(最左边的维度),然后才是输入操作数中只出现一次的下标标签。例如,以下方程实现了批次矩阵乘法 ‘…ij,…jk’

最后几点注意事项:方程中可以在不同元素(下标、省略号、箭头和逗号)之间包含空格,但类似 ‘…’ 的写法是无效的。空字符串 ‘’ 对于标量操作数是有效的。

注意

torch.einsum 对省略号 (‘…’) 的处理与 NumPy 不同,它允许对省略号覆盖的维度进行求和,也就是说,省略号不强制要求成为输出的一部分。

注意

请安装 opt-einsum (https://optimized-einsum.readthedocs.io/en/stable/) 以获得性能更好的 einsum。您可以在安装 torch 时一起安装:pip install torch[opt-einsum],或者单独安装:pip install opt-einsum

如果 opt-einsum 可用,此函数将通过我们的 opt_einsum 后端 torch.backends.opt_einsum(我知道 _ 和 - 容易混淆)优化收缩顺序,从而自动加速计算和/或减少内存消耗。当输入至少有三个时才会进行此优化,否则顺序无关紧要。请注意,找到最优路径是 NP 难问题,因此 opt-einsum 依赖于不同的启发式方法来获得接近最优的结果。如果 opt-einsum 不可用,默认顺序是从左到右收缩。

要绕过此默认行为,添加以下代码以禁用 opt_einsum 并跳过路径计算:torch.backends.opt_einsum.enabled = False

要指定 opt_einsum 计算收缩路径的策略,添加以下代码:torch.backends.opt_einsum.strategy = 'auto'。默认策略是 ‘auto’,我们也支持 ‘greedy’ 和 ‘optimal’。请注意,‘optimal’ 策略的运行时间是输入数量的阶乘!有关更多详细信息,请参阅 opt-einsum 文档 (https://optimized-einsum.readthedocs.io/en/stable/path_finding.html)。

注意

自 PyTorch 1.10 起,torch.einsum() 也支持子列表格式(参见下面的示例)。在这种格式中,每个操作数的下标由子列表([0, 52) 范围内的整数列表)指定。这些子列表跟在其操作数后面,并且可以在输入末尾出现一个额外的子列表来指定输出的下标,例如 torch.einsum(op1, sublist1, op2, sublist2, …, [subslist_out])。Python 的 Ellipsis 对象可以在子列表中提供,以实现上面方程表达式部分描述的广播功能。

参数
  • equation (str) – 爱因斯坦求和的下标表达式。

  • operands (List[张量]) – 用于计算爱因斯坦求和的张量列表。

返回类型

张量

示例

>>> # trace
>>> torch.einsum('ii', torch.randn(4, 4))
tensor(-1.2104)

>>> # diagonal
>>> torch.einsum('ii->i', torch.randn(4, 4))
tensor([-0.1034,  0.7952, -0.2433,  0.4545])

>>> # outer product
>>> x = torch.randn(5)
>>> y = torch.randn(4)
>>> torch.einsum('i,j->ij', x, y)
tensor([[ 0.1156, -0.2897, -0.3918,  0.4963],
        [-0.3744,  0.9381,  1.2685, -1.6070],
        [ 0.7208, -1.8058, -2.4419,  3.0936],
        [ 0.1713, -0.4291, -0.5802,  0.7350],
        [ 0.5704, -1.4290, -1.9323,  2.4480]])

>>> # batch matrix multiplication
>>> As = torch.randn(3, 2, 5)
>>> Bs = torch.randn(3, 5, 4)
>>> torch.einsum('bij,bjk->bik', As, Bs)
tensor([[[-1.0564, -1.5904,  3.2023,  3.1271],
        [-1.6706, -0.8097, -0.8025, -2.1183]],

        [[ 4.2239,  0.3107, -0.5756, -0.2354],
        [-1.4558, -0.3460,  1.5087, -0.8530]],

        [[ 2.8153,  1.8787, -4.3839, -1.2112],
        [ 0.3728, -2.1131,  0.0921,  0.8305]]])

>>> # with sublist format and ellipsis
>>> torch.einsum(As, [..., 0, 1], Bs, [..., 1, 2], [..., 0, 2])
tensor([[[-1.0564, -1.5904,  3.2023,  3.1271],
        [-1.6706, -0.8097, -0.8025, -2.1183]],

        [[ 4.2239,  0.3107, -0.5756, -0.2354],
        [-1.4558, -0.3460,  1.5087, -0.8530]],

        [[ 2.8153,  1.8787, -4.3839, -1.2112],
        [ 0.3728, -2.1131,  0.0921,  0.8305]]])

>>> # batch permute
>>> A = torch.randn(2, 3, 4, 5)
>>> torch.einsum('...ij->...ji', A).shape
torch.Size([2, 3, 5, 4])

>>> # equivalent to torch.nn.functional.bilinear
>>> A = torch.randn(3, 5, 4)
>>> l = torch.randn(2, 5)
>>> r = torch.randn(2, 4)
>>> torch.einsum('bn,anm,bm->ba', l, A, r)
tensor([[-0.3430, -5.2405,  0.4494],
        [ 0.3311,  5.5201, -3.0356]])

文档

访问 PyTorch 的全面开发者文档

查看文档

教程

获取针对初学者和高级开发者的深入教程

查看教程

资源

查找开发资源并获得问题解答

查看资源