大数据,机器学习,人工智能人工智能微刊人工智能(语言识别&图像识别)

Tensorflow CNN卷积的理解

2018-08-20  本文已影响1人  YANWeichuan

Tensorflow中的卷积函数

tf.nn.conv2d(input, filter, strides, padding, use_cudnn_on_gpu=None, name=None)

CNN中卷积的理解

这篇文章对CNN卷积的介绍较为详细, https://buptldy.github.io/2016/10/01/2016-10-01-im2col/

文中的两个图比较直观,图二中,input features,可以理解为一个3x3图片的3通道数据,卷积核为2x2大小三通道 2个卷积核。单通道的图片数据转化为矩阵时,是按照卷积核的大小展开为一维向量。图一的理解,则需要交换input和kernel的位置,output map逆时针旋转90度。

图一
图二

实践

用tensorflow的conv2d来验证图2的数据,其关键是输入和filter的数据的构造。

filter的维数[2, 2, 3, 2], 2x2维,3通道,2个卷积核,构造出的数据,不能直接按照卷积核的原始数据reshape得道,而是要重新排列,排列顺序如图。如果要到如图中第一个卷积核的数据则用filter_arg[:, :, 0, 0](也可见代码中的注释部分,通过直接赋值构建filter的值)。同理可以构造输入图片的数据,最后执行的结果和图中一致。

数据构建
import numpy as np
import tensorflow as tf

input_arg = tf.constant([
    [[1, 0, 1], [2, 2, 2], [0, 1, 1]],
    [[1, 0, 0], [1, 3, 1], [3, 2, 3]],
    [[0, 1, 3], [2, 1, 3], [2, 0, 2]]], dtype = tf.float32)

# [filter_height, filter_width, in_channels, out_channels]
#filter_arg = np.array(np.arange(24)).astype("float32")
#filter_arg = np.reshape(filter_arg, [2, 2, 3, 2]).astype("float32")
#filter_arg[:, :, 0, 0] =  [[1, 1], [2, 2]]
#filter_arg[:, :, 1, 0] =  [[1, 1], [1, 1]] 
#filter_arg[:, :, 2, 0] =  [[0, 1], [1, 0]]
#filter_arg[:, :, 0, 1] =  [[1, 0], [0, 1]]
#filter_arg[:, :, 1, 1] =  [[2, 1], [2, 1]] 
#filter_arg[:, :, 2, 1] =  [[1, 2], [2, 0]] 
filter_arg = tf.constant([
    [[[1, 1],
    [1, 2],
    [0, 1]],

    [[1, 0],
    [1, 1],
    [1, 2]]],

    [[[2, 0],
    [1, 2],
    [1, 2]],

    [[2, 1],
    [1, 1],
    [0, 0]]]], dtype = tf.float32)

input_arg_normal = tf.reshape(input_arg, [1, 3, 3, 3])
filter_arg_normal = tf.reshape(filter_arg, [2, 2, 3, 2])
op1 = tf.squeeze(tf.nn.conv2d(input_arg_normal, filter_arg_normal, strides=[1,1,1,1], use_cudnn_on_gpu=False, padding='VALID'))

with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    input_out = sess.run(input_arg_normal)
    print(input_out.shape)
    print(input_out)
    print("*" * 20)
    filter_out = sess.run(filter_arg_normal)
    print(filter_out.shape)
    print(filter_out)
    print("*" * 20)
    print(sess.run(op1))

运行结果

(1, 3, 3, 3)
[[[[1. 0. 1.]
   [2. 2. 2.]
   [0. 1. 1.]]

  [[1. 0. 0.]
   [1. 3. 1.]
    [0, 1]],
   [3. 2. 3.]]

  [[0. 1. 3.]
   [2. 1. 3.]
   [2. 0. 2.]]]]
********************
(2, 2, 3, 2)
[[[[1. 1.]
   [1. 2.]
   [0. 1.]]

  [[1. 0.]
   [1. 1.]
   [1. 2.]]]


 [[[2. 0.]
   [1. 2.]
   [1. 2.]]

  [[2. 1.]
   [1. 1.]
   [0. 0.]]]]
********************
[[[14. 12.]
  [20. 24.]]

 [[15. 17.]
  [24. 26.]]]

参考文章

http://cs231n.github.io/convolutional-networks/

经典卷积计算图
上一篇下一篇

猜你喜欢

热点阅读