TF - 数据生成器
2019-01-11 本文已影响0人
大地瓜_
生成器
- ADEChallengeData数据集
数据的下载链接: http://sceneparsing.csail.mit.edu/results2016.html
ADEChallengeData
|
| - images
| |
| | - training 20210
| | - validation 2000
|
| - annotations
|
| - training 20210
| - validation 2000
训练集和验证集分开
- 批次生成数据的核心思路
由于图像太大,直接读入内存是不可取的,所以采用将图像的name
和标签mask
读入数组,然后将图像名放入任务队列,最后将从打乱的任务队列中随机读取图像的name
,最后根据name
重新open
读取图片。
第一步 构造任务队列
Image_Obj_1 | Image_Obj_2 | ... | Image_Obj_N
* Image_Obj : { image : "1.jpg", "annotation": "1.png"}
第二步 随机读取batch_size数据
Image_Obj_x | ... | Image_Obj_x+n
第三步 读取batch_size的实际图像
image = read( Image_Obj_x .image)
annotation = read( Image_Obj_x .annotation)
# -*- coding:utf-8 -*-
import numpy as np
import os
import random
from six.moves import cPickle as pickle
from tensorflow.python.platform import gfile
import glob
import TensorflowUtils as utils
# DATA_URL = 'http://sceneparsing.csail.mit.edu/data/ADEChallengeData2016.zip'
DATA_URL = 'http://data.csail.mit.edu/places/ADEchallenge/ADEChallengeData2016.zip'
# input_dir = FLAGS.data_dir = MIT_SceneParsing
# data_dir = MIT_SceneParsing/
def read_dataset(data_dir):
pickle_filename = "MITSceneParsing.pickle"
pickle_filepath = os.path.join(data_dir, pickle_filename)
print pickle_filepath
if not os.path.exists(pickle_filepath):
utils.maybe_download_and_extract(data_dir, DATA_URL, is_zipfile=True)
#splitext 分离文件与扩展名
SceneParsing_folder = os.path.splitext(DATA_URL.split("/")[-1])[0]
# 输入路径为 MIT_SceneParsing/ADEChallengeData2016
# result 是训练数据和验证集得到字典组成
result = create_image_lists(os.path.join(data_dir, SceneParsing_folder))
print ("Pickling ...")
with open(pickle_filepath, 'wb') as f:
pickle.dump(result, f, pickle.HIGHEST_PROTOCOL)
else:
print ("Found pickle file!")
with open(pickle_filepath, 'rb') as f:
result = pickle.load(f)
training_records = result['training']
validation_records = result['validation']
del result
# train_records 是训练数据集的列表字典
# validation_records 是验证数据集的列表字典
return training_records, validation_records
def create_image_lists(image_dir):
'''
这个函数用来返回生成训练和验证集
image['training'] = [{'image': f, 'annotation': annotation_file, 'filename': filename},....,...,...]
:param image_dir:
:return:
'''
if not gfile.Exists(image_dir):
print("Image directory '" + image_dir + "' not found.")
return None
directories = ['training', 'validation']
image_list = {}
for directory in directories:
file_list = []
image_list[directory] = []
file_glob = os.path.join(image_dir, "images", directory, '*.' + 'jpg')
file_list.extend(glob.glob(file_glob))
if not file_list:
print('No files found')
else:
for f in file_list:
filename = os.path.splitext(f.split("/")[-1])[0]
annotation_file = os.path.join(image_dir, "annotations", directory, filename + '.png')
if os.path.exists(annotation_file):
record = {'image': f, 'annotation': annotation_file, 'filename': filename}
image_list[directory].append(record)
else:
print("Annotation file not found for %s - Skipping" % filename)
random.shuffle(image_list[directory])
no_of_images = len(image_list[directory])
print ('No. of %s files: %d' % (directory, no_of_images))
return image_list
import numpy as np
import scipy.misc as misc
class BatchDatset:
files = []
images = []
annotations = []
image_options = {}
batch_offset = 0
epochs_completed = 0
def __init__(self, records_list, image_options={}):
"""
Intialize a generic file reader with batching for list of files
:param records_list: list of file records to read -
sample record: {'image': f, 'annotation': annotation_file, 'filename': filename}
:param image_options: A dictionary of options for modifying the output image
Available options:
resize = True/ False
resize_size = #size of output image - does bilinear resize
color=True/False
"""
print("Initializing Batch Dataset Reader...")
print(image_options)
self.files = records_list
self.image_options = image_options
self._read_images()
def _read_images(self):
self.__channels = True
self.images = np.array([self._transform(filename['image']) for filename in self.files])
self.__channels = False
self.annotations = np.array(
[np.expand_dims(self._transform(filename['annotation']), axis=3) for filename in self.files])
print (self.images.shape)
print (self.annotations.shape)
def _transform(self, filename):
image = misc.imread(filename)
if self.__channels and len(image.shape) < 3: # make sure images are of shape(h,w,3)
image = np.array([image for i in range(3)])
if self.image_options.get("resize", False) and self.image_options["resize"]:
resize_size = int(self.image_options["resize_size"])
resize_image = misc.imresize(image,
[resize_size, resize_size], interp='nearest')
else:
resize_image = image
return np.array(resize_image)
def get_records(self):
return self.images, self.annotations
def reset_batch_offset(self, offset=0):
self.batch_offset = offset
def next_batch(self, batch_size):
start = self.batch_offset
self.batch_offset += batch_size
if self.batch_offset > self.images.shape[0]:
# Finished epoch
self.epochs_completed += 1
print("****************** Epochs completed: " + str(self.epochs_completed) + "******************")
# Shuffle the data
perm = np.arange(self.images.shape[0])
np.random.shuffle(perm)
self.images = self.images[perm]
self.annotations = self.annotations[perm]
# Start next epoch
start = 0
self.batch_offset = batch_size
end = self.batch_offset
return self.images[start:end], self.annotations[start:end]
def get_random_batch(self, batch_size):
indexes = np.random.randint(0, self.images.shape[0], size=[batch_size]).tolist()
return self.images[indexes], self.annotations[indexes]