[tf]incompatible with the layer:

2019-02-19  本文已影响2人  VanJordan

解决方法,用个tf.reshape指定一下形状就可以了,原来报错的地方时dense_1需要确定的维度,但是没有,可能的原因是因为我用tf.data.from_generator进行读取的时候获得不了指定的维度(tf的一个bug),所以我们在传进全连接层的时候做一个reshape操作return self.mlp(tf.reshape(h,[-1,self.embedding_size])),就可以了。

class CNN(object):
  def __init__(self,embedding_size,keep_prob):
    out_unit = embedding_size//3
    self.cnn_w3 = tf.layers.Conv2D(out_unit,kernel_size=[3,1],strides=[1,1],padding='same',use_bias=False)
    self.cnn_w4 = tf.layers.Conv2D(out_unit,kernel_size=[4,1],strides=[1,1],padding='same',use_bias=False)
    self.cnn_w5 = tf.layers.Conv2D(out_unit,kernel_size=[5,1],strides=[1,1],padding='same',use_bias=False)
    self.mlp = tf.layers.Dense(units=embedding_size,activation='relu')
    self.keep_prob = keep_prob
    self.embedding_size = embedding_size
  
  def __call__(self,x,is_training=True):
    x = tf.expand_dims(x,2)
    # ddaa = self.cnn_w3(x)
    # print(ddaa)
    h_w3 = tf.squeeze(tf.reduce_max(self.cnn_w3(x),axis=1))
    # print(h_w3)
    # raise ValueError('asdfjasdkf')
    h_w4 = tf.squeeze(tf.reduce_max(self.cnn_w4(x),axis=1))
    h_w5 = tf.squeeze(tf.reduce_max(self.cnn_w5(x),axis=1))
    h = tf.concat([h_w3,h_w4,h_w5],1)
    h = tf.nn.relu(h)
    if self.keep_prob < 1 and is_training:
      h = tf.nn.dropout(h,self.keep_prob)
    # print(h)
    # raise ValueError('asdfjkasdf')
    return self.mlp(tf.reshape(h,[-1,self.embedding_size]))
    # return h
上一篇下一篇

猜你喜欢

热点阅读