tf.transpose()

2018-02-07  本文已影响0人  Perry_Wu

tf.transpose()为转置函数,其中参数perm用来设置需要转置的维度和顺序

img = np.array([
    [[1, 2, 3], [4, 5, 6]],
    [[7, 8, 9], [10, 11, 12]]
])

# img = img[np.newaxis, :]
l1 = tf.convert_to_tensor(img)
l2 = tf.contrib.layers.flatten(l1)
l3 = tf.transpose(l1, (1, 0, 2))
l4=tf.contrib.layers.flatten(l3)

with tf.Session() as sess:
    out = sess.run(l4) 
    print out, out.shape

img是一个2*2*3 (row*col*channel)的图像矩阵,在内存中的存储顺序为:channel=>col=>row,即从shape的最后一个维度往前开始存储,对应的perm(0,1,2)

如果进行l3 = tf.transpose(l1, (0, 1, 2))则矩阵不变

如果进行l3 = tf.transpose(l1, (1, 0, 2))则对rowcol进行转置,转置后,内存中的存储顺序改为:channel=>row=>colshape=(2,2,3)

如果进行l3 = tf.transpose(l1, (2, 0, 1))则对先对rowcol进行转置,再对colchannel进行转置,内存中的存储顺序改为:col=>row=>channelshape=(3,2,2)

上一篇 下一篇

猜你喜欢

热点阅读