图像分类问题数据集划分的函数封装

2021-04-21  本文已影响0人  1037号森林里一段干木头

简介:做图片数据预处理真的是太麻烦太累了,说是又脏又累真是一点问题都没有,现在想想那些现成的开源的数据集直接拿来用真是很棒棒了-._.-,拿来主义真香了。在训练不同的数据集时做准备数据的工作就让人头大,然后当你发现很多东西是重复的时候就更难过了。
这里把图片分类问题中划分训练集、测试集、验证集的过程封装打包一下,造个轮子。

1.场景:
你已经收集好了不同类型的图片,每一类图片在一个文件夹下,所有的类又在一个文件夹下。结构如下:


image.png

要把图片按一定的比例分装成如下结构:


image.png

2.代码

import os
import random
import shutil

class dataset():
    def __init__(self, root, dstPath):
        self.root = root #原始分类图片保存的文件夹
        self.dstPath = dstPath #划分好的数据集保存的路径
    
    def getClassList(self):
        #root下放着不同的文件夹,表示不同的类
        classList = []
        for item in os.listdir(self.root):
            subPath = os.path.join(self.root, item)
            if os.path.isdir(subPath):
                classList.append(item)#把文件夹名当类名
        return classList
    
    def createFolder(self,path):
        if os.path.exists(path):
            print("Path already exists!")
        else:
            try:
                os.makedirs(path)
                print("makeDirs success!")
            except:
                print("makeDirs ERROR!")
            
    def makedirs(self):
        classList =  self.getClassList()
        for item in classList:
            self.createFolder(os.path.join(self.dstPath,"train",item))
            self.createFolder(os.path.join(self.dstPath,"test",item))
            self.createFolder(os.path.join(self.dstPath,"val",item))
    
    def getAllImagePath(self,path):
        #只读取一级目录下的图片
        imagePathList = []
        externName = ["jpg","jpeg","png","bmp"]
        for item in os.listdir(path):
            srcPath = os.path.join(path,item)
            if os.path.isfile(srcPath):
                if item.rsplit(".",1)[-1] in externName:
                    imagePathList.append(srcPath)
        return imagePathList
    
    def createDataset(self,train_p=0.7, test_p=0.15, val_p=0.15):
        classList = self.getClassList()
        for item in classList:
            dstPath_train = os.path.join(self.dstPath, "train", item)
            dstPath_test = os.path.join(self.dstPath, "test", item)
            dstPath_val = os.path.join(self.dstPath, "val", item)
            imagePathList = []
            srcPath = os.path.join(self.root,item)
            imagePathList = self.getAllImagePath(srcPath)
            for imagePath in imagePathList:
                random_num = random.random()
                imageName = imagePath.rsplit("\\",1)[-1]
                if random_num <train_p:
                    shutil.copy(imagePath,os.path.join(dstPath_train,imageName))
                    print("copy src:{},----> dst:{}".format(imagePath,os.path.join(dstPath_train,imageName)))
                elif random_num < train_p+test_p:
                    shutil.copy(imagePath, os.path.join(dstPath_test,imageName))
                    print("copy src:{},----> dst:{}".format(imagePath,os.path.join(dstPath_test,imageName)))
                else:
                    shutil.copy(imagePath, os.path.join(dstPath_val,imageName))
                    print("copy src:{},----> dst:{}".format(imagePath,os.path.join(dstPath_val,imageName)))
                    
        return True

if __name__ == "__main__":
    root = "K:\imageData\polarity\positive_negative"#存放不同类的图片的根目录
    dstPath = "K:\imageData\polarity\data"#生成的数据集保存的目录
    tool = dataset(root,dstPath)#初始化类,需要传入两个参数
    #先创建文件夹,然后在把root根目录下的图片按比例分发到train、test、val文件夹
    tool.makedirs()
    #设置训练测试验证集的分配概率
    train_p = 0.7
    test_p = 0.15
    val_p = 0.15
    tool.createDataset(train_p,test_p,val_p)

3.运行示例:

上一篇 下一篇

猜你喜欢

热点阅读