深度学习研究所

tf.nn.conv2d中传入的第一二个参数的数据格式问题

2017-09-17  本文已影响0人  西方失败9527

看到知乎上这样一个问题:

下面这个图所示,输入数据是一个2个通道3*3的数据,过滤器是一个具有两个通道的2*2的数据,按照一般卷积过程,即如果所示结果是一个通道的2*2的数据。

但是在tensorflow中,我们如下实现:

k = tf.constant([ 1,2 ,3,4,

                            5,6,7,8], dtype=tf.float32, name='k')

i = tf.constant([

                       1, 3, 5,

                       1, 3, 5,

                       1, 3, 5,

                        2, 4, 6,

                        2, 4, 6,

                        2, 4, 6

                        ], dtype=tf.float32, name='i')

kernel = tf.reshape(k, [2, 2, 2, 1], name='kernel')

image  = tf.reshape(i, [1, 3, 3, 2], name='image')

#res = tf.nn.conv2d(image, kernel, [1, 1, 1, 1], "VALID")

res = tf.squeeze(tf.nn.conv2d(image, kernel, [1, 1, 1, 1], "VALID"))# VALID means no padding

with tf.Session() as sess:

            print(sess.run(res))

结果不对

原因原来是data_format 参数的问题,图像数据格式定义了一批图片数据的存储顺序。在调用 TensorFlow API 时会经常看到 data_format 参数:

data_format 默认值为 "NHWC",也可以手动设置为 "NCHW"。这个参数规定了 input Tensor 和 output Tensor 的排列方式。

data_format 设置为 "NHWC" 时,排列顺序为 [batch, height, width, channels];

                      设置为 "NCHW" 时,排列顺序为 [batch, channels, height, width]。

其中 N 表示这批图像有几张,H 表示图像在竖直方向有多少像素,W 表示水平方向像素数,C 表示通道数(例如黑白图像的通道数 C = 1,而 RGB 彩色图像的通道数 C = 3)。为了便于演示,我们后面作图均使用 RGB 三通道图像。两种格式的区别如下图所示:

NCHW 中,C 排列在外层,每个通道内像素紧挨在一起,即 'RRRRRRGGGGGGBBBBBB' 这种形式。

NHWC 格式,C 排列在最内层,多个通道对应空间位置的像素紧挨在一起,即 'RGBRGBRGBRGBRGBRGB' 这种形式。

于是我们的程序中将数据顺序修改即可:

k = tf.constant([

1, 5,

2, 6,

3, 7,

4, 8

], dtype=tf.float32, name='k')

i = tf.constant([

1, 2, 3,

4, 5, 6,

1, 2, 3,

4, 5, 6,

1, 2, 3,

4, 5, 6

], dtype=tf.float32, name='i')

kernel = tf.reshape(k, [2, 2, 2, 1], name='kernel')

image  = tf.reshape(i, [1, 3, 3, 2], name='image')

#res = tf.nn.conv2d(image, kernel, [1, 1, 1, 1], "VALID")

res = tf.squeeze(tf.nn.conv2d(image, kernel, [1, 1, 1, 1], "VALID"))# VALID means no padding

with tf.Session() as sess:

            print(sess.run(image))

            print("------------------")

            print(sess.run(kernel))

            print("------------------")

           print(sess.run(res))

最终能如愿以偿得到如图右边的结果。不过feature map的172应该改为174,手算也该如此

主要参考:http://mp.weixin.qq.com/s/I4Q1Bv7yecqYXUra49o7tw

上一篇下一篇

猜你喜欢

热点阅读