线程池方案

Java技术专题「并发编程专题」Fork/Join框架基本使用和

2021-09-10  本文已影响0人  洛神灬殇

前提概述

Java 7开始引入了一种新的Fork/Join线程池,它可以执行一种特殊的任务:把一个大任务拆成多个小任务并行执行。

我们举个例子:如果要计算一个超大数组的和,最简单的做法是用一个循环在一个线程内完成:

算法原理介绍

相信大家此前或多或少有了解到ForkJoin框架,ForkJoin框架其实就是一个线程池ExecutorService的实现,通过工作窃取(work-stealing)算法,获取其他线程中未完成的任务来执行。可以充分利用机器的多处理器优势,利用空闲的线程去并行快速完成一个可拆分为小任务的大任务,类似于分治算法

实现达成目标

基本使用

入门例子,用Fork/Join框架使用示例,在这个示例中我们计算了1-5000累加后的值

public class TestForkAndJoinPlus {
    private static final Integer MAX = 400;
    static class WorkTask extends RecursiveTask<Integer> {
        // 子任务开始计算的值
        private Integer startValue;
        // 子任务结束计算的值
        private Integer endValue;
        public WorkTask(Integer startValue , Integer endValue) {
            this.startValue = startValue;
            this.endValue = endValue;
        }
        @Override
        protected Integer compute() {
            // 如果小于最小分片阈值,则说明要进行相关的数据操作
            // 可以正式进行累加计算了
            if(endValue - startValue < MAX) {
                System.out.println("开始计算的部分:startValue = " + startValue + ";endValue = " + endValue);
                Integer totalValue = 0;
                for(int index = this.startValue ; index <= this.endValue  ; index++) {
                    totalValue += index;
                }
                return totalValue;
            }
            // 否则再进行任务拆分,拆分成两个任务
            else {
                 // 因为采用二分法,拆分,所以进行1/2切分数据量
                WorkTask subTask1 = new WorkTask(startValue, (startValue + endValue) / 2);
                subTask1.fork();//进行拆分机制控制
                WorkTask subTask2 = new WorkTask((startValue + endValue) / 2 + 1 , endValue);
                subTask2.fork();
                return subTask1.join() + subTask2.join();
            }
        }
    }
    public static void main(String[] args) {
        // 这是Fork/Join框架的线程池
        ForkJoinPool pool = new ForkJoinPool();
        ForkJoinTask<Integer> taskFuture =  pool.submit(new MyForkJoinTask(1,1001));
        try {
            Integer result = taskFuture.get();
            System.out.println("result = " + result);
        } catch (InterruptedException | ExecutionException e) {
            e.printStackTrace(System.out);
        }
    }
}

对此我封装了一个框架集合,基于JDK1.8+中的Fork/Join框架实现,参考的Fork/Join框架主要源代码也基于JDK1.8+。

WorkTaskCallable实现抽象模型层次操作转换

@Accessors(chain = true)
public class WorkTaskCallable<T> extends RecursiveTask<T> {

    /**
     * 断言操作控制
     */
    @Getter
    private Predicate<T> predicate;

    /**
     * 执行参数化分割条件
     */
    @Getter
    private T splitParam;

    /**
     * 操作拆分方法操作机制
     */
    @Getter
    private Function<Object,Object[]> splitFunction;

    /**
     * 操作合并方法操作机制
     */
    @Getter
    private BiFunction<Object,Object,T> mergeFunction;

    /**
     * 操作处理机制
     */
    @Setter
    @Getter
    private Function<T,T> processHandler;


    /**
     * 构造器是否进行分割操作
     * @param predicate 判断是否进行下一步分割的条件关系
     * @param splitParam 分割参数
     * @param splitFunction 分割方法
     * @param mergeFunction 合并数据操作
     */
    public WorkTaskCallable(Predicate predicate,T splitParam,Function<Object,Object[]> splitFunction,BiFunction<Object,Object,T> mergeFunction,Function<T,T> processHandler){
        this.predicate = predicate;
        this.splitParam = splitParam;
        this.splitFunction = splitFunction;
        this.mergeFunction = mergeFunction;
        this.processHandler = processHandler;
    }

    /**
     * 实际执行调用操作机制
     * @return
     */
    @Override
    protected T compute() {
        if(predicate.test(splitParam)){
            Object[] result = splitFunction.apply(splitParam);
            WorkTaskCallable workTaskCallable1 = new WorkTaskCallable(predicate,result[0],splitFunction,mergeFunction,processHandler);
            workTaskCallable1.fork();
            WorkTaskCallable workTaskCallable2 = new WorkTaskCallable(predicate,result[1],splitFunction,mergeFunction,processHandler);
            workTaskCallable2.fork();
            return mergeFunction.apply(workTaskCallable1.join(),workTaskCallable2.join());
        }else{
            return processHandler.apply(splitParam);
        }
    }
}

ArrayListWorkTaskCallable实现List集合层次操作转换


/**
 * @project-name:wiz-shrding-framework
 * @package-name:com.wiz.sharding.framework.boot.common.thread.forkjoin
 * @author:LiBo/Alex
 * @create-date:2021-09-09 17:26
 * @copyright:libo-alex4java
 * @email:liboware@gmail.com
 * @description:
 */
public class ArrayListWorkTaskCallable extends WorkTaskCallable<List>{



    static Predicate<List> predicateFunction = param->param.size() > 3;


    static Function<List,List[]> splitFunction = (param)-> {
        if(predicateFunction.test(param)){
            return new List[]{param.subList(0,param.size()/ 2),param.subList(param.size()/2,param.size())};
        }else{
            return new List[]{param.subList(0,param.size()+1),Lists.newArrayList()};
        }
    };

    static BiFunction<List,List,List> mergeFunction = (param1,param2)->{
        List datalist = Lists.newArrayList();
        datalist.addAll(param2);
        datalist.addAll(param1);
        return datalist;
    };


    /**
     * 构造器是否进行分割操作
     * @param predicate     判断是否进行下一步分割的条件关系
     * @param splitParam    分割参数
     * @param splitFunction 分割方法
     * @param mergeFunction 合并数据操作
     */
    public ArrayListWorkTaskCallable(Predicate<List> predicate, List splitParam, Function splitFunction, BiFunction mergeFunction,
                                     Function<List,List> processHandler) {
        super(predicate, splitParam, splitFunction, mergeFunction,processHandler);
    }




    public ArrayListWorkTaskCallable(List splitParam, Function splitFunction, BiFunction mergeFunction,
                                     Function<List,List> processHandler) {
        super(predicateFunction, splitParam, splitFunction, mergeFunction,processHandler);
    }


    public ArrayListWorkTaskCallable(Predicate<List> predicate,List splitParam,Function<List,List> processHandler) {
        this(predicate, splitParam, splitFunction, mergeFunction,processHandler);
    }


    public ArrayListWorkTaskCallable(List splitParam,Function<List,List> processHandler) {
        this(predicateFunction, splitParam, splitFunction, mergeFunction,processHandler);
    }



    public static void main(String[] args){
        List dataList = Lists.newArrayList(0,1,2,3,4,5,6,7,8,9);
        ForkJoinPool forkJoinPool = ForkJoinPool.commonPool();
        ForkJoinTask<List> forkJoinResult = forkJoinPool.submit(new ArrayListWorkTaskCallable(dataList,param->Lists.newArrayList(param.size())));
        try {
            System.out.println(forkJoinResult.get());
        } catch (InterruptedException e) {
            e.printStackTrace();
        } catch (ExecutionException e) {
            e.printStackTrace();
        }
    }

ForkJoin代码分析

ForkJoinPool构造函数
  /**
     * Creates a {@code ForkJoinPool} with parallelism equal to {@link
     * java.lang.Runtime#availableProcessors}, using the {@linkplain
     * #defaultForkJoinWorkerThreadFactory default thread factory},
     * no UncaughtExceptionHandler, and non-async LIFO processing mode.
     *
     * @throws SecurityException if a security manager exists and
     *         the caller is not permitted to modify threads
     *         because it does not hold {@link
     *         java.lang.RuntimePermission}{@code ("modifyThread")}
     */
    public ForkJoinPool() {
        this(Math.min(MAX_CAP, Runtime.getRuntime().availableProcessors()),
             defaultForkJoinWorkerThreadFactory, null, false);
    }

    /**
     * Creates a {@code ForkJoinPool} with the indicated parallelism
     * level, the {@linkplain
     * #defaultForkJoinWorkerThreadFactory default thread factory},
     * no UncaughtExceptionHandler, and non-async LIFO processing mode.
     *
     * @param parallelism the parallelism level
     * @throws IllegalArgumentException if parallelism less than or
     *         equal to zero, or greater than implementation limit
     * @throws SecurityException if a security manager exists and
     *         the caller is not permitted to modify threads
     *         because it does not hold {@link
     *         java.lang.RuntimePermission}{@code ("modifyThread")}
     */
    public ForkJoinPool(int parallelism) {
        this(parallelism, defaultForkJoinWorkerThreadFactory, null, false);
    }

    /**
     * Creates a {@code ForkJoinPool} with the given parameters.
     *
     * @param parallelism the parallelism level. For default value,
     * use {@link java.lang.Runtime#availableProcessors}.
     * @param factory the factory for creating new threads. For default value,
     * use {@link #defaultForkJoinWorkerThreadFactory}.
     * @param handler the handler for internal worker threads that
     * terminate due to unrecoverable errors encountered while executing
     * tasks. For default value, use {@code null}.
     * @param asyncMode if true,
     * establishes local first-in-first-out scheduling mode for forked
     * tasks that are never joined. This mode may be more appropriate
     * than default locally stack-based mode in applications in which
     * worker threads only process event-style asynchronous tasks.
     * For default value, use {@code false}.
     * @throws IllegalArgumentException if parallelism less than or
     *         equal to zero, or greater than implementation limit
     * @throws NullPointerException if the factory is null
     * @throws SecurityException if a security manager exists and
     *         the caller is not permitted to modify threads
     *         because it does not hold {@link
     *         java.lang.RuntimePermission}{@code ("modifyThread")}
     */
    public ForkJoinPool(int parallelism,
                        ForkJoinWorkerThreadFactory factory,
                        UncaughtExceptionHandler handler,
                        boolean asyncMode) {
        this(checkParallelism(parallelism),
             checkFactory(factory),
             handler,
             (asyncMode ? FIFO_QUEUE : LIFO_QUEUE),
             "ForkJoinPool-" + nextPoolId() + "-worker-");
        checkPermission();
    }

    /**
     * Creates a {@code ForkJoinPool} with the given parameters, without
     * any security checks or parameter validation.  Invoked directly by
     * makeCommonPool.
     */
    private ForkJoinPool(int parallelism,
                         ForkJoinWorkerThreadFactory factory,
                         UncaughtExceptionHandler handler,
                         int mode,
                         String workerNamePrefix) {
        this.workerNamePrefix = workerNamePrefix;
        this.factory = factory;
        this.ueh = handler;
        this.mode = (short)mode;
        this.parallelism = (short)parallelism;
        long np = (long)(-parallelism); // offset ctl counts
        this.ctl = ((np << AC_SHIFT) & AC_MASK) | ((np << TC_SHIFT) & TC_MASK);
    }

在Fork/Join框架中有一个默认的ForkJoinWorkerThreadFactory接口实现:DefaultForkJoinWorkerThreadFactory。


需要注意点


image
ForkJoinPool类的属性介绍
WorkQueue类

image

ForkJoinTask是能够在ForkJoinPool中执行的任务抽象类,父类是Future,具体实现类有很多,这里主要关注RecursiveAction和RecursiveTask。

ForkJoinTask类属性的介绍

status: 任务的状态,对其他工作线程和pool可见,运行正常则status为负数,异常情况为正数。

ForkJoinTask功能介绍

只需要实现其compute()方法,在compute()中做最小任务控制,任务分解(fork)和结果合并(join)。

image

ForkJoinPool中执行的默认线程是ForkJoinWorkerThread,由默认工厂产生,可以自己重写要实现的工作线程。同时会将ForkJoinPool引用放在每个工作线程中,供工作窃取时使用。

ForkJoinWorkerThread类属性介绍

简易执行图

image

实际上Fork/Join框架的内部工作过程要比这张图复杂得多,例如如何决定某一个recursive task是使用哪条线程进行运行;再例如如何决定当一个任务/子任务提交到Fork/Join框架内部后,是创建一个新的线程去运行还是让它进行队列等待。

逻辑模型图(盗一张图:)

盗一张图:

()

fork方法和join方法

Fork/Join框架中提供的fork方法和join方法,可以说是该框架中提供的最重要的两个方法,它们和parallelism“可并行任务数量”配合工作。

Fork方法介绍

当一个ForkJoinTask任务调用fork()方法时,当前线程会把这个任务放入到queue数组的queueTop位置,然后执行以下两句代码:

if ((s -= queueBase) <= 2)
    pool.signalWork();
else if (s == m)
    growQueue();

当调用signalWork()方法。signalWork()方法做了两件事:1、唤配当前线程;2、当没有活动线程时或者线程数较少时,添加新的线程。


Join方法介绍

Join是一个不断等待,获取任务执行结果的过程。

private int doJoin() {
    Thread t; ForkJoinWorkerThread w; int s; boolean completed;
    if ((t = Thread.currentThread()) instanceof ForkJoinWorkerThread) {
        if ((s = status) < 0)
            return s;
        if ((w = (ForkJoinWorkerThread)t).unpushTask(this)) {
            try {
                completed = exec();
            } catch (Throwable rex) {
                return setExceptionalCompletion(rex);
            }
            if (completed)
                return setCompletion(NORMAL);
        }
        return w.joinTask(this);
    }
    else
        return externalAwaitDone();
}
final int joinTask(ForkJoinTask<?> joinMe) {
    ForkJoinTask<?> prevJoin = currentJoin;
    currentJoin = joinMe;
    for (int s, retries = MAX_HELP;;) {
        if ((s = joinMe.status) < 0) {
            currentJoin = prevJoin;
            return s;
        }
        if (retries > 0) {
            if (queueTop != queueBase) {
                if (!localHelpJoinTask(joinMe))
                    retries = 0;           // cannot help
            }
            else if (retries == MAX_HELP >>> 1) {
                --retries;                 // check uncommon case
                if (tryDeqAndExec(joinMe) >= 0)
                    Thread.yield();        // for politeness
            }
            else
                retries = helpJoinTask(joinMe) ? MAX_HELP : retries - 1;
        }
        else {
            retries = MAX_HELP;           // restart if not done
            pool.tryAwaitJoin(joinMe);
        }
    }
}
outer:for (ForkJoinWorkerThread thread = this;;) {
    // Try to find v, the stealer of task, by first using hint
    ForkJoinWorkerThread v = ws[thread.stealHint & m];
    if (v == null || v.currentSteal != task) {
        for (int j = 0; ;) {        // search array
            if ((v = ws[j]) != null && v.currentSteal == task) {
                thread.stealHint = j;
                break;              // save hint for next time
            }
            if (++j > m)
                break outer;        // can't find stealer
        }
    }
    // Try to help v, using specialized form of deqTask
    for (;;) {
        ForkJoinTask<?>[] q; int b, i;
        if (joinMe.status < 0)
            break outer;
        if ((b = v.queueBase) == v.queueTop ||
            (q = v.queue) == null ||
            (i = (q.length-1) & b) < 0)
            break;                  // empty
        long u = (i << ASHIFT) + ABASE;
        ForkJoinTask<?> t = q[i];
        if (task.status < 0)
            break outer;            // stale
        if (t != null && v.queueBase == b &&
            UNSAFE.compareAndSwapObject(q, u, t, null)) {
            v.queueBase = b + 1;
            v.stealHint = poolIndex;
            ForkJoinTask<?> ps = currentSteal;
            currentSteal = t;
            t.doExec();
            currentSteal = ps;
            helped = true;
        }
    }
    // Try to descend to find v's stealer
    ForkJoinTask<?> next = v.currentJoin;
    if (--levels > 0 && task.status >= 0 &&
        next != null && next != task) {
        task = next;
        thread = v;
    }
}
上一篇 下一篇

猜你喜欢

热点阅读