torch.chain_matmul¶
- torch.chain_matmul(*matrices, out=None)[源代码]¶
返回 个 2D 张量的矩阵乘积。此乘积使用矩阵链排序算法有效地计算,该算法选择导致算术运算成本最低的顺序([CLRS])。请注意,由于这是一个计算乘积的函数,因此 需要大于或等于 2;如果等于 2,则返回一个简单的矩阵-矩阵乘积。如果 等于 1,则这是一个无操作 - 原矩阵按原样返回。
警告
torch.chain_matmul()
已弃用,将在未来的 PyTorch 版本中删除。请改用torch.linalg.multi_dot()
,它接受一个包含两个或多个张量的列表,而不是多个参数。- 参数
matrices (张量...) – 要确定其乘积的两个或多个 2D 张量的序列。
out (张量, 可选) – 输出张量。如果
out
=None
,则忽略。
- 返回值
如果第 个张量维度为 ,则乘积的维度将为 。
- 返回类型
示例
>>> a = torch.randn(3, 4) >>> b = torch.randn(4, 5) >>> c = torch.randn(5, 6) >>> d = torch.randn(6, 7) >>> # will raise a deprecation warning >>> torch.chain_matmul(a, b, c, d) tensor([[ -2.3375, -3.9790, -4.1119, -6.6577, 9.5609, -11.5095, -3.2614], [ 21.4038, 3.3378, -8.4982, -5.2457, -10.2561, -2.4684, 2.7163], [ -0.9647, -5.8917, -2.3213, -5.2284, 12.8615, -12.2816, -2.5095]])