图像分割

2020-11-26 json_to_dataset DONG2

2020-11-26  本文已影响0人  智能之心
import shutil,base64,io,os,json,glob,math,warnings
import numpy as np
import PIL
from PIL import (ExifTags, Image, ImageOps, ImageDraw, ImageDraw, ImageFont)
from skimage import img_as_ubyte
import tensorflow as tf
import os.path as osp
import matplotlib.pyplot as plt
# labelme 制作的数据 Annotations 和 JEPGImages,重点Annotations,因为里面自带图片数据了
def mk_dataset(json_file, image_file, out, label_name_to_value):
    def img_b64_to_arr(img_b64):
        f = io.BytesIO()
        f.write(base64.b64decode(img_b64))
        img_arr = np.array(PIL.Image.open(f))
        return img_arr


    def img_arr_to_b64(img_arr): #没调用
        img_pil = PIL.Image.fromarray(img_arr)
        f = io.BytesIO()
        img_pil.save(f, format='PNG')
        img_bin = f.getvalue()
        if hasattr(base64, 'encodebytes'):
            img_b64 = base64.encodebytes(img_bin)
        else:
            img_b64 = base64.encodestring(img_bin)
        return img_b64


    def img_data_to_png_data(img_data): # 没调用
        with io.BytesIO() as f:
            f.write(img_data)
            img = PIL.Image.open(f)

            with io.BytesIO() as f:
                img.save(f, 'PNG')
                f.seek(0)
                return f.read()


    def apply_exif_orientation(image): #没调用
        try:
            exif = image._getexif()
        except AttributeError:
            exif = None

        if exif is None:
            return image

        exif = {
            PIL.ExifTags.TAGS[k]: v
            for k, v in exif.items()
            if k in PIL.ExifTags.TAGS
        }

        orientation = exif.get('Orientation', None)

        if orientation == 1:
            # do nothing
            return image
        elif orientation == 2:
            # left-to-right mirror
            return PIL.ImageOps.mirror(image)
        elif orientation == 3:
            # rotate 180
            return image.transpose(PIL.Image.ROTATE_180)
        elif orientation == 4:
            # top-to-bottom mirror
            return PIL.ImageOps.flip(image)
        elif orientation == 5:
            # top-to-left mirror
            return PIL.ImageOps.mirror(image.transpose(PIL.Image.ROTATE_270))
        elif orientation == 6:
            # rotate 270
            return image.transpose(PIL.Image.ROTATE_270)
        elif orientation == 7:
            # top-to-right mirror
            return PIL.ImageOps.mirror(image.transpose(PIL.Image.ROTATE_90))
        elif orientation == 8:
            # rotate 90
            return image.transpose(PIL.Image.ROTATE_90)
        else:
            return image

    def polygons_to_mask(img_shape, polygons, shape_type=None):
        warnings.warn(
            "The 'polygons_to_mask' function is deprecated, "
            "use 'shape_to_mask' instead."
        )
        return shape_to_mask(img_shape, points=polygons, shape_type=shape_type)


    def shape_to_mask(img_shape, points, shape_type=None,
                      line_width=10, point_size=5):
        mask = np.zeros(img_shape[:2], dtype=np.uint8)
        mask = PIL.Image.fromarray(mask)
        draw = PIL.ImageDraw.Draw(mask)
        xy = [tuple(point) for point in points]
        if shape_type == 'circle':
            assert len(xy) == 2, 'Shape of shape_type=circle must have 2 points'
            (cx, cy), (px, py) = xy
            d = math.sqrt((cx - px) ** 2 + (cy - py) ** 2)
            draw.ellipse([cx - d, cy - d, cx + d, cy + d], outline=1, fill=1)
        elif shape_type == 'rectangle':
            assert len(xy) == 2, 'Shape of shape_type=rectangle must have 2 points'
            draw.rectangle(xy, outline=1, fill=1)
        elif shape_type == 'line':
            assert len(xy) == 2, 'Shape of shape_type=line must have 2 points'
            draw.line(xy=xy, fill=1, width=line_width)
        elif shape_type == 'linestrip':
            draw.line(xy=xy, fill=1, width=line_width)
        elif shape_type == 'point':
            assert len(xy) == 1, 'Shape of shape_type=point must have 1 points'
            cx, cy = xy[0]
            r = point_size
            draw.ellipse([cx - r, cy - r, cx + r, cy + r], outline=1, fill=1)
        else:
            assert len(xy) > 2, 'Polygon must have points more than 2'
            draw.polygon(xy=xy, outline=1, fill=1)
        mask = np.array(mask, dtype=bool)
        return mask


    def shapes_to_label(img_shape, shapes, label_name_to_value, type='class'):
        assert type in ['class', 'instance']

        cls = np.zeros(img_shape[:2], dtype=np.int32)
        if type == 'instance':
            ins = np.zeros(img_shape[:2], dtype=np.int32)
            instance_names = ['_background_']
        for shape in shapes:
            points = shape['points']
            label = shape['label']
            shape_type = shape.get('shape_type', None)
            if type == 'class':
                cls_name = label
            elif type == 'instance':
                cls_name = label.split('-')[0]
                if label not in instance_names:
                    instance_names.append(label)
                ins_id = instance_names.index(label)
            cls_id = label_name_to_value[cls_name]
            mask = shape_to_mask(img_shape[:2], points, shape_type)
            cls[mask] = cls_id
            if type == 'instance':
                ins[mask] = ins_id

        if type == 'instance':
            return cls, ins
        return cls


    def labelme_shapes_to_label(img_shape, shapes):
        warnings.warn('labelme_shapes_to_label is deprecated, so please use '
                    'shapes_to_label.')

        label_name_to_value = {'_background_': 0}
        for shape in shapes:
            label_name = shape['label']
            if label_name in label_name_to_value:
                label_value = label_name_to_value[label_name]
            else:
                label_value = len(label_name_to_value)
                label_name_to_value[label_name] = label_value

        lbl = shapes_to_label(img_shape, shapes, label_name_to_value)
        return lbl, label_name_to_value

    def masks_to_bboxes(masks):
        if masks.ndim != 3:
            raise ValueError(
                'masks.ndim must be 3, but it is {}'
                .format(masks.ndim)
            )
        if masks.dtype != bool:
            raise ValueError(
                'masks.dtype must be bool type, but it is {}'
                .format(masks.dtype)
            )
        bboxes = []
        for mask in masks:
            where = np.argwhere(mask)
            (y1, x1), (y2, x2) = where.min(0), where.max(0) + 1
            bboxes.append((y1, x1, y2, x2))
        bboxes = np.asarray(bboxes, dtype=np.float32)
        return bboxes

    def label_colormap(N=256):

        def bitget(byteval, idx):
            return ((byteval & (1 << idx)) != 0)

        cmap = np.zeros((N, 3))
        for i in range(0, N):
            id = i
            r, g, b = 0, 0, 0
            for j in range(0, 8):
                r = np.bitwise_or(r, (bitget(id, 0) << 7 - j))
                g = np.bitwise_or(g, (bitget(id, 1) << 7 - j))
                b = np.bitwise_or(b, (bitget(id, 2) << 7 - j))
                id = (id >> 3)
            cmap[i, 0] = r
            cmap[i, 1] = g
            cmap[i, 2] = b
        cmap = cmap.astype(np.float32) / 255
        return cmap


    def _validate_colormap(colormap, n_labels):
        if colormap is None:
            colormap = label_colormap(n_labels)
        else:
            assert colormap.shape == (colormap.shape[0], 3), \
                'colormap must be sequence of RGB values'
            assert 0 <= colormap.min() and colormap.max() <= 1, \
                'colormap must ranges 0 to 1'
        return colormap


    # similar function as skimage.color.label2rgb
    def label2rgb(
        lbl, img=None, n_labels=None, alpha=0.5, thresh_suppress=0, colormap=None,
    ):
        if n_labels is None:
            n_labels = len(np.unique(lbl))

        colormap = _validate_colormap(colormap, n_labels)
        colormap = (colormap * 255).astype(np.uint8)

        lbl_viz = colormap[lbl]
        lbl_viz[lbl == -1] = (0, 0, 0)  # unlabeled

        if img is not None:
            img_gray = PIL.Image.fromarray(img).convert('LA')
            img_gray = np.asarray(img_gray.convert('RGB'))
            # img_gray = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY)
            # img_gray = cv2.cvtColor(img_gray, cv2.COLOR_GRAY2RGB)
            lbl_viz = alpha * lbl_viz + (1 - alpha) * img_gray
            lbl_viz = lbl_viz.astype(np.uint8)

        return lbl_viz


    def draw_label(label, img=None, label_names=None, colormap=None, **kwargs):
        """Draw pixel-wise label with colorization and label names.

        label: ndarray, (H, W)
            Pixel-wise labels to colorize.
        img: ndarray, (H, W, 3), optional
            Image on which the colorized label will be drawn.
        label_names: iterable
            List of label names.
        """


        backend_org = plt.rcParams['backend']
        plt.switch_backend('agg')

        plt.subplots_adjust(left=0, right=1, top=1, bottom=0,
                            wspace=0, hspace=0)
        plt.margins(0, 0)
        plt.gca().xaxis.set_major_locator(plt.NullLocator())
        plt.gca().yaxis.set_major_locator(plt.NullLocator())

        if label_names is None:
            label_names = [str(l) for l in range(label.max() + 1)]

        colormap = _validate_colormap(colormap, len(label_names))

        label_viz = label2rgb(
            label, img, n_labels=len(label_names), colormap=colormap, **kwargs
        )
        plt.imshow(label_viz)
        plt.axis('off')

        plt_handlers = []
        plt_titles = []
        for label_value, label_name in enumerate(label_names):
            if label_value not in label:
                continue
            fc = colormap[label_value]
            p = plt.Rectangle((0, 0), 1, 1, fc=fc)
            plt_handlers.append(p)
            plt_titles.append('{value}: {name}'
                              .format(value=label_value, name=label_name))
        plt.legend(plt_handlers, plt_titles, loc='lower right', framealpha=.5)

        f = io.BytesIO()
        plt.savefig(f, bbox_inches='tight', pad_inches=0)
        plt.cla()
        plt.close()

        plt.switch_backend(backend_org)

        out_size = (label_viz.shape[1], label_viz.shape[0])
        out = PIL.Image.open(f).resize(out_size, PIL.Image.BILINEAR).convert('RGB')
        out = np.asarray(out)
        return out


    def draw_instances(
        image=None,
        bboxes=None,
        labels=None,
        masks=None,
        captions=None,
    ):


        # TODO(wkentaro)
        assert image is not None
        assert bboxes is not None
        assert labels is not None
        assert masks is None
        assert captions is not None

        viz = PIL.Image.fromarray(image)
        draw = PIL.ImageDraw.ImageDraw(viz)

        font_path = osp.join(
            osp.dirname(matplotlib.__file__),
            'mpl-data/fonts/ttf/DejaVuSans.ttf'
        )
        font = PIL.ImageFont.truetype(font_path)

        colormap = label_colormap(255)
        for bbox, label, caption in zip(bboxes, labels, captions):
            color = colormap[label]
            color = tuple((color * 255).astype(np.uint8).tolist())

            xmin, ymin, xmax, ymax = bbox
            draw.rectangle((xmin, ymin, xmax, ymax), outline=color)
            draw.text((xmin, ymin), caption, font=font)

        return np.asarray(viz)
    
    
    if out is None:
        out_dir = os.path.basename(json_file).replace('.', '_')
        out_dir = os.path.join(os.path.dirname(json_file), out_dir)
    else:
        out_dir = out # 输出路径
    if not os.path.exists(out_dir):
        try:
            os.mkdir(out_dir)
        except:
            print("check_output_dir!!")

    def lblsave(filename, lbl):
        if os.path.splitext(filename)[1] != '.png':
            filename += '.png'
        # Assume label ranses [-1, 254] for int32,
        # and [0, 255] for uint8 as VOC.
        if lbl.min() >= -1 and lbl.max() < 255:
            lbl_pil = PIL.Image.fromarray(lbl.astype(np.uint8), mode='P')
            colormap = label_colormap(255)
            lbl_pil.putpalette((colormap * 255).astype(np.uint8).flatten())
            lbl_pil.save(filename)
        else:
            raise ValueError(
                '[%s] Cannot save the pixel-wise class label as PNG. '
                'Please consider using the .npy format.' % filename
            )

    paths = glob.glob(json_file+"/*.json")
    if paths==[]:
        print("json is NULL")
    n=0
    for path in paths:
        if os.path.isfile(path):
            n+=1
            data = json.load(open(path))
            out_name = os.path.split(path)[-1][:-5]

        # imageData
        if data['imageData']:
            imageData = data['imageData']
        else:
            file_name = os.path.split(data['imagePath'])[-1]
            imagePath = os.path.join(image_file, file_name)     
            with open(imagePath, 'rb') as f:
                imageData = f.read()
                imageData = base64.b64encode(imageData).decode('utf-8')

        ''' img (array type)  '''
        img = img_b64_to_arr(imageData)

        # label_name
        for shape in data['shapes']:
            label_name = shape['label'] # 获取label名
            if label_name in label_name_to_value:
                label_value = label_name_to_value[label_name]
            else:
                label_value = len(label_name_to_value)
                label_name_to_value[label_name] = label_value

        label_values, label_names = [], []
        for ln, lv in sorted(label_name_to_value.items(), key=lambda x: x[1]):
            label_values.append(lv)
            label_names.append(ln)
            pass

        if label_values != list(range(len(label_values))):
            print("assert label_values")


        ''' label  : lbl已经为0,1,2,3,的标签了'''
        lbl = shapes_to_label(img.shape, data['shapes'], label_name_to_value)      

        ''' lbl_viz  ''' 
        captions = ['{}: {}'.format(lv, ln)
        for ln, lv in label_name_to_value.items()]
        lbl_viz = draw_label(lbl, img, captions)

        # 保存三个类型图 raw, segmentation, visualization
        print("第{:}张图".format(n))   
        JPEGImages_outdir = out_dir + "\JPEGImages"
        if not os.path.exists(JPEGImages_outdir):
            os.mkdir(JPEGImages_outdir)    
        img_out = out_dir + "\JPEGImages\\" + out_name  + '.jpg' # + str(n).zfill(4)
        PIL.Image.fromarray(img).save(img_out)              # 原图数据输出


        Segmentation_outdir = out_dir + "\SegmentationClassPNG"
        if not os.path.exists(Segmentation_outdir):
            os.mkdir(Segmentation_outdir)
        label_out = out_dir + "\SegmentationClassPNG\\" + out_name  + '.png' # + str(n).zfill(4)
        lblsave(label_out, lbl)                             # 标签数据输出


        Segmentation_label_outdir = out_dir + "\SegmentationClassRaw"
        if not os.path.exists(Segmentation_label_outdir):
            os.mkdir(Segmentation_label_outdir)
        label_int8_out = out_dir + "\SegmentationClassRaw\\" + out_name  + '.png' # + str(n).zfill(4)
        with tf.io.gfile.GFile(label_int8_out, mode='w') as f:
            Image.fromarray(lbl.astype(dtype=np.uint8)).save(f, 'PNG')


        Visualization_outdir = out_dir + "\SegmentationClassVisualization"
        if not os.path.exists(Visualization_outdir):
            os.mkdir(Visualization_outdir)    
        label_viz_out = out_dir + "\SegmentationClassVisualization\\" + out_name  + '.png' # + str(n).zfill(4)
        PIL.Image.fromarray(lbl_viz).save(label_viz_out)    # 掩码数据输出

        fpath,fname=os.path.split(path)             # 分离文件名和路径
        dstpath = os.path.join(out_dir, "Annotation")
        if not os.path.exists(dstpath):
            os.makedirs(dstpath)                       # 创建路径
        shutil.copyfile(path, os.path.join(dstpath,fname))

        with open(osp.join(out_dir, 'label_names.txt'), 'w') as f:
            for lbl_name in label_names:
                f.write(lbl_name + '\n')
                pass
            pass
        pass
    pass

def main_multiple():
    all_dir = glob.glob("data2/*")
    label_name_to_value = {'_background_': 0, "hub":1, "valve":2} # 指定类名对应的标签,bk总是为第0
    for i in range(len(all_dir)):
        json_file = all_dir[i]  # json 地址
        image_file = all_dir[i] # 图片地址
        out = os.path.join("./result", os.path.split(all_dir[i])[-1]) # 写入到result里面
        print(json_file, image_file, out)
        mk_dataset(json_file, image_file, out, label_name_to_value)    
        pass
    pass

def main():
    all_dir = glob.glob("JPEGImages/*")
    label_name_to_value = {'_background_': 0, "hub":1, "valve":2} # 指定类名对应的标签,bk总是为第0
    json_file = all_dir[i]  # json 地址
    image_file = all_dir[i] # 图片地址
    out = os.path.join("./result", os.path.split(all_dir[i])[-1]) # 写入到result里面
    print(json_file, image_file, out)
    mk_dataset(json_file, image_file, out, label_name_to_value)    
    pass

if __name__ == '__main__': 
    main_multiple()
    pass
上一篇下一篇

猜你喜欢

热点阅读