tensorflow的广播相乘理解 2018-06-06
2018-06-06 本文已影响0人
美联储
1、tensorflow
import tensorflow as tf
sess = tf.Session()
xx = tf.constant(2, shape=[5,1,1], dtype=tf.float32)
yy = tf.constant(1, shape=[1,2,8], dtype=tf.float32)
zz = xx * yy
print(sess.run([zz]))
广播相乘输出结果:
[array([[[2., 2., 2., 2., 2., 2., 2., 2.],
[2., 2., 2., 2., 2., 2., 2., 2.]],
[[2., 2., 2., 2., 2., 2., 2., 2.],
[2., 2., 2., 2., 2., 2., 2., 2.]],
[[2., 2., 2., 2., 2., 2., 2., 2.],
[2., 2., 2., 2., 2., 2., 2., 2.]],
[[2., 2., 2., 2., 2., 2., 2., 2.],
[2., 2., 2., 2., 2., 2., 2., 2.]],
[[2., 2., 2., 2., 2., 2., 2., 2.],
[2., 2., 2., 2., 2., 2., 2., 2.]]], dtype=float32)]
两个值先各自扩展为[5,2,8]再点乘
2、numpy,跟tensorflow是一样的
a =[[[9],
[1],
[4],
[7]],
[[9],
[6],
[5],
[8]]]
a= np.array(a) # shape(2,4, 1)
b= [[2,6,7,1,6]]
b= np.array(b) # shape(1, 5)
#广播相乘
print(a*b)
输出结果:
[[[9]
[1]
[4]
[7]]
[[9]
[6]
[5]
[8]]]
[[2 6 7 1 6]]
[[[18 54 63 9 54]
[ 2 6 7 1 6]
[ 8 24 28 4 24]
[14 42 49 7 42]]
[[18 54 63 9 54]
[12 36 42 6 36]
[10 30 35 5 30]
[16 48 56 8 48]]]
运算过程等同于:
c =[[[9,9,9,9,9],
[1,1,1,1,1],
[4,4,4,4,4],
[7,7,7,7,7]],
[[9,9,9,9,9],
[6,6,6,6,6],
[5,5,5,5,5],
[8,8,8,8,8]]]
c= np.array(c) # shape(2,4, 5)
d =[[[2,6,7,1,6],
[2,6,7,1,6],
[2,6,7,1,6],
[2,6,7,1,6]],
[[2,6,7,1,6],
[2,6,7,1,6],
[2,6,7,1,6],
[2,6,7,1,6]]
]
d= np.array(d) # shape(2,4, 5)
print(c*d) #点对点相乘
输出结果:
[[[18 54 63 9 54]
[ 2 6 7 1 6]
[ 8 24 28 4 24]
[14 42 49 7 42]]
[[18 54 63 9 54]
[12 36 42 6 36]
[10 30 35 5 30]
[16 48 56 8 48]]]