tf.transpose()
2018-02-07 本文已影响0人
Perry_Wu
tf.transpose()
为转置函数,其中参数perm
用来设置需要转置的维度和顺序
img = np.array([
[[1, 2, 3], [4, 5, 6]],
[[7, 8, 9], [10, 11, 12]]
])
# img = img[np.newaxis, :]
l1 = tf.convert_to_tensor(img)
l2 = tf.contrib.layers.flatten(l1)
l3 = tf.transpose(l1, (1, 0, 2))
l4=tf.contrib.layers.flatten(l3)
with tf.Session() as sess:
out = sess.run(l4)
print out, out.shape
img
是一个2*2*3 (row*col*channel)
的图像矩阵,在内存中的存储顺序为:channel
=>col
=>row
,即从shape
的最后一个维度往前开始存储,对应的perm
为(0,1,2)
如果进行l3 = tf.transpose(l1, (0, 1, 2))
则矩阵不变
如果进行l3 = tf.transpose(l1, (1, 0, 2))
则对row
和col
进行转置,转置后,内存中的存储顺序改为:channel
=>row
=>col
,shape=(2,2,3)
如果进行l3 = tf.transpose(l1, (2, 0, 1))
则对先对row
和col
进行转置,再对col
和channel
进行转置,内存中的存储顺序改为:col
=>row
=>channel
,shape=(3,2,2)