面对不同维度大小矩阵乘法操作的处理(Tensorflow)
遇到的问题:
面对矩阵的大小不同的两个矩阵,其中一个矩阵如何根据另一个矩阵的要求实现相应的行或列缩放。目标效果如下所示:
x:(2,2,3)
[[[ 1. 2. 3.],
[ 4. 5. 6.]],
[[ 7. 8. 9.],
[10. 11. 12.]]]
w:(2,2)
[[0.5, 0.4],
[0.1, 0.2]]
x*w:(2,2,3)
[[[0.5 1. 1.5]
[1.6 2. 2.4]]
[[0.7 0.8 0.9]
[2. 2.2 2.4]]]
上面的效果,如果只利用点乘(w * x)
和乘法(tf.matmul(w, x))
操作是无法完成的,需要利用到矩阵的维度变换。具体处理流程为:
- 对w进行维度扩张
w(2,2) --> w(2,1,2) - 将x的第二维和第三维变换
x(2,2,3) -->x(2,3,2) - 这时候再进行矩阵
点乘
操作,才能得到上面的效果。
具体代码为:
# tensorflow的点乘
def test2():
# a = tf.constant([[1, 2, 3, 4], [5, 6, 7, 8]])
# [2, 2, 3] [2, 2]
a1 = np.array([[[1.0,2.0,3.0],[4.0,5.0,6.0]],
[[7.0,8.0,9.0], [10.0,11.0,12.0]]])
w = np.array([[0.5, 0.4],
[0.1, 0.2]])
a1 = tf.convert_to_tensor(a1)
w = tf.convert_to_tensor(w)
#
# y = w * a1
a_trans = tf.transpose(a1, [0, 2, 1])
w = tf.expand_dims(w, 1)
y = tf.multiply(a_trans, w)
y = tf.transpose(y, [0, 2, 1])
# y = a_trans*w
with tf.Session() as sess:
init = tf.global_variables_initializer()
sess.run(init)
print("x:")
print(sess.run(a1))
print("w")
print(sess.run(w))
print("x*w:")
print(sess.run(y))
test2()
注意事项:
(1) 点乘,只有在w的列为1或与x的列相等时,才能进行点乘运算;
(2) 乘法, 只有前一个矩阵的最后一维和后面一个矩阵的第一维相等时,才能进行乘法操作;
打个小广告: 欢迎关注本人github: https://github.com/wuxiaoxiaoer
随时会有新想法,或技术更新,尤其是假新闻方面的研究。