CNN跑自己的数据集及调整代码参考博客

2019-01-02  本文已影响0人  sunnylxs

参考博客:https://blog.csdn.net/bryant_meng/article/details/81077196

CNN,keras动态学习率参考:https://blog.csdn.net/zzc15806/article/details/79711114

https://blog.csdn.net/Xwei1226/article/details/81297878

https://blog.csdn.net/xjcvip007/article/details/52801216

#-*-coding:utf-8-*-

from __future__ import print_function

import numpy as np

seed=7

np.random.seed(seed)

import time

from keras.callbacks import EarlyStopping

import keras

from keras.utils.vis_utils import plot_model

import matplotlib.pyplot as plt

from sklearn.utils import shuffle

from sklearn.model_selection import train_test_split

from keras.datasets import mnist

from keras.models import Sequential

from keras.layers import Dense, Dropout, Activation, Flatten

from keras.optimizers import SGD,RMSprop,adam

from keras.layers import Conv2D, MaxPooling2D

from keras.wrappers.scikit_learn import KerasClassifier

from PIL import Image

from sklearn.model_selection import GridSearchCV

from keras import backend as K

K.clear_session()

import os

os.environ['TF_CPP_MIN_LOG_LEVEL']='2'

from keras.layers.normalization import BatchNormalization

from numpy import *

num_classes = 3

batch_size=64

epochs=50

chanDim=-1

start=time.clock()

path2='./newgray/'

img_rows,img_cols=140,105

imlist=os.listdir(path2)

imlist.sort(key = lambda x:int(x[:-4])) 

im1=array(Image.open('./newgray'+'//'+imlist[0]))

m,n=im1.shape[0:2]

imnbr=len(imlist)

immatrix=array([array(Image.open('./newgray'+'//'+im2)).flatten()

              for im2 in imlist],'f')

num_samples=size(imlist)

label=np.ones((num_samples,),dtype='int64')

label[0:1800]=0

label[1800:3600]=1

label[3600:]=2

names=['best','general','bad']

data,label=shuffle(immatrix,label,random_state=2)

train_data=[data,label]

#img=immatrix[167].reshape(img_rows,img_cols)

#plt.imshow(img)

#plt.imshow(img,cmap='gray')

print(train_data[0].shape)

print(train_data[1].shape)

(x,y)=(train_data[0],train_data[1])

x_train,x_test,y_train,y_test=train_test_split(x,y,test_size=0.3,random_state=4)

if K.image_data_format() == 'channels_first':

    x_train = x_train.reshape(x_train.shape[0], 1, img_rows, img_cols)

    x_test = x_test.reshape(x_test.shape[0], 1, img_rows, img_cols)

    input_shape = (1, img_rows, img_cols)

else:

    x_train = x_train.reshape(x_train.shape[0], img_rows, img_cols, 1)

    x_test = x_test.reshape(x_test.shape[0], img_rows, img_cols, 1)

    input_shape = (img_rows, img_cols, 1)

x_train = x_train.astype('float32')

x_test = x_test.astype('float32')

x_train /= 255

x_test /= 255

# convert class vectors to binary class matrices

y_train = keras.utils.to_categorical(y_train, num_classes)

y_test = keras.utils.to_categorical(y_test, num_classes)

model = Sequential()

model.add(Conv2D(16, [6,5],

                padding='valid',

                input_shape=input_shape))

model.add(Activation('relu'))

model.add(MaxPooling2D(pool_size=(1,2)))

model.add(Dropout(0.5))

model.add(Conv2D(632, [6,5],

                padding='valid'))                       

model.add(Activation('relu'))

model.add(MaxPooling2D(pool_size=(1,2)))

model.add(Dropout(0.5))

model.add(Flatten())

model.add(Dense(64))

model.add(Activation('relu'))

model.add(BatchNormalization(axis=chanDim))

model.add(Dropout(0.5))

model.add(Dense(num_classes))

model.add(Activation('softmax'))

model.compile(loss='categorical_crossentropy',

              optimizer='adam',       

              metrics=['accuracy'])

model.summary()

early_stopping = EarlyStopping(monitor='acc', patience=2,mode='max')

history=model.fit(x_train,y_train,batch_size=batch_size,epochs=epochs,

          callbacks=[early_stopping],shuffle=True,verbose=1,validation_split=0.2)

plot_model(model, to_file='model1.png',show_shapes=True) 

# Evaluating the model

score=model.evaluate(x_test,y_test,verbose=0)

print('Test Loss:',score[0])

print('Test Accuracy:',score[1])

plt.figure(1)

plt.plot(history.history['acc'])

plt.plot(history.history['val_acc'])

plt.title('model accuracy')

plt.ylabel('accuracy')

plt.xlabel('epochs')

plt.legend(['test','train'],loc='upper left')

plt.figure(2) 

plt.plot(history.history['loss'])

plt.plot(history.history['val_loss'])

plt.title('model loss')

plt.ylabel('loss')

plt.xlabel('epochs')

plt.show()

#%%

# Printing the confusion matrix

from sklearn.metrics import classification_report,confusion_matrix

import itertools

Y_pred = model.predict(x_test)

print(Y_pred)

y_pred = np.argmax(Y_pred, axis=1)

print(y_pred)

#y_pred = model.predict_classes(X_test)

#print(y_pred)

target_names = ['class 0(best)', 'class 1(general)', 'class 2(bad)']

print(classification_report(np.argmax(y_test,axis=1), y_pred,target_names=target_names))

print(confusion_matrix(np.argmax(y_test,axis=1), y_pred))

# Plotting the confusion matrix

def plot_confusion_matrix(cm, classes,

                          normalize=False,

                          title='Confusion matrix',

                          cmap=plt.cm.Blues):

    """

    This function prints and plots the confusion matrix.

    Normalization can be applied by setting `normalize=True`.

    """

    plt.imshow(cm, interpolation='nearest', cmap=cmap)

    plt.title(title)

    plt.colorbar()

    tick_marks = np.arange(len(classes))

    plt.xticks(tick_marks, classes, rotation=45)

    plt.yticks(tick_marks, classes)

    if normalize:

        cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]

        print("Normalized confusion matrix")

    else:

        print('Confusion matrix, without normalization')

    print(cm)

    thresh = cm.max() / 2.

    for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])):

        plt.text(j, i, cm[i, j],

                horizontalalignment="center",

                color="white" if cm[i, j] > thresh else "black")

    plt.tight_layout()

    plt.ylabel('True label')

    plt.xlabel('Predicted label')

# Compute confusion matrix

cnf_matrix = (confusion_matrix(np.argmax(y_test,axis=1), y_pred))

np.set_printoptions(precision=2)

plt.figure()

# Plot non-normalized confusion matrix

plot_confusion_matrix(cnf_matrix, classes=target_names,

                      title='Confusion matrix')

#plt.figure()

# Plot normalized confusion matrix

#plot_confusion_matrix(cnf_matrix, classes=target_names, normalize=True,

#                      title='Normalized confusion matrix')

#plt.figure()

plt.show()

end=time.clock()

print('Running time:%s seconds' %(end-start))

上一篇下一篇

猜你喜欢

热点阅读