torch.matmul¶
- torch.matmul(input, other, *, out=None) Tensor ¶
两个张量的矩阵乘积。
行为取决于张量的维度,具体如下:
如果两个张量都是 1 维,则返回点积(标量)。
如果两个参数都是 2 维,则返回矩阵-矩阵乘积。
如果第一个参数是 1 维,第二个参数是 2 维,则在矩阵乘法之前,在第一个参数的维度前面添加一个 1。矩阵乘法之后,移除添加的维度。
如果第一个参数是 2 维,第二个参数是 1 维,则返回矩阵-向量乘积。
如果两个参数至少都是 1 维,并且至少有一个参数是 N 维(N > 2),则返回批处理矩阵乘法。如果第一个参数是 1 维,则在批处理矩阵乘法之前,在其维度前面添加一个 1,之后移除。如果第二个参数是 1 维,则在批处理矩阵乘法之前,在其维度后面添加一个 1,之后移除。非矩阵(即批处理)维度会被广播(因此必须是可广播的)。例如,如果
input
是一个 形状的张量,other
是一个 形状的张量,则out
将是 形状的张量。请注意,广播逻辑在确定输入是否可广播时,仅查看批处理维度,而不查看矩阵维度。例如,如果
input
是一个 形状的张量,other
是一个 形状的张量,即使最后两个维度(即矩阵维度)不同,这些输入对于广播也是有效的。out
将是 形状的张量。
此操作支持具有稀疏布局的参数。特别是矩阵-矩阵乘法(两个参数都是 2 维)支持具有与
torch.mm()
相同限制的稀疏参数。警告
稀疏支持是 Beta 特性,某些 layout(s)/dtype/device 组合可能不受支持,或者可能不支持 autograd。如果您发现缺少功能,请提交 feature request。
此操作符支持 TensorFloat32。
在某些 ROCm 设备上,当使用 float16 输入时,此模块将在 backward 过程中使用不同的精度。
注意
此函数的 1 维点积版本不支持
out
参数。示例
>>> # vector x vector >>> tensor1 = torch.randn(3) >>> tensor2 = torch.randn(3) >>> torch.matmul(tensor1, tensor2).size() torch.Size([]) >>> # matrix x vector >>> tensor1 = torch.randn(3, 4) >>> tensor2 = torch.randn(4) >>> torch.matmul(tensor1, tensor2).size() torch.Size([3]) >>> # batched matrix x broadcasted vector >>> tensor1 = torch.randn(10, 3, 4) >>> tensor2 = torch.randn(4) >>> torch.matmul(tensor1, tensor2).size() torch.Size([10, 3]) >>> # batched matrix x batched matrix >>> tensor1 = torch.randn(10, 3, 4) >>> tensor2 = torch.randn(10, 4, 5) >>> torch.matmul(tensor1, tensor2).size() torch.Size([10, 3, 5]) >>> # batched matrix x broadcasted matrix >>> tensor1 = torch.randn(10, 3, 4) >>> tensor2 = torch.randn(4, 5) >>> torch.matmul(tensor1, tensor2).size() torch.Size([10, 3, 5])