tensorflow cnn常用函数解析(一)

2018-09-21  本文已影响0人  YG_9013

1、conv2d

tf.nn.conv2d(input, filter, strides, padding, use_cudnn_on_gpu=None, name=None)
除去name参数用以指定该操作的name,与方法有关的一共五个参数:

第一个参数input:指需要做卷积的输入图像,它要求是一个Tensor,具有[batch, in_height, in_width, in_channels]这样的shape,具体含义是[训练时一个batch的图片数量, 图片高度, 图片宽度, 图像通道数],注意这是一个4维的Tensor,要求类型为float32和float64其中之一

第二个参数filter:相当于CNN中的卷积核,它要求是一个Tensor,具有[filter_height, filter_width, in_channels, out_channels]这样的shape,具体含义是[卷积核的高度,卷积核的宽度,图像通道数,卷积核个数],要求类型与参数input相同,有一个地方需要注意,第三维in_channels,就是参数input的第四维

第三个参数strides:卷积时在图像每一维的步长,这是一个一维的向量,长度4

第四个参数padding:string类型的量,只能是"SAME","VALID"其中之一,这个值决定了不同的卷积方式。VALID表示原数据不补充0,SAME表示卷积后的结果sharp通过补充0与卷积之前的相同。

第五个参数:use_cudnn_on_gpu:bool类型,是否使用cudnn加速,默认为true

结果返回一个Tensor,这个输出,就是我们常说的feature map。

import tensorflow as tf
#case 2  sharp(1,3,3,1)
input = tf.Variable(tf.random_normal([1,3,3,5]))
filter = tf.Variable(tf.random_normal([1,1,5,1]))
 
op2 = tf.nn.conv2d(input, filter, strides=[1, 1, 1, 1], padding='VALID')
#case 3  sharp(1,1,1,1)
input = tf.Variable(tf.random_normal([1,3,3,5]))
filter = tf.Variable(tf.random_normal([3,3,5,1]))
 
op3 = tf.nn.conv2d(input, filter, strides=[1, 1, 1, 1], padding='VALID')
#case 4 sharp(1,3,3,1)
input = tf.Variable(tf.random_normal([1,5,5,5]))
filter = tf.Variable(tf.random_normal([3,3,5,1]))
 
op4 = tf.nn.conv2d(input, filter, strides=[1, 1, 1, 1], padding='VALID')
#case 5 sharp(1,5,5,1)
input = tf.Variable(tf.random_normal([1,5,5,5]))
filter = tf.Variable(tf.random_normal([3,3,5,1]))
 
op5 = tf.nn.conv2d(input, filter, strides=[1, 1, 1, 1], padding='SAME')
#case 6 sharp(1,5,5,7)
input = tf.Variable(tf.random_normal([1,5,5,5]))
filter = tf.Variable(tf.random_normal([3,3,5,7]))
 
op6 = tf.nn.conv2d(input, filter, strides=[1, 1, 1, 1], padding='SAME')
#case 7 sharp(1,3,3,7)
input = tf.Variable(tf.random_normal([1,5,5,5]))
filter = tf.Variable(tf.random_normal([3,3,5,7]))
 
op7 = tf.nn.conv2d(input, filter, strides=[1, 2, 2, 1], padding='SAME')
#case 8 sharp(10,3,3,7)
input = tf.Variable(tf.random_normal([10,5,5,5]))
filter = tf.Variable(tf.random_normal([3,3,5,7]))
 
op8 = tf.nn.conv2d(input, filter, strides=[1, 2, 2, 1], padding='SAME')
 

same:
补零个数:(num_weight_zero + weight_size) % (weight_stride) = 0
conv weight shape:  向上取整weight_size/weight_stride

2、max_pool

tf.nn.max_pool(value, ksize, strides, padding, name=None)
参数是四个,和卷积很类似:
第一个参数value:需要池化的输入,一般池化层接在卷积层后面,所以输入通常是feature map,依然是[batch, height, width, channels]这样的shape

第二个参数ksize:池化窗口的大小,取一个四维向量,一般是[1, height, width, 1],因为我们不想在batch和channels上做池化,所以这两个维度设为了1

第三个参数strides:和卷积类似,窗口在每一个维度上滑动的步长,一般也是[1, stride,stride, 1]

第四个参数padding:和卷积类似,可以取'VALID' 或者'SAME'

返回一个Tensor,类型不变,shape仍然是[batch, height, width, channels]这种形式

import tensorflow as tf
 
a=tf.constant([
        [[1.0,2.0,3.0,4.0],
        [5.0,6.0,7.0,8.0],
        [8.0,7.0,6.0,5.0],
        [4.0,3.0,2.0,1.0]],
        [[4.0,3.0,2.0,1.0],
         [8.0,7.0,6.0,5.0],
         [1.0,2.0,3.0,4.0],
         [5.0,6.0,7.0,8.0]]
    ])
 
a=tf.reshape(a,[1,4,4,2])
 
pooling=tf.nn.max_pool(a,[1,2,2,1],[1,1,1,1],padding='VALID')
with tf.Session() as sess:
    print("reslut:")
    result=sess.run(pooling)
    print (result)

reslut:
[[[[ 8.  7.]
   [ 6.  6.]
   [ 7.  8.]]
 
  [[ 8.  7.]
   [ 8.  7.]
   [ 8.  7.]]
 
  [[ 4.  4.]
   [ 8.  7.]
   [ 8.  8.]]]]

3、concat

tf.concat(concat_dim, values, name='concat')
除去name参数用以指定该操作的name,与方法有关的一共两个参数:
第一个参数concat_dim:必须是一个数,表明在哪一维上连接

如果concat_dim是0,那么在某一个shape的第一个维度上连,对应到实际,就是叠放到列上

t1 = [[1, 2, 3], [4, 5, 6]]
t2 = [[7, 8, 9], [10, 11, 12]]
tf.concat(0, [t1, t2]) == > [[1, 2, 3], [4, 5, 6], [7, 8, 9], [10, 11, 12]]

如果concat_dim是1,那么在某一个shape的第二个维度上连

t1 = [[1, 2, 3], [4, 5, 6]]
t2 = [[7, 8, 9], [10, 11, 12]]
tf.concat(1, [t1, t2]) ==> [[1, 2, 3, 7, 8, 9], [4, 5, 6, 10, 11, 12

如果有更高维,最后连接的依然是指定那个维:

 values[i].shape = [D0, D1, ... Dconcat_dim(i), ...Dn]连接后就是:[D0, D1, ... Rconcat_dim, ...Dn]
# tensor t3 with shape [2, 3]
# tensor t4 with shape [2, 3]
tf.shape(tf.concat(0, [t3, t4])) ==> [4, 3]
tf.shape(tf.concat(1, [t3, t4])) ==> [2, 6]

【参考链接】:
1、https://blog.csdn.net/mao_xiao_feng/article/details/53444333
2、https://blog.csdn.net/mao_xiao_feng/article/details/53453926
3、https://blog.csdn.net/mao_xiao_feng/article/details/53366163

上一篇下一篇

猜你喜欢

热点阅读