2020-12-06

2020-12-06  本文已影响0人  轻菊不爱柠檬

evaluate.py

from __future__import print_function

import matplotlib.pyplotas plt

import argparse

from randomimport shuffle

import random

import os

from netimport *

import pandasas pd

import numpyas np

parser = argparse.ArgumentParser(description='')

parser.add_argument("--image_size", type=int, default=64, help="load image size")# 网络输入的尺度#default=256

parser.add_argument("--snapshots", default='./log', help="Path of Snapshots")# 读取训练好的模型参数的路径

parser.add_argument("--test_data_path", default='./dataset/test/', help="path of x training datas.")# 高信噪比光谱数据的训练图片路径

parser.add_argument("--out_dir_x", default='./test_output/x/', help="Output Folder")# 保存x域的输入图片与生成的y域图片的路径

parser.add_argument("--out_dir_y", default='./test_output/y/', help="Output Folder")# 保存y域的输入图片与生成的x域图片的路径

args = parser.parse_args()

hang_num=11

#df_one=pd.read_csv(r'F:/bu/data/sdss/new_experiment/new_lowtest.csv',header=None)

df_one=pd.read_csv(r'F:/bu/data/sdss/测试下载数据/low.csv',header=None)

df_one=df_one.ix[:,:4096]

x_image_resize=[]

for iin range(hang_num):

print("...low...")

data_list=df_one.ix[i,:4096]

list_df_low = []

for jin range(4096):

#print(j)

        df_low_scal = (data_list[j]-np.mean(data_list)) / (max(data_list) -min(data_list))

list_df_low.append(df_low_scal)

re_list = np.array(list_df_low).reshape(64, 64, 1)

x_image_resize.append(re_list)

print(len(x_image_resize))

df_two=pd.read_csv(r'F:/bu/data/sdss/测试下载数据/high.csv',header=None)

y_image_resize=[]

for iin range(hang_num):

print("...high...")

data_list=df_two.ix[i,:]

list_df_high = []

for jin range(4096):

df_high2 = (data_list[j] - np.mean(data_list))/ (max(data_list) -min(data_list))

list_df_high.append(df_high2)

re_list = np.array(list_df_high).reshape(64, 64, 1)

y_image_resize.append(re_list)

print(len(y_image_resize))

def make_list_number_same(x_input_images_raw, y_input_images_raw):# add_train_list函数将x域和y域的图像数量变成一致

    if len(x_input_images_raw) ==len(y_input_images_raw):# 如果x域和y域图像数量本来就一致,直接返回

        return x_input_images_raw, shuffle(y_input_images_raw)

elif len(x_input_images_raw) >len(y_input_images_raw):# 如果x域的训练图像数量大于y域的训练图像数量,则随机选择y域的图像补充y域

        add_num =int(len(x_input_images_raw) -len(y_input_images_raw))# 计算两域图像数量相差的倍数

        length =len(y_input_images_raw)

for iin range(add_num):

n = random.sample(range(length), 1)# 每训练一个epoch,就打乱一下x域图像顺序

            num = n[0]

y_input_images_raw.append(y_input_images_raw[num])

return x_input_images_raw, y_input_images_raw# 返回数量一致的x域和y域图像路径名称列表

    else:# 与elif中的逻辑一致,只是x与y互换,不再赘述

        add_num =int(len(y_input_images_raw) -len(x_input_images_raw))# 计算两域图像数量相差的倍数

        length =len(x_input_images_raw)

for iin range(add_num):

n = random.sample(range(length), 1)# 每训练一个epoch,就打乱一下x域图像顺序

            num = n[0]

x_input_images_raw.append(x_input_images_raw[num])

return x_input_images_raw, y_input_images_raw# 返回数量一致的x域和y域图像路径名称列表

def cv_inv_proc(img):# cv_inv_proc函数将读取图片时归一化的图片还原成原图

    img_rgb = (img +1.) *127.5

    return img_rgb.astype(np.float32)# bgr

def get_write_picture(x_image, y_image, fake_y, fake_x):# get_write_picture函数得到网络测试结果

    x_image = cv_inv_proc(x_image)# 还原x域的图像

    y_image = cv_inv_proc(y_image)# 还原y域的图像

    fake_y = cv_inv_proc(fake_y[0])# 还原生成的y域的图像

    fake_x = cv_inv_proc(fake_x[0])# 还原生成的x域的图像

    x_output = np.concatenate((x_image, fake_y), axis=1)# 得到x域的输入图像以及对应的生成的y域图像

    y_output = np.concatenate((y_image, fake_x), axis=1)# 得到y域的输入图像以及对应的生成的x域图像

    return x_output, y_output

batch_size =5

def main():

if not os.path.exists(args.out_dir_x):# 如果保存x域测试结果的文件夹不存在则创建

        os.makedirs(args.out_dir_x)

if not os.path.exists(args.out_dir_y):# 如果保存y域测试结果的文件夹不存在则创建

        os.makedirs(args.out_dir_y)

test_x_image = tf.placeholder(tf.float32, shape=[batch_size, 64, 64, 1], name='test_x_image')# 输入的x域图像

    test_y_image = tf.placeholder(tf.float32, shape=[batch_size, 64, 64, 1], name='test_y_image')# 输入的y域图像

    fake_y = generator(image=test_x_image, reuse=False, name='generator_x2y')# 生成的y域图像

    fake_x_ = generator(image=fake_y, reuse=False, name='generator_y2x')# 重建的x域图像

    fake_x = generator(image=test_y_image, reuse=True, name='generator_y2x')# 生成的x域图像

    fake_y_ = generator(image=fake_x, reuse=True, name='generator_x2y')# 重建的y域图像

    restore_var = [vfor vin tf.global_variables()if 'generator' in v.name]# 需要载入的已训练的模型参数

    config = tf.ConfigProto()

config.gpu_options.allow_growth =True  # 设定显存不超量使用

    sess = tf.Session(config=config)# 建立会话层

    saver = tf.train.Saver(var_list=restore_var, max_to_keep=1)# 导入模型参数时使用

    checkpoint = tf.train.latest_checkpoint(args.snapshots)# 读取模型参数

    saver.restore(sess, checkpoint)# 导入模型参数

    #total_step = len(x_image_resize)

    total_step =50

    for stepin range(total_step):

n = random.sample(range(len(x_image_resize)), batch_size)

num = n[0]

#num=x_image_resize[step]

        batch_x_image = np.array(x_image_resize)[n]

batch_y_image = np.array(y_image_resize)[n]

feed_dict = {test_x_image: batch_x_image, test_y_image: batch_y_image}# 建立feed_dict

        fake_x_value, fake_y_value = sess.run([fake_x, fake_y], feed_dict=feed_dict)# 得到生成的x域图像与y域图像

        ###三张图画在一起  去噪l_h_l

        real_x = []##画真的x

        plt.yticks([0.0,0.5,1.0,1.5,2.0,2.5])

#plt.tight_layout(1)

#plt.subplots_adjust(bottom=0.1,left=0.1,right=0.15,top=0.15)

        plt.figure(figsize=(6,2.8))

for iin range(64):

for jin range(64):

real_x.append(x_image_resize[num][i][j]+1.9)

x = [ifor iin range(4000, 8096, 1)]

plt.plot(x, real_x, color='seagreen')

# plt.savefig('F:/bu/data/sdss/new_experiment/ll/' + str(counter))

        real_yy = []##画真的y

        for iin range(64):

for jin range(64):

real_yy.append(y_image_resize[num][i][j]+1.1)

x = [ifor iin range(4000, 8096, 1)]

plt.plot(x, real_yy, color='coral')

fake_xxx = []

for iin range(64):

for jin range(64):

fake_xxx.append(fake_y_value[0][i][j]+0.5)

x = [ifor iin range(4000, 8096, 1)]

plt.plot(x, fake_xxx, color='seagreen')

#plt.savefig('F:/bu/data/sdss/new_experiment/test/l_h_l/' + str(step))

        plt.savefig('F:/bu/data/sdss/测试下载数据/pic/' +str(step))

plt.close()

'''

        real_yy = []  ##画真的y

for i in range(64):

for j in range(64):

real_yy.append(y_image_resize[num][i][j] + 1.3)

x = [i for i in range(4000, 8096, 1)]

plt.plot(x, real_yy, color='seagreen')

        real_x = []  ##画真的x

for i in range(64):

for j in range(64):

real_x.append(x_image_resize[num][i][j] + 0.6)

x = [i for i in range(4000, 8096, 1)]

plt.plot(x, real_x, color='coral')

# plt.savefig('F:/bu/data/sdss/new_experiment/ll/' + str(counter))

fake_xxx = []

for i in range(64):

for j in range(64):

fake_xxx.append(fake_x_value[0][i][j])

x = [i for i in range(4000, 8096, 1)]

plt.plot(x, fake_xxx, color='seagreen')

plt.savefig('F:/bu/data/sdss/new_experiment/test/h_l_h/' + str(step))

plt.close()

'''

        print('step {:d}'.format(step))

if __name__ =='__main__':

main()

上一篇 下一篇

猜你喜欢

热点阅读