ForkJoin源码解析
前言
本文通过Forkjoin实现数据累加的demo来进行源码分析,并且基于jdk8环境,因此与jdk7的情况会略有不同。其具体代码实现如下。
任务类
public class ForkJoinSumCalculator extends RecursiveTask<Long> {
private final long[] numbers;
private final int start;
private final int end;
public static final long THRESHOLD = 10000;
public ForkJoinSumCalculator(long[] numbers) {
this(numbers, 0, numbers.length);
}
private ForkJoinSumCalculator(long[] numbers, int start, int end) {
this.numbers = numbers;
this.start = start;
this.end = end;
}
@Override
protected Long compute() {
int length = end - start;
if (length <= THRESHOLD) {
return computeSequentially();
}
ForkJoinSumCalculator leftTask = new ForkJoinSumCalculator(numbers, start, start + length/2);
leftTask.fork();
ForkJoinSumCalculator rightTask = new ForkJoinSumCalculator(numbers, start + length/2, end);
Long rightResult = rightTask.compute();
Long leftResult = leftTask.join();
return leftResult + rightResult;
}
private long computeSequentially() {
long sum = 0;
for (int i = start; i < end; i++) {
sum += numbers[i];
}
return sum;
}
}
定义了ForkJoinSumCalculator
来实现任务分解和子任务的累加计算。
测试类
public class ForkJoinTest {
public static void main(String[] args) {
long[] numbers = LongStream.rangeClosed(1, 1000000).toArray();
ForkJoinTask<Long> task = new ForkJoinSumCalculator(numbers);//1
long result = new ForkJoinPool().invoke(task);//2
System.out.println("result:"+result);
}
}
通过测试类ForkJoinTest
启动了ForkJoinPool
并计算得到结果,从这里的main
方法可以看出实现主要依赖1和2两行,1中ForkJoinSumCalculator
类的初始化先不做过多说明,从2开始进入分析。
源码解析
首先来看一下new ForkJoinPool()
这个线程池初始化操作到底做了什么,源码如下
public ForkJoinPool() {
this(Math.min(MAX_CAP, Runtime.getRuntime().availableProcessors()),
defaultForkJoinWorkerThreadFactory, null, false);
}
public ForkJoinPool(int parallelism,
ForkJoinWorkerThreadFactory factory,
UncaughtExceptionHandler handler,
boolean asyncMode) {
this(checkParallelism(parallelism),
checkFactory(factory),
handler,
asyncMode ? FIFO_QUEUE : LIFO_QUEUE,
"ForkJoinPool-" + nextPoolId() + "-worker-");
checkPermission();
}
private ForkJoinPool(int parallelism,
ForkJoinWorkerThreadFactory factory,
UncaughtExceptionHandler handler,
int mode,
String workerNamePrefix) {
this.workerNamePrefix = workerNamePrefix;
this.factory = factory;
this.ueh = handler;
this.config = (parallelism & SMASK) | mode;
long np = (long)(-parallelism); // offset ctl counts
this.ctl = ((np << AC_SHIFT) & AC_MASK) | ((np << TC_SHIFT) & TC_MASK);
}
这里就是三个ForkJoinPool
的连续调用,最后的作用仅是给workerNamePrefix
、factory
、ueh
、config
和ctl
几个属性赋值。顺带提一下ForkJoinPool
上包含注解sun.misc.Contended
,这个注解jdk8中才引入,是java中避免缓存行伪共享的一种方案,能在并发情况下更好提升性能,此处不展开。
接着来看一下checkParallelism
方法
private static int checkParallelism(int parallelism) {
if (parallelism <= 0 || parallelism > MAX_CAP)
throw new IllegalArgumentException();
return parallelism;
}
这里传入的parallelism
是Math.min(MAX_CAP, Runtime.getRuntime().availableProcessors())
,该值取0x7fff
和当前核心数的最小值,结合checkParallelism
方法可以看出,parallelism
值一般就是CPU核数了。由于SMASK = = 0xffff
,mode
是LIFO_QUEUE = 0
(从名字可以很明显看出这是个后入先出的队列),因此根据表达式config
的值就是核心数。
然后来看一下ctl
这个值,这是一个64位的long
变量,根据注释说明,ctl
的64位被分成4个16位标识,依次称为AC
、TC
、SS
、ID
。
- AC: 运行中线程数与目标值
checkParallelism
的差值,如果ac是负的说明没有足够的活动线程 - TC: 总线程数与目标值
checkParallelism
的差值,如果tc是负的说明没有足够的总线程 - SS: 版本计数和最顶端等待线程的状态
- ID: 栈中最顶端等待线程的索引
ctl
的低32位称为sp
,当sp
非0时说明有等待线程。
然后需要注意的是factory
属性传入的值为defaultForkJoinWorkerThreadFactory
,该值的初始化在ForkJoinPool
类的静态代码块中,源码如下
static {
// initialize field offsets for CAS etc
try {
U = sun.misc.Unsafe.getUnsafe();
Class<?> k = ForkJoinPool.class;
CTL = U.objectFieldOffset
(k.getDeclaredField("ctl"));
RUNSTATE = U.objectFieldOffset
(k.getDeclaredField("runState"));
STEALCOUNTER = U.objectFieldOffset
(k.getDeclaredField("stealCounter"));
Class<?> tk = Thread.class;
PARKBLOCKER = U.objectFieldOffset
(tk.getDeclaredField("parkBlocker"));
Class<?> wk = WorkQueue.class;
QTOP = U.objectFieldOffset
(wk.getDeclaredField("top"));
QLOCK = U.objectFieldOffset
(wk.getDeclaredField("qlock"));
QSCANSTATE = U.objectFieldOffset
(wk.getDeclaredField("scanState"));
QPARKER = U.objectFieldOffset
(wk.getDeclaredField("parker"));
QCURRENTSTEAL = U.objectFieldOffset
(wk.getDeclaredField("currentSteal"));
QCURRENTJOIN = U.objectFieldOffset
(wk.getDeclaredField("currentJoin"));
Class<?> ak = ForkJoinTask[].class;
ABASE = U.arrayBaseOffset(ak);
int scale = U.arrayIndexScale(ak);
if ((scale & (scale - 1)) != 0) //判断scale是否为2的幂次方
throw new Error("data type scale not a power of two");
ASHIFT = 31 - Integer.numberOfLeadingZeros(scale);
} catch (Exception e) {
throw new Error(e);
}
commonMaxSpares = DEFAULT_COMMON_MAX_SPARES;
defaultForkJoinWorkerThreadFactory =
new DefaultForkJoinWorkerThreadFactory();
modifyThreadPermission = new RuntimePermission("modifyThread");
common = java.security.AccessController.doPrivileged
(new java.security.PrivilegedAction<ForkJoinPool>() {
public ForkJoinPool run() { return makeCommonPool(); }});
int par = common.config & SMASK; // report 1 even if threads disabled
commonParallelism = par > 0 ? par : 1;
}
通过defaultForkJoinWorkerThreadFactory = new DefaultForkJoinWorkerThreadFactory();
对该常量进行了初始化,DefaultForkJoinWorkerThreadFactory
是ForkJoinPool
的静态内部类,其具体实现为
static final class DefaultForkJoinWorkerThreadFactory
implements ForkJoinWorkerThreadFactory {
public final ForkJoinWorkerThread newThread(ForkJoinPool pool) {
return new ForkJoinWorkerThread(pool);
}
}
到这里ForkJoinPool
的初始化就算完成了,接着回到main
方法来看一下invoke(task)
方法的实现
public <T> T invoke(ForkJoinTask<T> task) {
if (task == null)
throw new NullPointerException();
externalPush(task);
return task.join();
}
这里调用了externalPush(task)
方法,接着来看一下
final void externalPush(ForkJoinTask<?> task) {
WorkQueue[] ws; WorkQueue q; int m;
int r = ThreadLocalRandom.getProbe();
int rs = runState;
if ((ws = workQueues) != null && (m = (ws.length - 1)) >= 0 &&
(q = ws[m & r & SQMASK]) != null && r != 0 && rs > 0 &&
U.compareAndSwapInt(q, QLOCK, 0, 1)) {
ForkJoinTask<?>[] a; int am, n, s;
if ((a = q.array) != null &&
(am = a.length - 1) > (n = (s = q.top) - q.base)) {
int j = ((am & s) << ASHIFT) + ABASE;
U.putOrderedObject(a, j, task); //task加入q的任务队列中
U.putOrderedInt(q, QTOP, s + 1); //修改top的位置
U.putIntVolatile(q, QLOCK, 0);
if (n <= 1)
signalWork(ws, q);
return;
}
U.compareAndSwapInt(q, QLOCK, 1, 0);
}
externalSubmit(task);
}
首先看到ThreadLocalRandom.getProbe()
可以生成一个随机数,ThreadLocalRandom
类解决了Random
种子竞争的问题,在并发情况下性能更好,这里不做过多分析。
runState
标识pool的运行状态,具体表示如下
// runState bits: SHUTDOWN must be negative, others arbitrary powers of two
private static final int RSLOCK = 1;
private static final int RSIGNAL = 1 << 1;
private static final int STARTED = 1 << 2;
private static final int STOP = 1 << 29;
private static final int TERMINATED = 1 << 30;
private static final int SHUTDOWN = 1 << 31;
看第一个if,需要同时满足5个条件才进入分支,来看一下
- (ws = workQueues) != null //workQueues数组非空
- (m = (ws.length - 1)) >= 0 //workQueues中至少有一个
WorkQueue
对象,并赋值m - (q = ws[m & r & SQMASK]) != null //
m & r & SQMASK
保证随机数为偶数且不大于m,这么做是由于这里有一个隐含的约定,只有线程为空的WorkQueue
对象才能出现在ws
的偶数位 - r != 0 //随机数非0
- rs > 0 //
runState
非0表示线程池没有被关闭 - U.compareAndSwapInt(q, QLOCK, 0, 1) //能够成功将对象q的
qlock
属性从0置为1,这里的qlock
=1说明被锁定, < 0表示终止,所以这里显然是一个加锁操作
上述条件不能全部满足则会跳出if执行externalSubmit(task)
方法,否则就接着进入下一个if语句,又需要满足两个条件
- (a = q.array) != null //q的队列不为空,这里的array类型为
ForkJoinTask<?>[]
- (am = a.length - 1) > (n = (s = q.top) - q.base) //top是当前线程即将处理的队列偏移量,base是可以被其他线程“窃取”的队列偏移量,base是被volatile修饰的,所以这个值显然是会存在并发情况的
可以看到当条件不满足时会通过U.compareAndSwapInt(q, QLOCK, 1, 0)
直接释放锁。
而当满足上述条件时,开始执行int j = ((am & s) << ASHIFT) + ABASE
,这其中ASHIFT
是数组array
中每个元素所占字节长度的二进制位数(去除高位所有0后的位数),ABASE
是第一个元素地址相对于数组起始地址的偏移值,根据计算出的偏移量j
将task
放入array
中,利用QTOP
偏移量将top
值进行+1操作,置qlock
为0以释放锁,并执行以下代码块:
if (n <= 1)
signalWork(ws, q);
return;
这里跟进看一下signalWork(ws, q)
方法
final void signalWork(WorkQueue[] ws, WorkQueue q) {
long c; int sp, i; WorkQueue v; Thread p;
while ((c = ctl) < 0L) { // active线程过少
if ((sp = (int)c) == 0) { // 没有空闲线程
if ((c & ADD_WORKER) != 0L) // 工作线程太少
tryAddWorker(c);
break;
}
if (ws == null) // unstarted/terminated
break;
if (ws.length <= (i = sp & SMASK)) // 已终止
break;
if ((v = ws[i]) == null) // 正在终止
break;
int vs = (sp + SS_SEQ) & ~INACTIVE; // next scanState
int d = sp - v.scanState; // screen CAS
long nc = (UC_MASK & (c + AC_UNIT)) | (SP_MASK & v.stackPred);
if (d == 0 && U.compareAndSwapLong(this, CTL, c, nc)) {
v.scanState = vs;
if ((p = v.parker) != null) // 唤醒v的owner
U.unpark(p);
break;
}
if (q != null && q.base == q.top) // no more work
break;
}
}
(c = ctl) < 0L
判断active线程过少时,会执行while循环,当满足工作线程太少的判断条件时,会执行tryAddWorker(c)
方法增加工作线程,来看看具体代码
private void tryAddWorker(long c) {
boolean add = false;
do {
long nc = ((AC_MASK & (c + AC_UNIT)) |
(TC_MASK & (c + TC_UNIT)));//AC、TC分别进行加1操作,表示增加了worker线程
if (ctl == c) {
int rs, stop; // check if terminating
if ((stop = (rs = lockRunState()) & STOP) == 0)
add = U.compareAndSwapLong(this, CTL, c, nc);
unlockRunState(rs, rs & ~RSLOCK);
if (stop != 0)
break;
if (add) {
createWorker();
break;
}
}
} while (((c = ctl) & ADD_WORKER) != 0L && (int)c == 0);
}
这里用了do-while循环尝试创建worker线程,当CAS地修改ctl
成功时才会执行createWorker()
方法并推出,createWorker()
方法实现如下
private boolean createWorker() {
ForkJoinWorkerThreadFactory fac = factory;
Throwable ex = null;
ForkJoinWorkerThread wt = null;
try {
if (fac != null && (wt = fac.newThread(this)) != null) {
wt.start();
return true;
}
} catch (Throwable rex) {
ex = rex;
}
deregisterWorker(wt, ex);
return false;
}
根据之前的静态代码块可以知道,这里传入的factory
是一个DefaultForkJoinWorkerThreadFactory
类型对象,
static final class DefaultForkJoinWorkerThreadFactory
implements ForkJoinWorkerThreadFactory {
public final ForkJoinWorkerThread newThread(ForkJoinPool pool) {
return new ForkJoinWorkerThread(pool);
}
}
protected ForkJoinWorkerThread(ForkJoinPool pool) {
// Use a placeholder until a useful name can be set in registerWorker
super("aForkJoinWorkerThread");
this.pool = pool;
this.workQueue = pool.registerWorker(this);//将当前ForkJoinWorkerThread线程注册到ForkJoinPool中
}
由此可知这里的createWorker()
方法会创建一个ForkJoinWorkerThread
线程并启动它。pool.registerWorker(this)
会将当前线程注册到pool中,这也就意味着当前线程会成为这个workQueue
的owner
,这里就要说到worker steal算法,大意就是一个线程从自己任务队列的头部取出任务执行,而其他空闲线程可以从其队列的尾部“偷”任务执行,以充分利用空闲的线程资源。这里当线程成为owner
之后,才可以从top
位置取任务,因此WorkQueue
中的top
是非volatile
类型,base
却是volatile
的。
由于在createWorker()
中,创建的线程被启动了,那么我们有必要来看看ForkJoinWorkerThread
的run
方法里都做了些什么。
public void run() {
if (workQueue.array == null) { // only run once
Throwable exception = null;
try {
onStart(); //没有任何操作
pool.runWorker(workQueue);
} catch (Throwable ex) {
exception = ex;
} finally {
try {
onTermination(exception);
} catch (Throwable ex) {
if (exception == null)
exception = ex;
} finally {
pool.deregisterWorker(this, exception);
}
}
}
}
可以看到业务都交由了pool.runWorker(workQueue)
运行,源码如下
final void runWorker(WorkQueue w) {
w.growArray(); // allocate queue
int seed = w.hint; // initially holds randomization hint
int r = (seed == 0) ? 1 : seed; // avoid 0 for xorShift
for (ForkJoinTask<?> t;;) {
if ((t = scan(w, r)) != null)
w.runTask(t);
else if (!awaitWork(w, r))
break;
r ^= r << 13; r ^= r >>> 17; r ^= r << 5; // xorshift
}
}
这里用了一个死循环来执行task,具体涉及scan
、runTask
、awaitWork
几个方法,逐一来看一下。
首先是scan
:
private ForkJoinTask<?> scan(WorkQueue w, int r) {
WorkQueue[] ws; int m;
if ((ws = workQueues) != null && (m = ws.length - 1) > 0 && w != null) {
int ss = w.scanState; // initially non-negative
for (int origin = r & m, k = origin, oldSum = 0, checkSum = 0;;) {
WorkQueue q; ForkJoinTask<?>[] a; ForkJoinTask<?> t;
int b, n; long c;
if ((q = ws[k]) != null) {
if ((n = (b = q.base) - q.top) < 0 &&
(a = q.array) != null) { // non-empty
long i = (((a.length - 1) & b) << ASHIFT) + ABASE; //base对应的地址
if ((t = ((ForkJoinTask<?>)
U.getObjectVolatile(a, i))) != null &&
q.base == b) {
if (ss >= 0) {
if (U.compareAndSwapObject(a, i, t, null)) {
q.base = b + 1;
if (n < -1) // signal others
signalWork(ws, q);
return t;
}
}
else if (oldSum == 0 && // try to activate
w.scanState < 0)
tryRelease(c = ctl, ws[m & (int)c], AC_UNIT);
}
if (ss < 0) // refresh
ss = w.scanState;
r ^= r << 1; r ^= r >>> 3; r ^= r << 10;
origin = k = r & m; // move and rescan
oldSum = checkSum = 0;
continue;
}
checkSum += b;
}
if ((k = (k + 1) & m) == origin) { // 直到遍历完所有队列才停止
if ((ss >= 0 || (ss == (ss = w.scanState))) &&
oldSum == (oldSum = checkSum)) {
if (ss < 0 || w.qlock < 0) // already inactive
break;
int ns = ss | INACTIVE; // try to inactivate
long nc = ((SP_MASK & ns) |
(UC_MASK & ((c = ctl) - AC_UNIT)));
w.stackPred = (int)c; // 记录前一个top值
U.putInt(w, QSCANSTATE, ns);
if (U.compareAndSwapLong(this, CTL, c, nc))
ss = ns;
else
w.scanState = ss; // back out
}
checkSum = 0;
}
}
}
return null;
}
这个方法主要做的一件事就是遍历workQueue
,并窃取一个尾部任务,窃取到则立即返回,并执行w.runTask(t)
,那么接着来看一下runTask
方法
final void runTask(ForkJoinTask<?> task) {
if (task != null) {
scanState &= ~SCANNING; // mark as busy
(currentSteal = task).doExec();
U.putOrderedObject(this, QCURRENTSTEAL, null); // release for GC
execLocalTasks();
ForkJoinWorkerThread thread = owner;
if (++nsteals < 0) // collect on overflow
transferStealCount(pool);
scanState |= SCANNING;
if (thread != null)
thread.afterTopLevelExec();
}
}
可以看到task
被提交给(currentSteal = task).doExec()
进行处理
final int doExec() {
int s; boolean completed;
if ((s = status) >= 0) {
try {
completed = exec();
} catch (Throwable rex) {
return setExceptionalCompletion(rex);
}
if (completed)
s = setCompletion(NORMAL);
}
return s;
}
之后又被交由ForkJoinTask<V>
的子类RecursiveTask<V>
实现的exec()
方法进行处理
protected final boolean exec() {
result = compute();
return true;
}
到这里应该就很清楚了,这个compute()
方法就是我们自定义的任务类ForkJoinSumCalculator
中实现的方法。也就是说一旦窃取到任务就直接执行了,那么execLocalTasks()
方法又是在做什么呢,来看一下
final void execLocalTasks() {
int b = base, m, s;
ForkJoinTask<?>[] a = array;
if (b - (s = top - 1) <= 0 && a != null &&
(m = a.length - 1) >= 0) {
if ((config & FIFO_QUEUE) == 0) {
for (ForkJoinTask<?> t;;) {
if ((t = (ForkJoinTask<?>)U.getAndSetObject
(a, ((m & s) << ASHIFT) + ABASE, null)) == null) //从top位置取出任务
break;
U.putOrderedInt(this, QTOP, s);
t.doExec();
if (base - (s = top - 1) > 0)
break;
}
}
else
pollAndExecAll();
}
}
LIFO
模式会执行pollAndExecAll()
,否则执行另一个分支。两个分支做的事情其实一样,都是循环执行array
中的任务,不同的是一个从top
取,一个从base
取。
最后来看一下awaitWork
方法
private boolean awaitWork(WorkQueue w, int r) {
if (w == null || w.qlock < 0) // w is terminating
return false;
for (int pred = w.stackPred, spins = SPINS, ss;;) {
if ((ss = w.scanState) >= 0)
break;
else if (spins > 0) {
r ^= r << 6; r ^= r >>> 21; r ^= r << 7;
if (r >= 0 && --spins == 0) { // 进行随机自旋
WorkQueue v; WorkQueue[] ws; int s, j; AtomicLong sc;
if (pred != 0 && (ws = workQueues) != null &&
(j = pred & SMASK) < ws.length &&
(v = ws[j]) != null && // see if pred parking
(v.parker == null || v.scanState >= 0))
spins = SPINS; // continue spinning
}
}
else if (w.qlock < 0) // recheck after spins
return false;
else if (!Thread.interrupted()) {
long c, prevctl, parkTime, deadline;
int ac = (int)((c = ctl) >> AC_SHIFT) + (config & SMASK);
if ((ac <= 0 && tryTerminate(false, false)) ||
(runState & STOP) != 0) // pool terminating
return false;
if (ac <= 0 && ss == (int)c) { // is last waiter
prevctl = (UC_MASK & (c + AC_UNIT)) | (SP_MASK & pred);
int t = (short)(c >>> TC_SHIFT); // 收缩过剩的线程
if (t > 2 && U.compareAndSwapLong(this, CTL, c, prevctl))
return false; // else use timed wait
parkTime = IDLE_TIMEOUT * ((t >= 0) ? 1 : 1 - t);
deadline = System.nanoTime() + parkTime - TIMEOUT_SLOP;
}
else
prevctl = parkTime = deadline = 0L;
Thread wt = Thread.currentThread();
U.putObject(wt, PARKBLOCKER, this); // emulate LockSupport
w.parker = wt;
if (w.scanState < 0 && ctl == c) // recheck before park
U.park(false, parkTime);
U.putOrderedObject(w, QPARKER, null);
U.putObject(wt, PARKBLOCKER, null);
if (w.scanState >= 0)
break;
if (parkTime != 0L && ctl == c &&
deadline - System.nanoTime() <= 0L &&
U.compareAndSwapLong(this, CTL, c, prevctl))
return false; // shrink pool
}
}
return true;
}
当scan
方法没有窃取到任务时,会进入到这个方法,根据这个方法的返回值判断是继续去执行scan
还是退出当前线程。同时判断当前线程是否是过剩线程,如果是的话将退出当前线程以收缩线程池。
到这里pool.runWorker(workQueue)
做的事基本了解了,也知道了最后执行任务的步骤调用的都是我们自定义的compute()
方法,那么还是很有必要来具体了解一下这个方法的内容。
@Override
protected Long compute() {
int length = end - start;
if (length <= THRESHOLD) {//小于阈值开始进行累加
return computeSequentially();
}
ForkJoinSumCalculator leftTask = new ForkJoinSumCalculator(numbers, start, start + length/2);
leftTask.fork();
ForkJoinSumCalculator rightTask = new ForkJoinSumCalculator(numbers, start + length/2, end);
Long rightResult = rightTask.compute();
Long leftResult = leftTask.join();
return leftResult + rightResult;
}
这个方法虽然是自定义的,但其实必须遵守一个大概的实现模板,模板里必定有fork()
和join()
方法,我们依次来看一下它们做了什么。
public final ForkJoinTask<V> fork() {
Thread t;
if ((t = Thread.currentThread()) instanceof ForkJoinWorkerThread)
((ForkJoinWorkerThread)t).workQueue.push(this);
else
ForkJoinPool.common.externalPush(this);
return this;
}
fork()
做的事情比较简单,当前线程如果是ForkJoinWorkerThread
线程就通过push
方法将当前任务加入队列top
端,否则执行externalPush
方法,这个方法之前已经出现过了,这里就不重复介绍了。
接着来看join()
方法
public final V join() {
int s;
if ((s = doJoin() & DONE_MASK) != NORMAL)
reportException(s);
return getRawResult();
}
这里会根据doJoin()
方法的返回值来判断是否抛出异常,那么来看一下doJoin()
方法。
private int doJoin() {
int s; Thread t; ForkJoinWorkerThread wt; ForkJoinPool.WorkQueue w;
return (s = status) < 0 ? s :
((t = Thread.currentThread()) instanceof ForkJoinWorkerThread) ?
(w = (wt = (ForkJoinWorkerThread)t).workQueue).
tryUnpush(this) && (s = doExec()) < 0 ? s :
wt.pool.awaitJoin(w, this, 0L) :
externalAwaitDone();
}
当前线程如果不是ForkJoinWorkerThread
线程,则执行externalAwaitDone()
方法阻塞当前线程,否则执行另一个判断(w = (wt = (ForkJoinWorkerThread)t).workQueue). tryUnpush(this) && (s = doExec()) < 0 ? s : wt.pool.awaitJoin(w, this, 0L)
。
final boolean tryUnpush(ForkJoinTask<?> t) {
ForkJoinTask<?>[] a; int s;
if ((a = array) != null && (s = top) != base &&
U.compareAndSwapObject
(a, (((a.length - 1) & --s) << ASHIFT) + ABASE, t, null)) {
U.putOrderedInt(this, QTOP, s);
return true;
}
return false;
}
tryUnpush
方法判断top
端的任务取出是否成功,并且调用doExec()
执行,成功则返回状态s
,否则执行wt.pool.awaitJoin(w, this, 0L)
。awaitJoin
方法会在指定任务完成或者超时前尝试帮助或阻塞自身,来具体看一下,
final int awaitJoin(WorkQueue w, ForkJoinTask<?> task, long deadline) {
int s = 0;
if (task != null && w != null) {
ForkJoinTask<?> prevJoin = w.currentJoin;
U.putOrderedObject(w, QCURRENTJOIN, task);//记录当前等待的任务
CountedCompleter<?> cc = (task instanceof CountedCompleter) ?
(CountedCompleter<?>)task : null;
for (;;) {
if ((s = task.status) < 0) //任务完成则直接退出
break;
if (cc != null)
helpComplete(w, cc, 0);
else if (w.base == w.top || w.tryRemoveAndExec(task))
helpStealer(w, task);
if ((s = task.status) < 0)
break;
long ms, ns;
if (deadline == 0L)
ms = 0L;
else if ((ns = deadline - System.nanoTime()) <= 0L)
break;
else if ((ms = TimeUnit.NANOSECONDS.toMillis(ns)) <= 0L)
ms = 1L;
if (tryCompensate(w)) {
task.internalWait(ms);
U.getAndAddLong(this, CTL, AC_UNIT);
}
}
U.putOrderedObject(w, QCURRENTJOIN, prevJoin);
}
return s;
}
这里比较重要的是tryRemoveAndExec
、helpStealer
和tryCompensate
几个方法。
首先来看tryRemoveAndExec
方法,
final boolean tryRemoveAndExec(ForkJoinTask<?> task) {
ForkJoinTask<?>[] a; int m, s, b, n;
if ((a = array) != null && (m = a.length - 1) >= 0 &&
task != null) {
while ((n = (s = top) - (b = base)) > 0) {
for (ForkJoinTask<?> t;;) { // 从s遍历到b
long j = ((--s & m) << ASHIFT) + ABASE;
if ((t = (ForkJoinTask<?>)U.getObject(a, j)) == null)
return s + 1 == top; // shorter than expected
else if (t == task) {
boolean removed = false;
if (s + 1 == top) { // pop
if (U.compareAndSwapObject(a, j, task, null)) {
U.putOrderedInt(this, QTOP, s);
removed = true;
}
}
else if (base == b) // replace with proxy
removed = U.compareAndSwapObject(
a, j, task, new EmptyTask());
if (removed)
task.doExec();
break;
}
else if (t.status < 0 && s + 1 == top) {
if (U.compareAndSwapObject(a, j, t, null))
U.putOrderedInt(this, QTOP, s);
break; // was cancelled
}
if (--n == 0)
return false;
}
if (task.status < 0)
return false;
}
}
return true;
}
该方法主要做的是去自己的队列中进行遍历,看看任务是否在top
位置,在的话直接取出执行,若在队列中间,则用new EmptyTask()
替换之,并取出任务执行。方法返回时若任务未执行完,则不进行后续的help动作。
接着来看一下helpStealer
方法,
private void helpStealer(WorkQueue w, ForkJoinTask<?> task) {
WorkQueue[] ws = workQueues;
int oldSum = 0, checkSum, m;
if (ws != null && (m = ws.length - 1) >= 0 && w != null &&
task != null) {
do { // restart point
checkSum = 0; // for stability check
ForkJoinTask<?> subtask;
WorkQueue j = w, v; // v是子任务的stealer
descent: for (subtask = task; subtask.status >= 0; ) {
for (int h = j.hint | 1, k = 0, i; ; k += 2) {
if (k > m) // can't find stealer
break descent;
if ((v = ws[i = (h + k) & m]) != null) {
if (v.currentSteal == subtask) {
j.hint = i;
break;
}
checkSum += v.base;
}
}
for (;;) { // 帮助v执行任务
ForkJoinTask<?>[] a; int b;
checkSum += (b = v.base);
ForkJoinTask<?> next = v.currentJoin;
if (subtask.status < 0 || j.currentJoin != subtask ||
v.currentSteal != subtask) // stale
break descent;
if (b - v.top >= 0 || (a = v.array) == null) {
if ((subtask = next) == null)
break descent;
j = v;
break;
}
int i = (((a.length - 1) & b) << ASHIFT) + ABASE;
ForkJoinTask<?> t = ((ForkJoinTask<?>)
U.getObjectVolatile(a, i));
if (v.base == b) {
if (t == null) // stale
break descent;
if (U.compareAndSwapObject(a, i, t, null)) {
v.base = b + 1;
ForkJoinTask<?> ps = w.currentSteal;
int top = w.top;
do {
U.putOrderedObject(w, QCURRENTSTEAL, t);
t.doExec(); // 清空本地任务
} while (task.status >= 0 &&
w.top != top &&
(t = w.pop()) != null);
U.putOrderedObject(w, QCURRENTSTEAL, ps);
if (w.base != w.top)
return; // 自己的队列不为空了不再进行help操作
}
}
}
}
} while (task.status >= 0 && oldSum != (oldSum = checkSum));
}
}
该方法很长,做的事情主要是找到偷取自己任务的WorkQueue
,去偷取它的任务执行。直到自己的队列不为空了,则不再进行help操作。
最后来看一下tryCompensate
方法
private boolean tryCompensate(WorkQueue w) {
boolean canBlock;
WorkQueue[] ws; long c; int m, pc, sp;
if (w == null || w.qlock < 0 || // caller terminating
(ws = workQueues) == null || (m = ws.length - 1) <= 0 ||
(pc = config & SMASK) == 0) // parallelism disabled
canBlock = false;
else if ((sp = (int)(c = ctl)) != 0) // 释放空闲线程
canBlock = tryRelease(c, ws[sp & m], 0L);
else {
int ac = (int)(c >> AC_SHIFT) + pc;
int tc = (short)(c >> TC_SHIFT) + pc;
int nbusy = 0; // validate saturation
for (int i = 0; i <= m; ++i) { // two passes of odd indices
WorkQueue v;
if ((v = ws[((i << 1) | 1) & m]) != null) {
if ((v.scanState & SCANNING) != 0)
break;
++nbusy;
}
}
if (nbusy != (tc << 1) || ctl != c)
canBlock = false; // unstable or stale
else if (tc >= pc && ac > 1 && w.isEmpty()) {
long nc = ((AC_MASK & (c - AC_UNIT)) |
(~AC_MASK & c)); // uncompensated
canBlock = U.compareAndSwapLong(this, CTL, c, nc);
}
else if (tc >= MAX_CAP ||
(this == common && tc >= pc + commonMaxSpares))
throw new RejectedExecutionException(
"Thread limit exceeded replacing blocked worker");
else { // similar to tryAddWorker
boolean add = false; int rs; // CAS within lock
long nc = ((AC_MASK & c) |
(TC_MASK & (c + TC_UNIT)));
if (((rs = lockRunState()) & STOP) == 0)
add = U.compareAndSwapLong(this, CTL, c, nc);
unlockRunState(rs, rs & ~RSLOCK);
canBlock = add && createWorker(); // throws on exception
}
}
return canBlock;
}
该方法尝试减少活跃线程,也会由于任务阻塞释放或者创建补偿线程。到此整个流程基本完整了。
总结
本文从数据累加的demo开始,将整个执行流程在源码层面进行了一个大概的串联,由于本人能力有限在许多标志位的使用及位运算的细节方面并没有了解的很深入,仔细去推敲的话其实里面还有很多东西可以去挖,看完源码之后不得不感叹Doug Lea大神的厉害。