java技术专栏

ForkJoinPool源码解析

2019-03-30  本文已影响0人  liujianhuiouc

初识ForkJoinPool

java中的Executor相比大家都很熟悉,它是一种执行器。日常工作中比较容易见到的就是ThreadPoolExecutor,提供了线程池模型,程序开发者只需要通过相关接口就可以开发任务的执行和调度。此处我们不在详细介绍ThreadPoolExecutor的使用方式,感兴趣的可以查看相关的资料。既然已经有了这么好用的执行器,为何又要引用ForkJoinPool,ForkJoinPool从名字来看就是将任务Fork(拆分)为小任务,再通过Join(合并)来汇总小任务的结果,总体思想和MapReduce很像,是一个单机版的MapReduce实现方式。

ForkJoin例子
图片截取原文地址

ForkJoinPool的例子

ThreadPoolExecutor可以执行Runnable或者是Callable的任务,ForkJoinPool接收ForkJoinTask子类的任务,通常我们不需要直接实现ForkJoinTask这个抽象类,JDK为我们提供了RecursiveActionRecursiveTask两个类,分别对应没有返回值和有返回值的场景。接下来我们通过一个例子来说明下如何使用ForkJoinPool。

给定一段区间,计算出这段区间内所有数的累加和?

public class CountTask extends RecursiveTask<Integer> {
    private int start;

    private int end;

    private int threshold;

    public CountTask(int start, int end, int threshold) {
        this.start = start;
        this.end = end;
        this.threshold = threshold;
    }

    @Override
    protected Integer compute() {
        System.out.println(Thread.currentThread().getName());
        if (end - start <= threshold) {

            int sum = 0;
            for (int i = start; i < end; ++i){
                sum += i;
            }
            return sum;
        } else {
            int median = (start + end) / 2;
            CountTask left = new CountTask(start, median, threshold);
            CountTask right = new CountTask(median, end, threshold);

            left.fork();
            right.fork();

            int leftValue = left.join();
            int rightValue = right.join();

            return leftValue + rightValue;
        }
    }
}

CountTask将一段区间细分为两个子区间分别计算,然后在累加这两段区间分别的求和。我们通过ForkJoinPool来执行这段代码。

    private static void testRecurisiveTask() throws InterruptedException, java.util.concurrent.ExecutionException {
        CountTask countTask = new CountTask(0, 1000000, 100);
        ForkJoinPool forkJoinPool = new ForkJoinPool(1);

        forkJoinPool.submit(countTask);
        System.out.println(countTask.get());
    }

代码很简单,相信各位读者立马就知道怎么用ForkJoinPool,笔者在参考写完这段代码后至少产生了以下几个疑问?

  1. 能否通过ThreadPoolExecutor来实现?
  2. 如果只有一个线程的情况下,该线程已经在执行CountTask这个任务,又是通过什么方式来执行子任务的呢?
  3. fork和join方法具体是在完成了哪些操作?
    大家肯定也有自己的问题,记录下来,看下面的内容能否回答大家的问题,不能回答的话,大家可以留言给我。

ForkJoinPool深度剖析

我们先来看一下ForkJoinPool这个框架涉及的几个重要部分。
WorkeQueue: Task队列。
ForkJoinThread: 处理线程
ForkJoinPool: 包括处理线程和线程队列。

submit()方法详解

    public <T> ForkJoinTask<T> submit(ForkJoinTask<T> task) {
        if (task == null)
            throw new NullPointerException();
        externalPush(task);
        return task;
    }

// 1. 如果workqueues未初始化,初始化WorkQueue[]数组
// 2. 根据当前线程的探针值获取在workequeue[]数组中的位置,如果对应的WorkQueue是空,则new一个新的WorkQueue,然后将task放置在WorkQueue中ForkJoinTask[]数组中
// 3. signalWork开启启动相关线程开启任务的处理
Tips: Doug Lea大神大部分都是通过Unsafe的相关方法来进行赋值操作,这种方式直接操作内存,效率很高,但是很容易出错,不建议大家平时使用

 private void externalSubmit(ForkJoinTask<?> task) {
        int r;                                    // initialize caller's probe
        if ((r = ThreadLocalRandom.getProbe()) == 0) {
            ThreadLocalRandom.localInit();
            r = ThreadLocalRandom.getProbe();
        }
        for (;;) {
            WorkQueue[] ws; WorkQueue q; int rs, m, k;
            boolean move = false;
            if ((rs = runState) < 0) {
                tryTerminate(false, false);     // help terminate
                throw new RejectedExecutionException();
            }
            else if ((rs & STARTED) == 0 ||     // initialize
                     ((ws = workQueues) == null || (m = ws.length - 1) < 0)) {
                int ns = 0;
                rs = lockRunState();
                try {
                    if ((rs & STARTED) == 0) {
                        U.compareAndSwapObject(this, STEALCOUNTER, null,
                                               new AtomicLong());
                        // create workQueues array with size a power of two
                        int p = config & SMASK; // ensure at least 2 slots
                        int n = (p > 1) ? p - 1 : 1;
                        n |= n >>> 1; n |= n >>> 2;  n |= n >>> 4;
                        n |= n >>> 8; n |= n >>> 16; n = (n + 1) << 1;
                        workQueues = new WorkQueue[n];
                        ns = STARTED;
                    }
                } finally {
                    unlockRunState(rs, (rs & ~RSLOCK) | ns);
                }
            }
            else if ((q = ws[k = r & m & SQMASK]) != null) {
                if (q.qlock == 0 && U.compareAndSwapInt(q, QLOCK, 0, 1)) {
                    ForkJoinTask<?>[] a = q.array;
                    int s = q.top;
                    boolean submitted = false; // initial submission or resizing
                    try {                      // locked version of push
                        if ((a != null && a.length > s + 1 - q.base) ||
                            (a = q.growArray()) != null) {
                            int j = (((a.length - 1) & s) << ASHIFT) + ABASE;
                            U.putOrderedObject(a, j, task);
                            U.putOrderedInt(q, QTOP, s + 1);
                            submitted = true;
                        }
                    } finally {
                        U.compareAndSwapInt(q, QLOCK, 1, 0);
                    }
                    if (submitted) {
                        signalWork(ws, q);
                        return;
                    }
                }
                move = true;                   // move on failure
            }
            else if (((rs = runState) & RSLOCK) == 0) { // create new queue
                q = new WorkQueue(this, null);
                q.hint = r;
                q.config = k | SHARED_QUEUE;
                q.scanState = INACTIVE;
                rs = lockRunState();           // publish index
                if (rs > 0 &&  (ws = workQueues) != null &&
                    k < ws.length && ws[k] == null)
                    ws[k] = q;                 // else terminated
                unlockRunState(rs, rs & ~RSLOCK);
            }
            else
                move = true;                   // move if busy
            if (move)
                r = ThreadLocalRandom.advanceProbe(r);
        }
    }

ForkJoinThread获取任务进行处理

    /**
     * This method is required to be public, but should never be
     * called explicitly. It performs the main run loop to execute
     * {@link ForkJoinTask}s.
     */
    public void run() {
        if (workQueue.array == null) { // only run once
            Throwable exception = null;
            try {
                onStart();
                pool.runWorker(workQueue);
            } catch (Throwable ex) {
                exception = ex;
            } finally {
                try {
                    onTermination(exception);
                } catch (Throwable ex) {
                    if (exception == null)
                        exception = ex;
                } finally {
                    pool.deregisterWorker(this, exception);
                }
            }
        }
    }

    final void runWorker(WorkQueue w) {
        w.growArray();                   // allocate queue
        int seed = w.hint;               // initially holds randomization hint
        int r = (seed == 0) ? 1 : seed;  // avoid 0 for xorShift
        for (ForkJoinTask<?> t;;) {
            if ((t = scan(w, r)) != null)
                w.runTask(t);
            else if (!awaitWork(w, r))
                break;
            r ^= r << 13; r ^= r >>> 17; r ^= r << 5; // xorshift
        }
    }

// 1. 从workqueue[]数组中获取一个ForkJoinTask,获取的规则是根据当前线程的随机数和数组的长度通过与操作来获取下标(就是一种hash方式)
// 2. 找到一个有ForkJoinTask的WorkQueue来获取相关的task,可以通过WorkQueue中的base和top两个变量来判断任务的个数
// 3. 如果遍历一圈过后依然没有可用的task,直接返回null,遍历一圈的逻辑是通过oldsum和checksum来判断的,判断过一次,则新增一次sum,如果遍历的过程中有新增的task,则会再判断一圈

private ForkJoinTask<?> scan(WorkQueue w, int r) {
        WorkQueue[] ws; int m;
        if ((ws = workQueues) != null && (m = ws.length - 1) > 0 && w != null) {
            int ss = w.scanState;                     // initially non-negative
            for (int origin = r & m, k = origin, oldSum = 0, checkSum = 0;;) {
                WorkQueue q; ForkJoinTask<?>[] a; ForkJoinTask<?> t;
                int b, n; long c;
                if ((q = ws[k]) != null) {
                    if ((n = (b = q.base) - q.top) < 0 &&
                        (a = q.array) != null) {      // non-empty
                        long i = (((a.length - 1) & b) << ASHIFT) + ABASE;
                        if ((t = ((ForkJoinTask<?>)
                                  U.getObjectVolatile(a, i))) != null &&
                            q.base == b) {
                            if (ss >= 0) {
                                if (U.compareAndSwapObject(a, i, t, null)) {
                                    q.base = b + 1;
                                    if (n < -1)       // signal others
                                        signalWork(ws, q);
                                    return t;
                                }
                            }
                            else if (oldSum == 0 &&   // try to activate
                                     w.scanState < 0)
                                tryRelease(c = ctl, ws[m & (int)c], AC_UNIT);
                        }
                        if (ss < 0)                   // refresh
                            ss = w.scanState;
                        r ^= r << 1; r ^= r >>> 3; r ^= r << 10;
                        origin = k = r & m;           // move and rescan
                        oldSum = checkSum = 0;
                        continue;
                    }
                    checkSum += b;
                }
                if ((k = (k + 1) & m) == origin) {    // continue until stable
                    if ((ss >= 0 || (ss == (ss = w.scanState))) &&
                        oldSum == (oldSum = checkSum)) {
                        if (ss < 0 || w.qlock < 0)    // already inactive
                            break;
                        int ns = ss | INACTIVE;       // try to inactivate
                        long nc = ((SP_MASK & ns) |
                                   (UC_MASK & ((c = ctl) - AC_UNIT)));
                        w.stackPred = (int)c;         // hold prev stack top
                        U.putInt(w, QSCANSTATE, ns);
                        if (U.compareAndSwapLong(this, CTL, c, nc))
                            ss = ns;
                        else
                            w.scanState = ss;         // back out
                    }
                    checkSum = 0;
                }
            }
        }
        return null;

fork方法详解

    public final ForkJoinTask<V> fork() {
        Thread t;
        if ((t = Thread.currentThread()) instanceof ForkJoinWorkerThread)
            ((ForkJoinWorkerThread)t).workQueue.push(this);
        else
            ForkJoinPool.common.externalPush(this);
        return this;
    }

fork方法比较简单,如果当前的线程是ForkJoinWokerThread,则直接放在这个线程对应的workqueue中,否则跟submit的方式一样,为什么要放在自己workqueue中,在后面的join中再详解

join方法详解

    public final V join() {
        int s;
        if ((s = doJoin() & DONE_MASK) != NORMAL)
            reportException(s);
        return getRawResult();
    }

    private int doJoin() {
        int s; Thread t; ForkJoinWorkerThread wt; ForkJoinPool.WorkQueue w;
        return (s = status) < 0 ? s :
            ((t = Thread.currentThread()) instanceof ForkJoinWorkerThread) ?
            (w = (wt = (ForkJoinWorkerThread)t).workQueue).
            tryUnpush(this) && (s = doExec()) < 0 ? s :
            wt.pool.awaitJoin(w, this, 0L) :
            externalAwaitDone();
    }

final int awaitJoin(WorkQueue w, ForkJoinTask<?> task, long deadline) {
        int s = 0;
        if (task != null && w != null) {
            ForkJoinTask<?> prevJoin = w.currentJoin;
            U.putOrderedObject(w, QCURRENTJOIN, task);
            CountedCompleter<?> cc = (task instanceof CountedCompleter) ?
                (CountedCompleter<?>)task : null;
            for (;;) {
                if ((s = task.status) < 0)
                    break;
                if (cc != null)
                    helpComplete(w, cc, 0);
                //此处是join的核心
                else if (w.base == w.top || w.tryRemoveAndExec(task))
                    helpStealer(w, task);
                if ((s = task.status) < 0)
                    break;
                long ms, ns;
                if (deadline == 0L)
                    ms = 0L;
                else if ((ns = deadline - System.nanoTime()) <= 0L)
                    break;
                else if ((ms = TimeUnit.NANOSECONDS.toMillis(ns)) <= 0L)
                    ms = 1L;
                if (tryCompensate(w)) {
                    task.internalWait(ms);
                    U.getAndAddLong(this, CTL, AC_UNIT);
                }
            }
            U.putOrderedObject(w, QCURRENTJOIN, prevJoin);
        }
        return s;
    }

join会发生什么呢,它会去遍历workqueue中的ForkJoinTask数组的列表,找到在队列中之前通过fork方法放入的ForkJoinTask,用EmptyTask替换它,然后在当前线程中直接通过doExec方法直接执行,我想此处是work-steal的精髓所在

final boolean tryRemoveAndExec(ForkJoinTask<?> task) {
            ForkJoinTask<?>[] a; int m, s, b, n;
            if ((a = array) != null && (m = a.length - 1) >= 0 &&
                task != null) {
                while ((n = (s = top) - (b = base)) > 0) {
                    for (ForkJoinTask<?> t;;) {      // traverse from s to b
                        long j = ((--s & m) << ASHIFT) + ABASE;
                        if ((t = (ForkJoinTask<?>)U.getObject(a, j)) == null)
                            return s + 1 == top;     // shorter than expected
                        else if (t == task) {
                            boolean removed = false;
                            if (s + 1 == top) {      // pop
                                if (U.compareAndSwapObject(a, j, task, null)) {
                                    U.putOrderedInt(this, QTOP, s);
                                    removed = true;
                                }
                            }
                            else if (base == b)      // replace with proxy
                                removed = U.compareAndSwapObject(
                                    a, j, task, new EmptyTask());
                            if (removed)
                              //直接执行
                                task.doExec();
                            break;
                        }
                        else if (t.status < 0 && s + 1 == top) {
                            if (U.compareAndSwapObject(a, j, t, null))
                                U.putOrderedInt(this, QTOP, s);
                            break;                  // was cancelled
                        }
                        if (--n == 0)
                            return false;
                    }
                    if (task.status < 0)
                        return false;
                }
            }
            return true;
        }

我想join的方法应该是work-stealing的关键,当前线程会窃取执行调用join的的task,使得当前线程从一个task切换到了另外一个task的执行。再回到我们之前提的几个疑问。

  1. 利用forkjoinpool处理的任务不太方便通过ThreadPoolExecutor来实现,原因就在于如果我们在一个任务中通过submit提交两个子任务,然后通过future.get的方式来获取结果的方式不会释放当前线程,get方法会阻塞整个线程的执行,使得这个线程无法执行其他的任务。

其他两个问题我想上文中已经做了足够多的回答了。

附录

ForkJoin论文翻译版

喜欢请关注,欢迎大家转载!

上一篇下一篇

猜你喜欢

热点阅读