tf.concat用法浅析

2018-10-31  本文已影响0人  陈晓峥

tf.concat(concat_dim, values, name='concat')

第一个参数concat_dim:必须是一个数,表明在哪一维上连接

values 代表你要连接的矩阵

废话不多说 直接上代码

t1 = [[[[1, 2, 3], [4, 5, 6],[20, 21, 22]]]]

t2 = [[[[7, 8, 9], [10, 11, 12], [25, 26, 27]]]]

t3 = [[[[13, 14, 15], [16, 17, 18], [28, 29, 20]]]]

print(tf.Variable(t1).shape)   #(1, 1, 3, 3)

print(tf.Variable(t2).shape)   #(1, 1, 3, 3)

print(tf.Variable(t3).shape)  #(1, 1, 3, 3)

data = tf.Variable(tf.concat([t1, t2, t3], 0))     # 0 代表我将t1的第一个index相加 为 3

data1 = tf.Variable(tf.concat([t1, t2, t3], 3))   # 3 代表我将t1的第四个index相加  为9

init = tf.global_variables_initializer()

with tf.Session() as sess:

sess.run(init)

print("test:\n", sess.run(data), "data =", data.shape)    # (3, 1, 3, 3)

print("test:\n", sess.run(data1), "data1 =", data1.shape) #(1, 1, 3, 9)


如果tf.concat 超过了矩阵的长度如将data1 = tf.Variable(tf.concat([t1, t2, t3], 3)) 改为

data1 = tf.Variable(tf.concat([t1, t2, t3], 4)) 则会报错

报错信息为

ValueError: Shape must be at least rank 5 but is rank 4 for 'concat_1' (op: 'ConcatV2') with input shapes: [1,1,3,3], [1,1,3,3], [1,1,3,3], [] and with computed input tensors: input[3] = <4>.

大概意思实 你的t1 t2 t3 的张量最大是到第四个维度, 而你输入4 代表要对张量的第五个维度进行相加,类似于数组越界,所以报错

如有问题欢迎大家指正,谢谢

上一篇下一篇

猜你喜欢

热点阅读