补充:tensor之间进行矩阵相乘的方法总结
利用@进行简单的矩阵乘
@符号在tensor中就表示矩阵相乘,@符号的矩阵相乘性质在numpy中依然适用。
- 首先矩阵相乘的双方必须满足可以矩阵相乘的条件
- @只会关注两个矩阵最里面的两个维度是否符合条件,外面的维度都只表示矩阵运算的次数,甚至两个矩阵只要满足广播的条件和里面两个维度可以进行矩阵乘,二者的维度都可以不一样。
torch.mul
一定要注意这个函数是陷阱!其与*的作用是完全一样的,其不管相乘的双方维度如何,执行的都是对位相乘的操作, *与torch.mul均不能实现矩阵相乘的规则。
torch.mm
torch.mm是阉割版的@,其只能对二维的tensor进行矩阵相乘,高了的维度其不会进行广播 ↓
a=torch.ones((2,3))
b=torch.ones((3,4))
print(torch.mm(a,b))
'''
tensor([[3., 3., 3., 3.],
[3., 3., 3., 3.]])
'''
torch.matmul
其作用与@完全相同 ↓
a=torch.ones((1,2,2,3))
b=torch.ones((2,1,3,4))
print(torch.matmul(a,b))
'''
tensor([[[[3., 3., 3., 3.],
[3., 3., 3., 3.]],
[[3., 3., 3., 3.],
[3., 3., 3., 3.]]],
[[[3., 3., 3., 3.],
[3., 3., 3., 3.]],
[[3., 3., 3., 3.],
[3., 3., 3., 3.]]]])
'''
总结:就认准@就可以了