TF2.0:训练集、测试集的地址、标签获取完整流程!

2020-05-21  本文已影响0人  胜负55开

TF2.0获得训练集、测试集的所有文件地址和对应的标签,下面是完整的一条龙操作,简单快捷好理解,获得的结果可直接作为tf.data的输入数据。

1. 首先:训练集、测试集放到两个不同的文件夹里:
图1:训练集、测试集放两个文件夹内

每一个文件夹下又有airplane、lake两个文件夹:


图2:每个文件夹内又有2个子文件夹,分别放对应的图
2. 获得所有训练、测试数据的路径(字符串):
import glob

 # train文件夹内所有文件夹都要,所有文件夹内的所有.jpg文件都要!
train_data_path = glob.glob( 'E:/data/train/*/*.jpg' ) 
# test文件夹内所有文件夹都要,所有文件夹内的所有.jpg文件都要!
test_data_path = glob.glob( 'E:/data/test/*/*.jpg' )    

# 查看一下:数据量、类型
len(train_data_path), type(train_data_path[0])
(1400, str)

# 再查看一下:内容是什么
train_data_path[0]
'E:/data/train\\airplane\\airplane_001.jpg'
3. 将数据全部打乱:
import random 

random.shuffle( train_data_path )
random.shuffle( test_data_path )
4. 获得训练集、测试集一共有哪些标签种类:
# 纯标签有哪几种:把每个文件地址按\\分割,第2个元素就是标签!
# set()获得无序不重复元素集!
pure_train_labels = set( [ p.split('\\')[1] for p in train_data_path ] )
pure_test_labels = set( [ p.split('\\')[1] for p in test_data_path ] )

# 查看一下:
pure_train_labels, pure_test_labels
({'airplane', 'lake'}, {'airplane', 'lake'})
5. 把4获得的标签种类,转为数字索引的形式(字典):
pure_train_labels_to_index = dict( (index, name) for (name, index) in enumerate(pure_train_labels) )
pure_test_labels_to_index = dict( (index, name) for (name, index) in enumerate(pure_test_labels) )

# 查看一下:
pure_train_labels_to_index, pure_test_labels_to_index
({'airplane': 0, 'lake': 1}, {'airplane': 0, 'lake': 1})
6. 获得所有图片的标签:
# 还是把每个文件地址按\\分割,第2个元素就是标签!不需要用set(因为就要获得所有图片的)
train_labels = [ p.split('\\')[1] for p in train_data_path ]
test_labels = [ p.split('\\')[1] for p in test_data_path ]

# 查看一下:前3个
train_labels[0:3], test_labels[0:3]
(['airplane', 'airplane', 'lake'], ['airplane', 'lake', 'airplane'])
7. 把6获得的标签全转为对应的索引值:利用5中的字典!
# 获取“键”对应的“(索引)值”:
train_labels = [ pure_train_labels_to_index.get(label) for label in train_labels ]
test_labels = [ pure_test_labels_to_index.get(label) for label in test_labels ]

# 查看一下:还是前3个 —— 可与上面对比看对不对!
([0, 0, 1], [0, 1, 0])
8. 地址、标签获取完毕;最后列出所有后面会用到的变量名:
# 文件地址:后面tf.io.read_file( path )需要输入它们
train_data_path, test_data_path

# 测试集、训练集所有图像对应的标签:tf.data()需要!
train_labels, test_labels

补充:

上一篇下一篇

猜你喜欢

热点阅读