Keras使用问题记录

2019-10-21  本文已影响0人  点点渔火

1, 模型结构保存:

    def get_config(self):
        config = super(DilatedGatedConv1D, self).get_config()
        config.update(
            {
                'o_dim': self.o_dim,
                'k_size': self.k_size,
                'rate': self.rate,
                'skip_connect': self.skip_connect,
                'drop_gate': self.drop_gate
             }
        )
        return config

如果初始化参数也是一个Layer网络层, Layer对象本身不能序列化, 这就要求重新实现get_config()和from_config()两个方法,实现包含层Layer的序列化反序列化, 参考Bidirectional,Wrapper的实现

    def get_config(self):
        """
        参数的序列化操作
        :return:
        """
        config = super(OurBidirectional, self).get_config()
        config.update(
            {
                'layer': {       # 参照Wrapper 不能直接保留类对象
                    'class_name': self.layer.__class__.__name__,
                    'config': self.layer.get_config()
                }
            }
        )
        return config

    @classmethod
    def from_config(cls, config, custom_objects=None):
        """
        自定义从字典config恢复实例参数
        :param config:
        :param custom_objects:
        :return:
        """
        layer = deserialize_layer(config.pop('layer'),
                                  custom_objects=custom_objects)
        return cls(layer, **config)
def get_custom_objects(self):
        """
        自定义的层或者函数
        :return:
        """
        custom_objects = self.embedding.get_custom_objects()
        custom_objects['OurMasking'] = OurMasking
        custom_objects['CRF'] = CRF
        return custom_objects

keras.models.model_from_json(
            model_json_str,
            custom_objects=model.get_custom_objects()
        )
上一篇 下一篇

猜你喜欢

热点阅读