tensorflow加载以目录标识的数据集

2020-10-13  本文已影响0人  WYCWGTDDR

对于图片分类网络的训练,往往将照片按类别标签存放在相应目录下,在神经网络训练时,可以使用tensorflow提供的flow_from_directory方法加载,但为了提高数据加载的性能,及使用更加强大的图像增强方法,尝试使用tensorflow.data.Dataset进行数据加载和预处理。

import tensorflow as tf
import glob
import io

gpus = tf.config.experimental.list_physical_devices('GPU')
tf.config.experimental.set_visible_devices(gpus[1], 'GPU')# 指定第2块GPU可用  
tf.config.experimental.set_memory_growth(device=gpus[1], enable=True)# 按需取用显存

train_dir = "/ai/jzclass/train"
valid_dir = "/ai/jzclass/valid"

#读取文件列表和对应的标签目录列表
train_image_path=glob.glob(train_dir + '/*/*.jpg')
train_image_label=[p.split("/")[4] for p in train_image_path]

#通过目录读取类别列表,并转换为字典
label_names = os.listdir(train_dir)
label_to_index = dict((name, index) for index, name in enumerate(label_names))

#将以文本标识的列表转换为数字标签(y值)
all_image_labels = [label_to_index[path] for path in train_image_label]

#数据加载及增强方法
def load_preprosess_image(path,label):
    image=tf.io.read_file(path)
    image=tf.image.decode_jpeg(image,channels=3) #有坑,可以用opencv代替
    image=tf.image.resize(image,[360,360])
    image=tf.image.random_crop(image,[224,224,3])
    image=tf.image.random_flip_left_right(image)
    image=tf.image.random_flip_up_down(image)
    image=tf.image.random_brightness(image,0.5)
    image=tf.image.random_contrast(image,0,1)
    image=tf.cast(image,tf.float32)
    image=image/255
    label=tf.reshape(label,[-1])
    return image,label

#数据加载
train_image_ds=tf.data.Dataset.from_tensor_slices((train_image_path,all_image_labels))
AUTOTUNE=tf.data.experimental.AUTOTUNE
train_image_ds=train_image_ds.map(load_preprosess_image,num_parallel_calls=AUTOTUNE)
train_count=len(train_image_path)

#乱序和预处理
train_image_ds=train_image_ds.shuffle(train_count).batch(BATCH_SIZE)
train_image_ds=train_image_ds.prefetch(AUTOTUNE)

后续就可以直接使用model.fit(train_image_ds, epochs = 100)进行调用。

上一篇下一篇

猜你喜欢

热点阅读