ForkJoin学习
组成结构
Fork-Join 框架有三个核心类:ForkJoinPool,ForkJoinWorkerThread,ForkJoinTask。
和ForkJoinPool类似的有之前的ThreadPoolExecutor线程池,线程池是针对每个任务一个线程,如果有某个任务过大,其他线程是帮不上忙的,ForkJoinPool就是解决这种大任务带来的问题,将一个大任务拆分成多个子任务,使用fork将小任务分发,使用join对结果进行汇总,其实是分而治之的并行版本。
ForkJoinWorkerThread就是pool中的工作线程,在Thread的基础上加了双端队列,fork/join并不是给每个子任务分配线程,而是每个线程都有自己的双端队列(为了给工作窃取用的)。
public class ForkJoinWorkerThread extends Thread {
final ForkJoinPool pool; // 工作线程所在的线程池
final ForkJoinPool.WorkQueue workQueue; // 线程的工作队列(这个双端队列是work-stealing机制的核心)
}
ForkJoinTask:
在实现分治编程时,主要就是调用 ForkJoinTask 的 fork() 和 join() 方法。fork() 方法用于提交子任务,而 join() 方法则用于等待子任务的完成。
fork()方法先判断当前线程(调用fork()来提交任务的线程)是不是一个 ForkJoinWorkerThread 的工作线程,如果是,则将任务加入到内部队列中,否则,由 ForkJoinPool 提供的内部公用的线程池 common 线程池 来执行这个任务。
工作窃取算法
- 每个线程都有自己的一个WorkQueue,该工作队列是一个双端队列。
- 队列支持三个功能push、pop、poll
- push/pop只能被队列的所有者线程调用,而poll可以被其他线程调用。
- 划分的子任务调用fork时,都会被push到自己的队列中。
- 默认情况下,工作线程从自己的双端队列获出任务并执行。
- 当自己的队列为空时,线程随机从另一个线程的队列末尾调用poll方法窃取任务。
构造
ForkJoinPool的构造参数:
parallelism:并行级别,通常默认为JVM可用的处理器个数Runtime.getRuntime().availableProcessors()
factory:用于创建ForkJoinPool中使用的线程。
handler:用于处理工作线程未处理的异常,默认为null。
asyncMode:用于控制WorkQueue的工作模式,效果是工作线程在处理本地任务时使用 FIFO 顺序
使用
执行任务有三种方法,直接累加、创建线程池手动分配任务、Fork/Join框架
1、 创建接口
public interface Calculator {
long sum(long [] nums);
}
public interface ClosablePool {
void close();
}
2、创建三种执行子类
public class PlainCalculator implements Calculator{
@Override
public long sum(long[] nums) {
long total = 0l;
for (long i : nums){
total += i;
}
return total;
}
}
public class ExecutorCalculator implements Calculator, ClosablePool{
// 并发数量,和cpu核一致
private int parallism;
// 线程池
private ExecutorService pool;
public ExecutorCalculator(){
parallism = Runtime.getRuntime().availableProcessors();
pool = Executors.newFixedThreadPool(parallism);
}
private static class SumTask implements Callable<Long> {
private long[] numbers;
private int from;
private int to;
public SumTask(long[] numbers, int from, int to) {
this.numbers = numbers;
this.from = from;
this.to = to;
}
@Override
public Long call() throws Exception {
long total = 0;
for (int i = from; i <= to; i++) {
total += numbers[i];
}
return total;
}
}
@Override
public long sum(long[] nums) {
List<Future<Long>> results = new ArrayList<>();
// 把任务分解为 n 份,交给 n 个线程处理
int part = nums.length / parallism;
for (int i = 0; i < parallism; i++) {
int from = i * part;
int to = (i == parallism - 1) ? nums.length - 1 : (i + 1) * part - 1;
results.add(pool.submit(new SumTask(nums, from, to)));
}
// 把每个线程的结果相加,得到最终结果
long total = 0L;
for (Future<Long> f : results) {
try {
total += f.get();
} catch (InterruptedException e) {
e.printStackTrace();
} catch (ExecutionException e) {
e.printStackTrace();
}
}
return total;
}
@Override
public void close() {
pool.shutdownNow();
}
}
public class ForkJoinCalculator implements Calculator,ClosablePool{
// 并发数量,和cpu核一致
private int parallism;
private ForkJoinPool pool;
public ForkJoinCalculator(){
super();
parallism = Runtime.getRuntime().availableProcessors();
pool = new ForkJoinPool(parallism, ForkJoinPool.defaultForkJoinWorkerThreadFactory, (t , e) -> {
System.out.println("抛出异常");
}, true);
}
private static class SumTask extends RecursiveTask<Long> {
private long[] numbers;
private int from;
private int to;
public SumTask(long[] numbers, int from, int to) {
this.numbers = numbers;
this.from = from;
this.to = to;
}
@Override
protected Long compute() {
// 当需要计算的数字小于n时,直接计算结果
if (to - from < 100) {
long total = 0;
for (int i = from; i <= to; i++) {
total += numbers[i];
}
return total;
// 否则,把任务一分为二,递归计算
} else {
int middle = (from + to) / 2;
SumTask left = new SumTask(numbers, from, middle);
SumTask right = new SumTask(numbers, middle+1, to);
left.fork();
right.fork();
return left.join() + right.join();
}
}
}
@Override
public long sum(long[] numbers) {
return pool.invoke(new SumTask(numbers, 0, numbers.length-1));
}
@Override
public void close() {
pool.shutdownNow();
}
}
3、创建上下文
public class Context {
private Calculator calculator;
public Context(Calculator calculator) {
this.calculator = calculator;
}
public void sumWithTime(long [] nums){
long start = System.currentTimeMillis();
calculator.sum(nums);
long end = System.currentTimeMillis();
System.out.println(calculator.getClass().getName() + ":" + (end - start));
if (calculator instanceof ClosablePool) {
( (ClosablePool) calculator).close();
}
}
}
4、执行
public class Main {
public static void main(String[] args) {
long[] numbers = LongStream.rangeClosed(1, 1000).toArray();
// 第一种,简单单线程累加
Calculator calculator1 = new PlainCalculator();
Context context1 = new Context(calculator1);
context1.sumWithTime(numbers);
// 第三种,使用Fork/Join执行任务
Calculator calculator3 = new ForkJoinCalculator();
Context context2 = new Context(calculator3);
context2.sumWithTime(numbers);
// 第二种,使用线程池手动分割任务
Calculator calculator2 = new ExecutorCalculator();
Context context3 = new Context(calculator2);
context3.sumWithTime(numbers);
}
}
image.png
只有当任务足够大的时候,并且是IO密集型的时候,使用Fork/Join才有明显的效果,否则线程之间的竞争会导致效率较低,理论上Fork/Join和手动拆分是差不多的
绝世好文:http://blog.dyngr.com/blog/2016/09/15/java-forkjoinpool-internals/