python充分利用多核性能预处理ImageNet数据集

2018-12-17  本文已影响0人  JiangPQ

TL;DR:用multiprocessing库解决python单线程处理大量图片缓慢的问题。

最近想试试HSI色彩空间的图片对卷积网络有没有帮助,就在每次加载数据的时候对每张图片做RGB到HSI的色彩空间变换。跑了几个epoch之后寻思着不对头,网络训练速度比原来慢了不少。这应该是因为数据预处理太占用CPU,感觉很不爽,于是想把整个ImageNet数据集提前处理好存下来,一劳永逸。

于是简单用python写了个脚本,遍历数据集,然后每张图片做好变换后按照原来的目录结构保存到新的根目录下。

实现很简单,但是跑了下一看,整个数据集跑完一遍竟然要17个小时......显然是因为python的GIL而无法充分利用CPU的多核性能。解决思路自然是利用真“多线程”来让程序跑起来。

python的thread是假线程,适合用在IO密集型的场景,对这种计算密集型的任务毫无帮助,而另一个multiprocessing自然就是解决方案了。运行机制简单地说就是产生一个进程池pool,pool提供一个map接口,把处理数据的函数接口和待处理的数据迭代器丢进去,进程池会自动分配多个进程执行,达到多进程的目的。当然multiprocessing库不止这么简单,还有更复杂的用法,这里并不需要所以不再深入。

最后附上代码:

import os
import tqdm
import itertools
import numpy as np
import multiprocessing as mp

from PIL import Image


def rgb2hsi(rgb):
    rgb /= 255.
    r, g, b = list(map(np.squeeze, np.split(rgb, 3, 2)))
    hsi = np.zeros_like(rgb)
    theta = np.arccos(((r - g) + (r - b)) / (2 * np.sqrt((r - g) ** 2 + (r - b) * (g - b))))
    pi_2 = 2 * np.pi
    hsi[:, :, 0] = np.where(g >= b, theta, pi_2 - theta) / pi_2
    hsi[:, :, 1] = 1 - 3 * np.min(rgb, 2) / np.sum(rgb, 2)
    hsi[:, :, 2] = np.sum(rgb, 2) / 3.

    return hsi * 255.


def resize_create_hsi_img(dir_pair):
    src_path, target_path_rgb, target_path_hsi = dir_pair
    rgb_not_exist = not os.path.exists(target_path_rgb)
    hsi_not_exist = not os.path.exists(target_path_hsi)
    try:
        if rgb_not_exist or hsi_not_exist:
            org_pic = Image.open(src_path)
            new_size = int(org_pic.size[0] * 0.7), int(org_pic.size[1] * 0.7)
            pic = org_pic.resize(new_size, Image.ANTIALIAS)
            if rgb_not_exist:
                pic.save(target_path_rgb, quality=75)
            if hsi_not_exist:
                if pic.mode == 'RGB':
                    rgb_img = np.asarray(pic, np.float32)
                    hsi_img = rgb2hsi(rgb_img).astype(np.uint8)
                    pic = Image.fromarray(hsi_img)
                pic.save(target_path_hsi, quality=75)
        return None
    except Exception as exc:
        print(exc)
        return src_path


def walk_all_pic():
    root = 'D:\Datasets\ImageNet\ILSVRC2017_CLS-LOC\ILSVRC\Data\CLS-LOC'
    targets = ['val', 'train', ]
    root_new1 = 'E:\Imagenet\cls_rgb'
    root_new2 = 'E:\Imagenet\cls_hsi'
    if not os.path.exists(root_new1):
        os.mkdir(root_new1)

    if not os.path.exists(root_new2):
        os.mkdir(root_new2)

    for t in targets:
        sub1 = os.path.join(root, t)
        sub1_new1 = os.path.join(root_new1, t)
        sub1_new2 = os.path.join(root_new2, t)
        folders = os.listdir(sub1)

        if not os.path.exists(sub1_new1):
            os.mkdir(sub1_new1)
        if not os.path.exists(sub1_new2):
            os.mkdir(sub1_new2)

        for subfolder in folders:
            sub2 = os.path.join(sub1, subfolder)
            sub2_new1 = os.path.join(sub1_new1, subfolder)
            sub2_new2 = os.path.join(sub1_new2, subfolder)
            if os.path.isdir(sub2):
                if not os.path.exists(sub2_new1):
                    os.mkdir(sub2_new1)
            if os.path.isdir(sub2):
                if not os.path.exists(sub2_new2):
                    os.mkdir(sub2_new2)

                files = os.listdir(sub2)
                for file in files:
                    fpath = os.path.join(sub2, file)
                    fpath_new_rgb = os.path.join(sub2_new1, file)
                    fpath_new_hsi = os.path.join(sub2_new2, file)
                    if os.path.isfile(fpath):
                        yield fpath, fpath_new_rgb, fpath_new_hsi

# Multiple process version
def run_multiprocess():

    print('Processing pictures with multiple processors...')
    error_pics = []
    with mp.Pool(processes=mp.cpu_count()) as pool:
        for ep in pool.imap_unordered(resize_create_hsi_img, tqdm.tqdm(walk_all_pic(), total=1331167, ncols=65)):
            error_pics.append(ep)
    with open('./error_pics.log', mode='w') as f:
        if error_pics is not None:
            f.writelines(error_pics)
    print('All pictures cannot be processed have been writen into \'error_pics.log\'')

# Single process version
def run():
    for f, fnew1, fnew2 in tqdm.tqdm(walk_all_pic()):
        resize_create_hsi_img((f, fnew1, fnew2))


if __name__ == '__main__':
    run_multiprocess()

上一篇 下一篇

猜你喜欢

热点阅读