Java并发

ForkJoin分析

2019-05-21  本文已影响0人  barry_di

一、ForkJoin

ForkJoin是由JDK1.7后提供多线并发处理框架。ForkJoin的框架的基本思想是分而治之。什么是分而治之?分而治之就是将一个复杂的计算,按照设定的阈值进行分解成多个计算,然后将各个计算结果进行汇总。相应的ForkJoin将复杂的计算当做一个任务。而分解的多个计算则是当做一个子任务。


image.png

二、ForkJoin的使用

private static class SumTask extends RecursiveTask<Integer> {

        private  int threshold ;
        private static final int segmentation = 10;

        private int[] src;

        private int fromIndex;
        private int toIndex;

        public SumTask(int formIndex,int toIndex,int[] src){
            this.fromIndex = formIndex;
            this.toIndex = toIndex;
            this.src = src;
            this.threshold = src.length/segmentation;
        }

        @Override
        protected Integer compute() {
            if((toIndex - fromIndex)<threshold ){
                int count = 0;
                System.out.println(" from index = "+fromIndex
                        +" toIndex="+toIndex);
                for(int i = fromIndex;i<=toIndex;i++){
                  count+=src[i];
                }
                return count;
            }else{
                int mid = (fromIndex+toIndex)/2;
                SumTask left =  new SumTask(fromIndex,mid,src);
                SumTask right = new SumTask(mid+1,toIndex,src);
                invokeAll(left,right);
                return left.join()+right.join();
            }
        }
    }
 public static void main(String[] args) {
            int[]  array = MakeArray.createIntArray();
        ForkJoinPool forkJoinPool= new ForkJoinPool();
        SumTask sumTask  = new SumTask(0,array.length-1,array);

        long start = System.currentTimeMillis();

        forkJoinPool.invoke(sumTask);
        System.out.println("The count is "+sumTask.join()
                +" spend time:"+(System.currentTimeMillis()-start)+"ms");

    }

三、RecursiveTask和RecursiveAction区别

public abstract class RecursiveTask<V> extends ForkJoinTask<V> {
    private static final long serialVersionUID = 5232453952276485270L;

    /**
     * The result of the computation.
     */
    V result;

    /**
     * The main computation performed by this task.
     * @return the result of the computation
     */
    protected abstract V compute();

    public final V getRawResult() {
        return result;
    }

    protected final void setRawResult(V value) {
        result = value;
    }

    /**
     * Implements execution conventions for RecursiveTask.
     */
    protected final boolean exec() {
        result = compute();
        return true;
    }

}
public abstract class RecursiveAction extends ForkJoinTask<Void> {
    private static final long serialVersionUID = 5232453952276485070L;

    /**
     * The main computation performed by this task.
     */
    protected abstract void compute();

    /**
     * Always returns {@code null}.
     *
     * @return {@code null} always
     */
    public final Void getRawResult() { return null; }

    /**
     * Requires null completion value.
     */
    protected final void setRawResult(Void mustBeNull) { }

    /**
     * Implements execution conventions for RecursiveActions.
     */
    protected final boolean exec() {
        compute();
        return true;
    }

}

ForkJoinTask是RecursiveAction与RecursiveTask的父类, ForkJoinTask中使用了模板模式进行设计
,将ForkJoinTask的执行相关的代码进行隐藏,通过提供抽象类暴露用户的实际业务处理。

三、ForJoin注意点

使用ForkJoin将相同的计算任务通过多线程的进行执行。从而能提高数据的计算速度。在google的中的大数据处理框架mapreduce就通过类似ForkJoin的思想。通过多线程提高大数据的处理。但是我们需要注意:

四、ForkJoin工作窃取(work-stealing)

为什么ForkJoin会存在工作窃取呢?因为我们将任务进行分解成多个子任务的时候。每个子任务的处理时间都不一样。例如分别有子任务A\B。如果子任务A的1ms的时候已经执行,子任务B还在执行。那么如果我们子任务A的线程等待子任务B完毕后在进行汇总,那么子任务A线程就会在浪费执行时间,最终的执行时间就以最耗时的子任务为准。而如果我们的子任务A执行完毕后,处理子任务B的任务,并且执行完毕后将任务归还给子任务B。这样就可以提高执行效率。而这种就是工作窃取。

五、ForkJoin排序

public class SortForkJoin {
    /**
     * 数组排序
     *
     * @param arry
     * @return
     */
    public static int[] sort(int[] arry) {
        if (arry.length == 0) return arry;
        for (int index = 0; index < arry.length - 1; index++) {
            int pre_index = index;
            int currentValue = arry[index + 1];
            while (pre_index >= 0 && arry[pre_index] > currentValue) {
                arry[pre_index + 1] = arry[pre_index];
                pre_index--;
            }
            arry[pre_index + 1] = currentValue;
        }
        return arry;
    }

    /**
     * 组合
     *
     * @param left
     * @param right
     * @return
     */
    public static int[] merge(int[] left, int[] right) {
        int[] result = new int[left.length + right.length];
        for (int resultIndex = 0, leftIndex = 0, rightIndex = 0; resultIndex < result.length; resultIndex++) {
            if (leftIndex >= left.length) {
                result[resultIndex] = right[rightIndex++];
            } else if (rightIndex >= right.length) {
                result[resultIndex] = left[leftIndex++];
            } else if (left[leftIndex] > right[rightIndex]) {
                result[resultIndex] = right[rightIndex++];
            } else {
                result[resultIndex] = left[leftIndex++];
            }
        }
        return result;
    }


     static  class SortTask extends RecursiveTask<int[]> {
        private int threshold;
        private int start;
        private int end;
        private int segmentation ;
        private int[] src;

        public SortTask(int[] src,int start,int end,int segmentation){
            this.src = src;
            this.start = start;
            this.end = end;
            this.threshold = src.length/segmentation;
            this.segmentation = segmentation;
        }
        @Override
        protected int[] compute() {
            if((end - start) <threshold){
               int mid =  (end-start)/2;
               SortTask leftTask = new SortTask(src,start,mid,segmentation);
               SortTask rightTask = new SortTask(src,mid+1,end,segmentation);
               invokeAll(leftTask,rightTask);
               return SortForkJoin.merge(leftTask.join(),rightTask.join());
            }else{
               return  SortForkJoin.sort(src);
            }
        }
    }

    @Test
    public void test() {
        int[]  array = MakeArray.createIntArray();
        ForkJoinPool forkJoinPool= new ForkJoinPool();
        SortTask sortTask =new SortTask(array,0,array.length-1,1000);
        long start = System.currentTimeMillis();
        forkJoinPool.execute(sortTask);
        System.out.println(
                " spend time:"+(System.currentTimeMillis()-start)+"ms");
    }

}


上一篇下一篇

猜你喜欢

热点阅读