label 图片转成 one-hot 后,获取每个 class
2019-01-27 本文已影响10人
谢小帅
import os
import numpy as np
import matplotlib.pyplot as plt
import skimage.io as sio
np.set_printoptions(threshold=np.nan, linewidth=10000)
label_path = '/temp_disk/xs/sun/train/label/'
csv_path = '/temp_disk/xs/sun/seg37_class_dict.csv'
def get_label_info(csv_path):
# return label -> {label_name: [r_value, g_value, b_value, ...}
ann = pd.read_csv(csv_path)
label = {}
for iter, row in ann.iterrows():
label_name = row['name']
r = row['r']
g = row['g']
b = row['b']
if r + g + b == 0: # remove bg
continue
label[label_name] = [int(r), int(g), int(b)]
return label
def one_hot_it(label, label_info):
# return semantic_map -> [H, W, num_classes] (530, 730, 38)
semantic_map = []
for info in label_info:
color = label_info[info]
# colour_map = np.full((label.shape[0], label.shape[1], label.shape[2]), colour, dtype=int)
equality = np.equal(label, color)
class_map = np.all(equality, axis=-1)
semantic_map.append(class_map)
semantic_map = np.stack(semantic_map, axis=-1)
return semantic_map
label_info = get_label_info(csv_path)
label_names = list(label_info.keys())
for label in os.listdir(label_path):
label = sio.imread(label_path + label)
plt.imshow(label)
plt.show()
label = np.array(label)
label = one_hot_it(label, label_info).astype(np.uint8)
for i in range(len(label_names)): # classes
if np.sum(label[:, :, i]) > 0:
print(label_names[i], i)
mask = label[:, :, i] # class superpixel
plt.imshow(mask, cmap='gray')
plt.title(label_names[i])
plt.show()
break
wall 0
floor 1
door 7
window 8
paper 25
label color
wall
floor
door
window
paper