语义分割-筛选ADE数据并重新赋值标签

2020-05-17  本文已影响0人  su945

筛选ADE数据并重新赋值标签

ADE中标签类共有150个,目前项目主要是针对室内场景的语义分割。因此需要筛选特定的类别,同时将图像中标签顺序重新赋值。原始label图像中,像素值代表类别顺序。因此,需要根据选定的类别重新赋值,其余类别的像素值设为0即背景类。

#-*- coding:utf-8 -*-
# author:suyuan
# datetime:2020/5/6 上午10:18
# software: PyCharm

import glob
import os
from PIL import Image
import cv2 as cv
import numpy as np
import json

####解析odgt数据
def parse_input_list(odgt, max_sample=-1, start_idx=-1, end_idx=-1):
    # 判断odgt是否为list类型
    #list_sample = []
    if isinstance(odgt, list):
        list_sample = odgt
    elif isinstance(odgt, str):
        list_sample = [json.loads(x.rstrip()) for x in open(odgt, 'r')]

    if max_sample > 0:
       list_sample =list_sample[0:max_sample]
    if start_idx >= 0 and end_idx >= 0:  # divide file list
        list_sample = list_sample[start_idx:end_idx]
    # 样本数量
    num_sample = len(list_sample)
    assert num_sample > 0
    print('# samples: {}'.format(num_sample))
    return list_sample

###筛选出一些特定的类别
def select_ADEcategory(lable_list,src_lable_path,dst_lable_path):
    src_lable_list = glob.glob(os.path.join(src_lable_path,'*.png'))
    for label_file in src_lable_list:
        img_name = label_file.split('/')[-1]
        img_id = img_name.split('.')[0]

        label = cv.imread(label_file, cv.IMREAD_GRAYSCALE)
        h = label.shape[0]
        w = label.shape[1]

        # new
        rewrite_label_file_name = img_id + '.png'
        rewrite_lable_file_path = os.path.join(dst_lable_path, rewrite_label_file_name)
        # rewrite_lable = cv.imread(rewrite_lable_file_path, cv.IMREAD_GRAYSCALE)

        for i in range(h):
            for j in range(w):
                lable_pixel = label[i, j]
                # 更新id信息
                for label_id  in lable_list:
                    if label[i, j] == label_id:
                       label[i, j] = 0
                #label[i, j] = lable_list[lable_pixel]
        # 重新保存更改过标签的label图像
        cv.imwrite(rewrite_lable_file_path, label)
        print(rewrite_lable_file_path)

###统计类别图像中类别的个数
#根据odgt文件进行读取
def count_category_odgt(root_dataset,odgt_file):
    list_sample = parse_input_list(odgt_file)
    count_list = np.zeros([151], dtype=np.int32)
    for sample in list_sample:
        segm_path = os.path.join(root_dataset, sample['fpath_segm'])
        label = cv.imread(segm_path, cv.IMREAD_GRAYSCALE)
        h = label.shape[0]
        w = label.shape[1]

        hest = np.zeros([151], dtype=np.int32)
        for i in range(h):
            for j in range(w):
                lable_pixel = label[i, j]
                # 统计
                if hest[lable_pixel] == 0:
                    hest[lable_pixel] = 1
        select = np.where(hest == 1)
        # 根据类别进行统计
        for index in select:
            count_list[index] += 1
        # 打印出类别个数
    count = np.sum(count_list[1:])
    for i in range(151):
        if i != 0 :
            print('类别:', i, ',数量:', count_list[i],',占比:',count_list[i]/count)




#直接读取路径下的图片
def count_category(src_lable_path):
    src_lable_list = glob.glob(os.path.join(src_lable_path, '*.png'))

    count_list = np.zeros([151], dtype=np.int32)
    for label_file in src_lable_list:
        img_name = label_file.split('/')[-1]
        img_id = img_name.split('.')[0]

        label = cv.imread(label_file, cv.IMREAD_GRAYSCALE)
        h = label.shape[0]
        w = label.shape[1]

        hest = np.zeros([151], dtype=np.int32)
        for i in range(h):
            for j in range(w):
                lable_pixel = label[i, j]
                #统计
                if hest[lable_pixel] == 0:
                    hest[lable_pixel] = 1
        select = np.where(hest==1)
        #根据类别进行统计
        for index in select:
            count_list[index] +=1
    #打印出类别个数
    for i in range(151):
        print('类别:',i,',数量:',count_list[i])


def select_20_category(odgt_file,save_path,select_id):
    print('选择类别数:',len(select_id))
    list_sample = parse_input_list(odgt_file)

    for sample in list_sample:
        segm_path = os.path.join(root_dataset, sample['fpath_segm'])
        label = cv.imread(segm_path, cv.IMREAD_GRAYSCALE)
        h = label.shape[0]
        w = label.shape[1]

        for i in range(h):
            for j in range(w):
                lable_pixel = label[i, j]
                #
                if lable_pixel in select_id:
                    for times, id in  enumerate(select_id) :
                        if id == lable_pixel:
                            label[i, j] = times+1 ;
                else:
                    label[i, j] = 0;

        img_name = segm_path.split('/')[-1]
        #img_id = img_name.split('.')[0]
        img_save_path = os.path.join(save_path,img_name)
        #重新保存
        cv.imwrite(img_save_path, label)




if __name__ == '__main__':
    lable_list = [1,13,136]


    root_path = "/media/suyuan/U/data/segment/ADE20K/ADEChallengeData2016"
    train_label_path = os.path.join(root_path,"annotations/training")
    select_train_lable_path = os.path.join(root_path,"annotations_select0506/training")
    #select_ADEcategory(lable_list, train_label_path, select_train_lable_path)

    val_lable_path = os.path.join(root_path,"annotations/validation")
    select_val_lable_path = os.path.join(root_path,"annotations_select0506/validation")
    #select_ADEcategory(lable_list, val_lable_path, select_val_lable_path)


    #统计类别数量
    src_lable_path  = '/media/suyuan/U/data/segment/ADE20K/ADEChallengeData2016/test'
    #直接读取路径下图像
    #count_category(src_lable_path)

    root_dataset = '/media/suyuan/U/data/segment/ADE20K'
    odgt_file = '/media/suyuan/U/data/segment/ADE20K/ADEChallengeData2016/validation_select.odgt'
    #使用ogdt
    #count_category_odgt(root_dataset,odgt_file)

    #类别选择
    save_path = '/media/suyuan/U/data/segment/ADE20K/training'
    select_id = [8, 9, 11, 13, 15, 16, 18, 19, 20, 23, 24, 25, 28, 34, 43, 45, 48, 58, 83, 90]
    select_20_category(odgt_file, save_path, select_id)

    #统计类别数量
    file = np.loadtxt(os.path.join(root_dataset,'ADEChallengeData2016/temo.txt'))
    count = np.sum(file[1:])
    list_id =[1,8,9,10,11,13,15,16,18,19,20,23,24,25,28,34,38,45,48,58,64,82,83,88,90,146]
    pr = 0
    for id in list_id:
        percent = 100*file[id]/9755
        pr  += percent
        print('id:',id,',占比',percent,'%')
    print(pr,'%')

上一篇 下一篇

猜你喜欢

热点阅读