补充:tensor之间进行矩阵相乘的方法总结

利用@进行简单的矩阵乘

@符号在tensor中就表示矩阵相乘,@符号的矩阵相乘性质在numpy中依然适用。

  • 首先矩阵相乘的双方必须满足可以矩阵相乘的条件
  • @只会关注两个矩阵最里面的两个维度是否符合条件,外面的维度都只表示矩阵运算的次数,甚至两个矩阵只要满足广播的条件和里面两个维度可以进行矩阵乘,二者的维度都可以不一样。

torch.mul

定要注意这个函数是陷阱!其与*的作用是完全一样的,其不管相乘的双方维度如何,执行的都是对位相乘的操作, *与torch.mul均不能实现矩阵相乘的规则。

torch.mm

torch.mm是阉割版的@,其只能对二维的tensor进行矩阵相乘,高了的维度其不会进行广播 ↓

a=torch.ones((2,3))
b=torch.ones((3,4))
print(torch.mm(a,b))
'''
tensor([[3., 3., 3., 3.],
        [3., 3., 3., 3.]])
'''

torch.matmul

其作用与@完全相同 ↓

a=torch.ones((1,2,2,3))
b=torch.ones((2,1,3,4))
print(torch.matmul(a,b))
'''
tensor([[[[3., 3., 3., 3.],
          [3., 3., 3., 3.]],

         [[3., 3., 3., 3.],
          [3., 3., 3., 3.]]],


        [[[3., 3., 3., 3.],
          [3., 3., 3., 3.]],

         [[3., 3., 3., 3.],
          [3., 3., 3., 3.]]]])
          '''

总结:就认准@就可以了