tl.layers.SlimNetsLayer使用

2018-07-19  本文已影响36人  yalesaleng
from tensorflow.contrib.slim.python.slim.nets.resnet_v1 import (resnet_v1_50, resnet_arg_scope)
from tensorflow.contrib.slim.python.slim.nets.vgg import (vgg_16, vgg_arg_scope)

tra_patch = tl.layers.InputLayer(tra_patch, name='tra_Inputlayer')
with slim.arg_scope(resnet_arg_scope()):
    tra_network = tl.layers.SlimNetsLayer(layer=tra_patch,
                                          slim_layer=resnet_v1_50,
                                          slim_args={
                                            'num_classes': num_classes,
                                            'is_training': True,
                                            'global_pool': True,
                                           },
                                          name='resnet_v1_50')

val_patch = tl.layers.InputLayer(val_patch, name='val_Inputlayer')
with slim.arg_scope(resnet_arg_scope(set_name_reuse())):
    val_network = tl.layers.SlimNetsLayer(layer=val_patch,
                                          slim_layer=resnet_v1_50,
                                          slim_args={
                                            'num_classes': num_classes,
                                            'is_training': False,
                                            'global_pool': True,
                                            'reuse': True
                                          },
                                          name='resnet_v1_50')

tra_logits = tra_network.outputs
discard dim==1
tra_logits = tf.squeeze(tra_logits)
val_logits = val_network.outputs
val_logits = tf.squeeze(val_logits)
上一篇 下一篇

猜你喜欢

热点阅读