多标签训练数据准备工作

2019-04-16  本文已影响0人  求索_700e

1. 作用:读取文件夹下的图片,类名就是最后一级文件夹的名称,名称只能是数字。

2. 将文件夹路径,类名写到txt中,类是one-hot编码后的。

import pickle as p

import numpy as np

import matplotlib.pyplot as plt

import matplotlib.image as plimg

from PIL import Image

import os

args={"train_test_val_dir": os.getcwd()+'/train/',

      "output_fname": "./txt/train_mul.txt",

      "delimiter":" ",

      "no_label":False,

      "num_class": 4

    }

def gen_binary(cls, is_all_zeros):

    binary=""

    if is_all_zeros==False:

        for i in range(0, args["num_class"]):

            if i==cls:

                if i==args["num_class"]-1:

                    binary=binary+"1"

                else:

                    binary=binary+"1"+","

            else:

                if i==args["num_class"]-1:

                    binary=binary+"0"

                else:

                    binary=binary+"0"+","

    else:

      for i in range(0, args["num_class"]):

            if i ==args["num_class"]-1:

                binary=binary+"0"

            else:

                binary=binary+"0"+","

    return binary

if __name__ == "__main__":

  if args["no_label"]==False:

      #pathes=os.listdir(args["train_test_val_dir"])

      f = open(args["output_fname"], "w")

      for root , dirs, files in os.walk(args["train_test_val_dir"]):

        print("'root=",root, "dirs=",dirs,"files=",files)

        for name in files:

            cls=root.split("/")[-1]

            srcfile=os.path.join(root, name)

            print(srcfile)

            if cls == "blank": 

                binaries=gen_binary(0, True) ##could be any number, here is 0

            else:

                binaries=gen_binary(int(cls), False)

                print(binaries)

            f.write(srcfile+args["delimiter"]+binaries+"\n")

      f.close()

      print ("保存完毕.")

      '''

      else:

        imgs=os.listdir(args["train_test_val_dir"])

        f = open(args["output_fname"], "w")

        for i in imgs:##imgX.shape的第一个维度是batch

              path=args["train_test_val_dir"]+i

              print(path)

              f.write(path+" "+"\n")

        f.close()

        print ("保存完毕.")

      '''

上一篇 下一篇

猜你喜欢

热点阅读