数据集处理工具

2019-10-18  本文已影响0人  张亦风

划分训练测试集

import os
import random
full_path="full.txt"
train_path="save_dir/train.txt"
test_path="save_dir/test.txt"
with open(full_path,"r")as f1:
    line=[i.replace("\n","") for i in f1.readlines()]
random.shuffle(line)
trian_num=int(len(line)*0.9)
for i in line[:train]:
    with open(train_path,"a+") as f:
        f.write(i+"\n")
for i in line[train:]:
    with open(test_path,"a+") as f:
        f.write(i+"\n")

修改json文件

import os,sys
import json
m_path='conf.json'
def get_new_json(filepath):
    with open(filepath, 'rb') as f:
        data = json.load(f)
        data["conf"][1]["value"]=0.99
        data["conf"][2]["value"]=0.75
    return data
def rewrite_json_file(filepath,json_data):
    with open(filepath, 'w') as f:
        json.dump(json_data,f)
if __name__ == '__main__':
    m_json_data = get_new_json(m_path)
    rewrite_json_file("/disk1/1.json",m_json_data)

dict to json

import json

data = {
    'name' : {"0":{"name":'ACME',"age":16}},
    'shares' : 100,
    'price' : 542.23
}
with open("2.json","w") as f:
    
    json.dump(data,f)

base.txt2yolo

import os
import cv2
from random import randint

base_labels_dir = '/home/zhangyong/zhangyong/ssd/caffe/data/VOCdevkit/VOC2007/base'
images_dir = '/home/zhangyong/zhangyong/ssd/caffe/data/VOCdevkit/VOC2007/JPEGImages'

alllabels = os.listdir(base_labels_dir)
allimages = os.listdir(images_dir)

trainf = open('train3.txt','a+')
testf = open('test3.txt','a+')

for jpg in allimages:
    
    txt = jpg.replace('jpg','txt')
    if txt not in alllabels:
        continue
    if randint(0,10) is 1:
        testf.write('/home/zhangyong/zhangyong/ssd/caffe/data/VOCdevkit/VOC2007/JPEGImages/%s\n' % jpg)
    else:
        trainf.write('/home/zhangyong/zhangyong/ssd/caffe/data/VOCdevkit/VOC2007/JPEGImages/%s\n' % jpg)
    
    jpg = os.path.join(images_dir, jpg)
    jpg = cv2.imread(jpg)
    h,w,_=jpg.shape
    
    labf = open(os.path.join('/home/zhangyong/zhangyong/ssd/caffe/data/VOCdevkit/VOC2007/Label',txt), 'w')
    txt = os.path.join(base_labels_dir, txt)
    txt = open(txt, 'r')
    lines = txt.readlines()
    for line in lines:
        box = [int(i) for i in line.split(' ')]
        box[1] = max(0, box[2])
        box[2] = max(0, box[1])
        box[3] = min(w-1, box[4])
        box[4] = min(h-1, box[3])

        cx = (box[1]+box[3])/2
        cy = (box[2]+box[4])/2
        cw = box[3]-box[1]
        ch = box[4]-box[2]

        cx = float(cx)/float(w)
        cy = float(cy)/float(h)
        cw = float(cw)/float(w)
        ch = float(ch)/float(h)

        labf.write('%d %.6f %.6f %.6f %.6f\n' % (box[0]-1,cx,cy,cw,ch))
    labf.close()
    
trainf.close()
testf.close()

txt2xml

import os
from PIL import Image
import cv2

out0 ='''<?xml version="1.0" encoding="utf-8"?>
<annotation>
    <folder>None</folder>
    <filename>%(name)s</filename>
    <source>
        <database>None</database>
        <annotation>None</annotation>
        <image>None</image>
        <flickrid>None</flickrid>
    </source>
    <owner>
        <flickrid>None</flickrid>
        <name>None</name>
    </owner>
    <segmented>0</segmented>
    <size>
        <width>%(width)d</width>
        <height>%(height)d</height>
        <depth>3</depth>
    </size>
'''
out1 = '''  <object>
        <name>%(class)s</name>
        <pose>Unspecified</pose>
        <truncated>0</truncated>
        <difficult>0</difficult>
        <bndbox>
            <xmin>%(xmin)d</xmin>
            <ymin>%(ymin)d</ymin>
            <xmax>%(xmax)d</xmax>
            <ymax>%(ymax)d</ymax>
        </bndbox>
    </object>
'''

out2 = '''</annotation>
'''
def translate(lists): 
    source = {}
    label = {}
    for jpg in lists:
        if os.path.splitext(jpg)[1] == '.jpg':
            print(jpg)
            jpg=jpg.replace('darknet','ssd/caffe')
            img=cv2.imread(jpg)
            h,w,_=img.shape[:]
            fxml = jpg.replace('ImageSets','Annotations')
            fxml = fxml.replace('.jpg','.xml')
            fxml = open(fxml, 'w');

            imgfile = jpg.split('/')[-1]
            source['name'] = imgfile
            source['width'] = w
            source['height'] = h

            fxml.write(out0 % source)

            txt = jpg.replace('.jpg','.txt')
            txt=txt.replace("ImageSets",'base')
            with open(txt,'r') as f:
                lines = [i.replace('\n','') for i in f.readlines()]
                print(lines)
            for box in lines:
                box = box.split(' ')
                name=int(box[0])
                label['class'] =name             
                
                label['xmin'] = max(int(box[2]),0)
                label['ymin'] = max(int(box[1]),0)
                label['xmax'] = min(int(box[4]),w-1)
                label['ymax'] = min(int(box[3]),h-1)
                
                if label['xmin']>=w or label['ymin']>=h or label['xmax']>=w or label['ymax']>=h:
                    continue
                if label['xmin']<0 or label['ymin']<0 or label['xmax']<0 or label['ymax']<0:
                    continue
                    
                fxml.write(out1 % label)
                
            fxml.write(out2)

if __name__ == '__main__':
    with open('/home/zhangyong/newdisk/company/week/codes/tools/train.txt','r') as f:
        lines = [i.replace('\n','') for i in f.readlines()]
        
    translate(lines)

上一篇下一篇

猜你喜欢

热点阅读