阿里的面试题带你认识ForkJoinPool
我相信大家都用过线程池,比如ExcutorService,比如ThreadPoolExcutor
今天来讲讲ForkJoinPool,它实现于ExcutorService,但又和我们常用的
ThreadPoolExcutor原理不同
前言
随着在硬件上多核处理器的发展和广泛使用,并发编程成为程序员必须掌握的一门技术,在面试中也经常考查面试者并发相关的知识。
今天,我们就从一道阿里的面试题来开始
题目:如何充分利用多核CPU,计算超大数组中所有整数的和?
解析开始
1.单线程相加?
我们最容易想到就是单线程相加,一个for循环搞定。
2.线程池相加?
如果进一步优化,我们会自然而然地想到使用线程池来分段相加,最后再把每个段的结果相加。
3.其它?
Yes,就是我们今天的主角——ForkJoinPool,但是它要怎么实现呢?似乎没怎么用过哈^^
让我们看看上面是那种方法都如何实现
/**
* 计算1亿个整数的和 */
public class ForkJoinPoolTest01 {
public static void main(String[] args) throws ExecutionException, InterruptedException {
// 构造数据
int length = 100000000;
long[] arr = new long[length];
for (int i = 0; i < length; i++) {
arr[i] = ThreadLocalRandom.current().nextInt(Integer.MAX_VALUE);
}
// 单线程
singleThreadSum(arr);
// ThreadPoolExecutor线程池
multiThreadSum(arr);
// ForkJoinPool线程池
forkJoinSum(arr);
}
private static void singleThreadSum(long[] arr) {
long start = System.currentTimeMillis();
long sum = 0;
for (int i = 0; i < arr.length; i++) {
// 模拟耗时,本文由公从号“彤哥读源码”原创
sum += (arr[i]/5*5/5*5/5*5/5*5/5*5);
}
System.out.println("sum: " + sum);
System.out.println("single thread elapse: " + (System.currentTimeMillis() - start));
}
private static void multiThreadSum(long[] arr) throws ExecutionException, InterruptedException {
long start = System.currentTimeMillis();
int count = 8;
ExecutorService threadPool = Executors.newFixedThreadPool(count);
List<Future<Long>> list = new ArrayList<>();
for (int i = 0; i < count; i++) {
int num = i;
// 分段提交任务
Future<Long> future = threadPool.submit(() -> {
long sum = 0;
for (int j = arr.length / count * num; j < (arr.length / count * (num + 1)); j++) {
try {
// 模拟耗时
sum += (arr[j]/5*5/5*5/5*5/5*5/5*5);
} catch (Exception e) {
e.printStackTrace();
}
}
return sum;
});
list.add(future);
}
// 每个段结果相加
long sum = 0;
for (Future<Long> future : list) {
sum += future.get();
}
System.out.println("sum: " + sum);
System.out.println("multi thread elapse: " + (System.currentTimeMillis() - start));
}
private static void forkJoinSum(long[] arr) throws ExecutionException, InterruptedException {
long start = System.currentTimeMillis();
ForkJoinPool forkJoinPool = ForkJoinPool.commonPool();
// 提交任务
ForkJoinTask<Long> forkJoinTask = forkJoinPool.submit(new SumTask(arr, 0, arr.length));
// 获取结果
Long sum = forkJoinTask.get();
forkJoinPool.shutdown();
System.out.println("sum: " + sum);
System.out.println("fork join elapse: " + (System.currentTimeMillis() - start));
}
private static class SumTask extends RecursiveTask<Long> {
private long[] arr;
private int from;
private int to;
public SumTask(long[] arr, int from, int to) {
this.arr = arr;
this.from = from;
this.to = to;
}
@Override
protected Long compute() {
// 小于1000的时候直接相加,可灵活调整
if (to - from <= 1000) {
long sum = 0;
for (int i = from; i < to; i++) {
// 模拟耗时
sum += (arr[i]/5*5/5*5/5*5/5*5/5*5);
}
return sum;
}
// 分成两段任务,本文由公从号“彤哥读源码”原创
int middle = (from + to) / 2;
SumTask left = new SumTask(arr, from, middle);
SumTask right = new SumTask(arr, middle, to);
// 提交左边的任务
left.fork();
// 右边的任务直接利用当前线程计算,节约开销
Long rightResult = right.compute();
// 等待左边计算完毕
Long leftResult = left.join();
// 返回结果
return leftResult + rightResult;
}
}
}
~~Garnett偷偷地告诉你,实际上计算1亿个整数相加,单线程是最快的,我的电脑大概是100ms左右,使用线程池反而会变慢。~~
~~所以,为了演示ForkJoinPool的牛逼之处,我把每个数都/5*5/5*5/5*5/5*5/5*5
了一顿操作,用来模拟计算耗时。~~
来看结果:
sum: 107352457433800662
single thread elapse: 789
sum: 107352457433800662
multi thread elapse: 228
sum: 107352457433800662fork join elapse: 189
可以看到,ForkJoinPool相对普通线程池还是有很大提升的。
什么是ForkJoinPool?
谈到线程池,很多人会想到Executors提供的一些预设的线程池,比如单线程线程池SingleThreadExecutor
,固定大小的线程池FixedThreadPool
,但是很少有人会注意到其中还提供了一种特殊的线程池:WorkStealingPool
,我们点进这个方法,会看到和其他方法不同的是,这种线程池并不是通过ThreadPoolExecutor
来创建的,而是ForkJoinPool
来创建的
public static ExecutorService newWorkStealingPool() {
return new ForkJoinPool
(Runtime.getRuntime().availableProcessors(),
ForkJoinPool.defaultForkJoinWorkerThreadFactory,
null, true);
}
ThreadPoolExecutor应该都很了解了,就是一个基本的存储线程的线程池,需要执行任务的时候就从线程池中拿一个线程来执行。而ForkJoinPool则不仅仅是这么简单,同样也不是ThreadPoolExecutor的代替品,这种线程池是为了实现“分治法”这一思想而创建的,通过把大任务拆分成小任务,然后再把小任务的结果汇总起来就是最终的结果,和MapReduce的思想很类似
最核心的思想可以这样描述:
if(任务很小){
直接计算得到结果
}else{
分拆成N个子任务
调用子任务的fork()进行计算
调用子任务的join()合并计算结果
}
1.fork()
fork()方法类似于线程的Thread.start()方法,但是它不是真的启动一个线程,而是将任务放入到工作队列中。
2.join()
join()方法类似于线程的Thread.join()方法,但是它不是简单地阻塞线程,而是利用工作线程运行其它任务。当一个工作线程中调用了join()方法,它将处理其它任务,直到注意到目标子任务已经完成了。
ForkJoinPool内部原理-工作窃取
work-stealing(工作窃取)算法
ForkJoinPool 的另一个特性是它使用了work-stealing(工作窃取)算法:
线程池内的所有工作线程都尝试找到并执行已经提交的任务,或者是被其他活动任务创建的子任务(如果不存在就阻塞等待)。这种特性使得 ForkJoinPool 在运行多个可以产生子任务的任务,或者是提交的许多小任务时效率更高。尤其是构建异步模型的 ForkJoinPool 时,对不需要合并(join)的事件类型任务也非常适用。
在 ForkJoinPool 中,线程池中每个工作线程(ForkJoinWorkerThread)都对应一个任务队列(WorkQueue),工作线程优先处理来自自身队列的任务(LIFO或FIFO顺序,参数 mode 决定),然后以FIFO的顺序随机窃取其他队列中的任务。
ForkJoinPool中的任务
ForkJoinPool 中的任务分为两种:
一种是本地提交的任务(Submission task,如 execute、submit 提交的任务);
另外一种是 fork 出的子任务(Worker task)。
两种任务都会存放在 WorkQueue 数组中,但是这两种任务并不会混合在同一个队列里,ForkJoinPool 内部使用了一种随机哈希算法(有点类似 ConcurrentHashMap 的桶随机算法)将工作队列与对应的工作线程关联起来,Submission 任务存放在 WorkQueue 数组的偶数索引位置,Worker 任务存放在奇数索引位。
实质上,Submission 与 Worker 一样,只不过它被限制只能执行它们提交的本地任务,在后面的源码解析中,我们统一称之为“Worker”。
任务的分布情况如下图:
ForkJoinPool原理
初始化ForkJoinPool
ForkJoinPool pool = ForkJoinPool.commonPool()
public static ForkJoinPool commonPool() {
// assert common != null : "static init error";
return common;
}
获取ForkJoinPool很简单,直接调用commonPool()。注意,这个方法是jdk1.8才加的,也是推荐的方法,满足大部分场景。
static{
//...
common = java.security.AccessController.doPrivileged
(new java.security.PrivilegedAction<ForkJoinPool>() {
public ForkJoinPool run() { return makeCommonPool(); }});
//...
}
private static ForkJoinPool makeCommonPool() {
//...
return new ForkJoinPool(parallelism, factory, handler, LIFO_QUEUE,"ForkJoinPool.commonPool-worker-");
}
common在static{}里创建,调用的是makeCommonPool(),最终调用ForkJoinPool的构造函数。
private ForkJoinPool(int parallelism,
ForkJoinWorkerThreadFactory factory,
UncaughtExceptionHandler handler,
int mode, String workerNamePrefix) {
this.workerNamePrefix = workerNamePrefix;
this.factory = factory;
this.ueh = handler;
this.config = (parallelism & SMASK) | mode;
long np = (long)(-parallelism); // offset ctl counts
this.ctl = ((np << AC_SHIFT) & AC_MASK) | ((np << TC_SHIFT) & TC_MASK);
}
parallelism默认是cpu核心数,ForkJoinPool里线程数量依据于它,但不表示最大线程数,不要等同于ThreadPoolExecutor里的corePoolSize或者maximumPoolSize。
factory是线程工程,不是新东西了,默认实现是
DefaultForkJoinWorkerThreadFactory。
workerNamePrefix是其中线程名称的前缀,默认使用“ForkJoinPool-*”
config保存不变的参数,包括了parallelism和mode,供后续读取。mode可选FIFO_QUEUE和LIFO_QUEUE,默认是LIFO_QUEUE,具体用哪种,就要看业务。
ctl是ForkJoinPool中最重要的控制字段,将下面信息按16bit为一组封装在一个long中。
AC: 活动的worker数量;
TC: 总共的worker数量;
SS: WorkQueue状态,第一位表示active的还是inactive,其余十五位表示版本号(对付ABA);
ID: 这里保存了一个WorkQueue在WorkQueue[]的下标,和其他worker通过字段stackPred组成一个TreiberStack。后文讲的栈顶,指这里下标所在的WorkQueue。
TreiberStack:这个栈的pull和pop使用了CAS,所以支持并发下的无锁操作。
AC和TC初始化时取的是parallelism负数,后续代码可以直接判断正负,为负代表还没有达到目标数量。另外ctl低32位有个技巧可以直接用sp=(int)ctl取得,为负代表存在空闲worker。
线程池缺不了状态的变化,记录字段是runState,具体介绍在后面的“ForkJoinPool状态修改”。
任务ForkJoinTask
ForkJoinPool执行任务的对象是ForkJoinTask,它是一个抽象类,有两个具体实现类RecursiveAction和RecursiveTask。
public abstract class RecursiveAction extends ForkJoinTask<Void> {
protected abstract void compute();
public final Void getRawResult() { return null; }
protected final void setRawResult(Void mustBeNull) { }
protected final boolean exec() {
compute();
return true;
}
}
public abstract class RecursiveTask<V> extends ForkJoinTask<V> {
V result;
protected abstract V compute();
public final V getRawResult() {
return result;
}
protected final void setRawResult(V value) {
result = value;
}
protected final boolean exec() {
result = compute();
return true;
}
}
ForkJoinTask的抽象方法exec由RecursiveAction和RecursiveTask实现,它被定义为final,具体的执行步骤compute延迟到子类实现。很容易看出RecursiveAction和RecursiveTask的区别,前者没有result,getRawResult返回空,它们对应不需要返回结果和需要返回结果两种场景。
ForkJoinTask里很重要的字段是它的状态status,默认是0,当得出结果时变更为负数,有三种结果:
NORMAL
CANCELLED
EXCEPTIONAL
除此之外,在得出结果之前,任务状态能够被设置为SIGNAL,表示有线程等待这个任务的结果,执行完成后需要notify通知,具体看后文的join。
ForkJoinTask在触发执行后,并不支持其他什么特别操作,只能等待任务执行完成。CountedCompleter是ForkJoinTask的子类,它在子任务协作方面扩展了更多操作。我们聚焦ForkJoinPool主线流程,CountedCompleter相关内容另文再介绍。
WorkQueue
WorkQueue是一个双端队列,它定义在ForkJoinPool类里。
scanState描述WorkQueue当前状态:
偶数表示RUNNING
奇数表示SCANNING
负数表示inactive
stackPred是WorkQueue组成TreiberStack时,保存前者的字段。
ForkJoinPool状态修改
STARTED
STOP
TERMINATED
SHUTDOWN
RSLOCK
RSIGNAL
runState记录了ForkJoinPool的运行状态,除了SHUTDOWN是负数,其他都是正数。前面四种不用说了,线程池标准状态流转。在多线程环境修改runState,不能简单想改就改,需要先获取锁,RSLOCK和RSIGNAL就用在这里。
private int lockRunState() {
int rs;
return ((((rs = runState) & RSLOCK) != 0 ||
!U.compareAndSwapInt(this, RUNSTATE, rs, rs |= RSLOCK)) ?
awaitRunStateLock() : rs);
}
修改前调用lockRunState锁定,检查当前状态,尝试一次使用CAS修改runState为RSLOCK。需要状态变化的机会很少,大多数时间一次就能成功,但不能排除少几率的竞争,这时候进入awaitRunStateLock。
private int awaitRunStateLock() {
Object lock;
boolean wasInterrupted = false;
for (int spins = SPINS, r = 0, rs, ns;;) {
//1
if (((rs = runState) & RSLOCK) == 0) {
if (U.compareAndSwapInt(this, RUNSTATE, rs, ns = rs | RSLOCK)) {
if (wasInterrupted) {
try {
Thread.currentThread().interrupt();
} catch (SecurityException ignore) {
}
}
return ns;
}
}
else if (r == 0)
r = ThreadLocalRandom.nextSecondarySeed();
else if (spins > 0) {
r ^= r << 6; r ^= r >>> 21; r ^= r << 7; // xorshift
if (r >= 0)
--spins;
}
//2
else if ((rs & STARTED) == 0 || (lock = stealCounter) == null)
Thread.yield(); // initialization race
//3
else if (U.compareAndSwapInt(this, RUNSTATE, rs, rs | RSIGNAL)) {
synchronized (lock) {
if ((runState & RSIGNAL) != 0) {
try {
lock.wait();
} catch (InterruptedException ie) {
if (!(Thread.currentThread() instanceof
ForkJoinWorkerThread))
wasInterrupted = true;
}
}
else
lock.notifyAll();
}
}
}
}
在自旋中,第一步,mark1再次尝试修改runState为RSLOCK,成功直接返回。
mark2检查ForkJoinPool初始化情况,这里没有额外多写个变量做锁,直接利用了stealCounter这个原子变量。因为初始化时(后文的externalSubmit),才会对stealCounter赋值。所以当状态不是STARTED或者stealCounter为空时,让出线程等待。
mark3处,线程不会无限制自旋尝试,会利用wait/notify进入阻塞等待。RSIGNAL代替原状态,表示有线程进入了等待,解锁时要处理。在高并发下,这不是一个好的设计,但进入这里的几率很低,作为兜底还是可以的。
private void unlockRunState(int oldRunState, int newRunState) {
if (!U.compareAndSwapInt(this, RUNSTATE, oldRunState, newRunState)) {
Object lock = stealCounter;
runState = newRunState; // clears RSIGNAL bit
if (lock != null)
synchronized (lock) { lock.notifyAll(); }
}
}
解锁的逻辑就比较简单,如果顺利将状态修改为目标状态,成功大吉。否则表示有别的线程进入了wait,需要调用notifyAll唤醒,重新尝试竞争。
ForkJoinPool最佳实践
(1)最适合的是计算密集型任务
(2)在需要阻塞工作线程时,可以使用ManagedBlocker;
(3)不应该在RecursiveTask的内部使用ForkJoinPool.invoke()/invokeAll();
总结
(1)ForkJoinPool特别适合于“分而治之”算法的实现;
(2)ForkJoinPool和ThreadPoolExecutor是互补的,不是谁替代谁的关系,二者适用的场景不同;
(3)ForkJoinTask有两个核心方法——fork()和join(),有三个重要子类——RecursiveAction、RecursiveTask和CountedCompleter;
(4)ForkjoinPool内部基于“工作窃取”算法实现;
(5)每个线程有自己的工作队列,它是一个双端队列,自己从队列头存取任务,其它线程从尾部窃取任务;
(6)ForkJoinPool最适合于计算密集型任务,但也可以使用ManagedBlocker以便用于阻塞型任务;
(7)RecursiveTask内部可以少调用一次fork(),利用当前线程处理,这是一种技巧;
Garnett还会不断的分享技术干货的,希望你们是我最好的观众!
乐于输出干货的Java技术公众号:Garnett的Java之路。公众号内有大量的技术文章、海量视频资源、精美脑图,不妨来关注一下!回复【资料】领取大量学习资源和免费书籍!