tf.concat()

2017-09-21  本文已影响0人  sterio
tf.concat(
    values,
    axis, # concat_dim
    name='concat'
)

axis = 0 means concatenating in 1st dim.
axis = 1 means concatenating in 2nd dim.
axis = n-1 means concatenating in n th dim.

import tensorflow as tf
t1 = [[1, 2, 3], [4, 5, 6]]
t2 = [[7, 8, 9], [10, 11, 12]]
sess = tf.Session()

a1 = tf.concat([t1, t2], 0) # concatenate in 1st dim, row
a2 = tf.concat([t1, t2], 1) # concatenate in 2nd dim, column
a3 = tf.concat([t1, t2], 2) # concatenate in 3rd dim, which does not exist.
print(sess.run(c1))
print("\n")

print(sess.run(c2))
[[ 1  2  3]
 [ 4  5  6]
 [ 7  8  9]
 [10 11 12]]


[[ 1  2  3  7  8  9]
 [ 4  5  6 10 11 12]]

a3 results in Error as expected here:

ValueError: Shape must be at least rank 3 but is rank 2 for 'concat_14' (op: 'ConcatV2') with input shapes: [2,3], [2,3], [].

Similarly, numpy.concatenate() results the same as tf.concat()

import numpy as np

t1 = [[1, 2, 3], [4, 5, 6]]
t2 = [[7, 8, 9], [10, 11, 12]]

b1 = np.concatenate((t1, t2), 0)
b2 = np.concatenate((t1, t2), 1)

print(b1)
print("\n")
print(b2)
[[ 1  2  3]
 [ 4  5  6]
 [ 7  8  9]
 [10 11 12]]


[[ 1  2  3  7  8  9]
 [ 4  5  6 10 11 12]]
上一篇 下一篇

猜你喜欢

热点阅读