tensorflow的广播相乘理解 2018-06-06

2018-06-06  本文已影响0人  美联储

1、tensorflow

import tensorflow as tf
sess = tf.Session()
xx = tf.constant(2, shape=[5,1,1], dtype=tf.float32)
yy = tf.constant(1, shape=[1,2,8], dtype=tf.float32)
zz = xx * yy
print(sess.run([zz]))

广播相乘输出结果:

[array([[[2., 2., 2., 2., 2., 2., 2., 2.],
        [2., 2., 2., 2., 2., 2., 2., 2.]],

       [[2., 2., 2., 2., 2., 2., 2., 2.],
        [2., 2., 2., 2., 2., 2., 2., 2.]],

       [[2., 2., 2., 2., 2., 2., 2., 2.],
        [2., 2., 2., 2., 2., 2., 2., 2.]],

       [[2., 2., 2., 2., 2., 2., 2., 2.],
        [2., 2., 2., 2., 2., 2., 2., 2.]],

       [[2., 2., 2., 2., 2., 2., 2., 2.],
        [2., 2., 2., 2., 2., 2., 2., 2.]]], dtype=float32)]

两个值先各自扩展为[5,2,8]再点乘

2、numpy,跟tensorflow是一样的

a =[[[9],
  [1],
  [4],
  [7]],

 [[9],
  [6],
  [5],
  [8]]]

a= np.array(a) # shape(2,4, 1)

b= [[2,6,7,1,6]]
b= np.array(b)  # shape(1, 5)
#广播相乘
print(a*b)

输出结果:

[[[9]
[1]
[4]
[7]]

[[9]
[6]
[5]
[8]]]
[[2 6 7 1 6]]
[[[18 54 63 9 54]
[ 2 6 7 1 6]
[ 8 24 28 4 24]
[14 42 49 7 42]]

[[18 54 63 9 54]
[12 36 42 6 36]
[10 30 35 5 30]
[16 48 56 8 48]]]

运算过程等同于:

c =[[[9,9,9,9,9],
  [1,1,1,1,1],
  [4,4,4,4,4],
  [7,7,7,7,7]],

 [[9,9,9,9,9],
  [6,6,6,6,6],
  [5,5,5,5,5],
  [8,8,8,8,8]]]
c= np.array(c)  # shape(2,4, 5)

d =[[[2,6,7,1,6],
    [2,6,7,1,6],
    [2,6,7,1,6],
    [2,6,7,1,6]],

    [[2,6,7,1,6],
    [2,6,7,1,6],
    [2,6,7,1,6],
    [2,6,7,1,6]]
    ]

d= np.array(d)  # shape(2,4, 5)
print(c*d) #点对点相乘 

输出结果:

[[[18 54 63 9 54]
[ 2 6 7 1 6]
[ 8 24 28 4 24]
[14 42 49 7 42]]

[[18 54 63 9 54]
[12 36 42 6 36]
[10 30 35 5 30]
[16 48 56 8 48]]]

上一篇下一篇

猜你喜欢

热点阅读