0001-keras规范类的基本模型

2019-08-20  本文已影响0人  小新学算法

代码如下

from keras.layers import Input, Dense
from keras.models import Model
from keras import models
class Arg:
    input_size = 784
    #模型参数
    #训练参数
    epochs = 1000    
    #度量参数    
class LinearModel:
    def __init__(self,args):
        #参数
        self.input_size = args.input_size
        self.epochs = args.epochs
        #创建和编译模型
        self.model = self._build_model()        
    def _build_model(self):             
        inputs = Input(shape=(self.input_size,))        
        x = Dense(64, activation='relu')(inputs)
        x = Dense(64, activation='relu')(x)
        predictions = Dense(10, activation='softmax')(x)        
        model = Model(inputs=inputs, outputs=predictions)
        model.compile(optimizer='rmsprop',loss='categorical_crossentropy', metrics=['accuracy'])        
        return model       
    
    def fit(self,X,y):
        history = self.model.fit(X,y,epochs=self.epochs)
        return history        
    def predict(self,X):
        return self.model.predict(X)              
    def save_model(self,path):
        return self.model.save(path)        
    def load_model(self,path):
        return models.load_model(path)    
    def load_model_structure(self):
        return self.model.summary()
        
if __name__ == "__main__":
    from keras.utils import to_categorical    
    import numpy as np
    X = np.linspace(start=20, stop=100, num=100*784).reshape(100,784)
    y = np.random.randint(10,size=(100,))
    y = to_categorical(y,num_classes=10)    
    args = Arg()    
    model = LinearModel(args)    
    history = model.fit(X,y)    
    y_hat = model.predict(X)    
    model.save_model("H:/model.h5")
    model = model.load_model("H:/model.h5")
    model.predict(X)        
    ```
    
上一篇 下一篇

猜你喜欢

热点阅读