label 图片转成 one-hot 后,获取每个 class

2019-01-27  本文已影响10人  谢小帅
import os
import numpy as np
import matplotlib.pyplot as plt
import skimage.io as sio

np.set_printoptions(threshold=np.nan, linewidth=10000)

label_path = '/temp_disk/xs/sun/train/label/'
csv_path = '/temp_disk/xs/sun/seg37_class_dict.csv'

def get_label_info(csv_path):
    # return label -> {label_name: [r_value, g_value, b_value, ...}
    ann = pd.read_csv(csv_path)
    label = {}
    for iter, row in ann.iterrows():
        label_name = row['name']
        r = row['r']
        g = row['g']
        b = row['b']
        if r + g + b == 0:  # remove bg
            continue
        label[label_name] = [int(r), int(g), int(b)]
    return label

def one_hot_it(label, label_info):
    # return semantic_map -> [H, W, num_classes] (530, 730, 38)
    semantic_map = []
    for info in label_info:
        color = label_info[info]
        # colour_map = np.full((label.shape[0], label.shape[1], label.shape[2]), colour, dtype=int)
        equality = np.equal(label, color)
        class_map = np.all(equality, axis=-1)
        semantic_map.append(class_map)
    semantic_map = np.stack(semantic_map, axis=-1)
    return semantic_map


label_info = get_label_info(csv_path)
label_names = list(label_info.keys())

for label in os.listdir(label_path):
    label = sio.imread(label_path + label)
    plt.imshow(label)
    plt.show()
    label = np.array(label)
    label = one_hot_it(label, label_info).astype(np.uint8)
    for i in range(len(label_names)):  # classes
        if np.sum(label[:, :, i]) > 0:
            print(label_names[i], i)
            mask = label[:, :, i]  # class superpixel
            plt.imshow(mask, cmap='gray')
            plt.title(label_names[i])
            plt.show()
    break
wall 0
floor 1
door 7
window 8
paper 25
label color wall floor door window paper
上一篇下一篇

猜你喜欢

热点阅读