ForkJoin
ForkJoin 是一个并行计算框架,以支持分治任务模型, Fork 对应任务分解,Join对应任何合并。PS:大任务 拆分成小任务,最后在合并结果
ForkJoinPool
ForkJoinPoll 跟 ThreadPoolExecuto类似,都是继承于AbstractExecutorService类。
但是 ForkJoinPool 有多个任务队列,而且这个任务队列是双向的,内部使用的是“工作窃取”算法,就是当前工作线程 完成工作队列里面的任务后,可以帮助其他线程完成任务队列里面的任务,任务类是ForkJoinTask。
public class ForkJoinPool extends AbstractExecutorService {
public ForkJoinPool() {
this(Math.min(MAX_CAP, Runtime.getRuntime().availableProcessors()),
defaultForkJoinWorkerThreadFactory, null, false);
}
/**
* parallelism 指定的并行级别
* ForkJoinWorkerThreadFactory 用来创建工作线程ForkJoinWorkerThread
* UncaughtExceptionHandler 用来捕获处理,线程中出现的异常
* mode 表示工作模式,1<<16 表示FIFO队列;0表示LIFO队列。
* workerNamePrefix 线程名字的前缀
**/
private ForkJoinPool(int parallelism,
ForkJoinWorkerThreadFactory factory,
UncaughtExceptionHandler handler,
int mode,
String workerNamePrefix) {
this.workerNamePrefix = workerNamePrefix;
this.factory = factory;
this.ueh = handler;
this.config = (parallelism & SMASK) | mode;
long np = (long)(-parallelism); // offset ctl counts
this.ctl = ((np << AC_SHIFT) & AC_MASK) | ((np << TC_SHIFT) & TC_MASK);
}
/**
* 执行任务,并返回任务的结果,里面会调用externalPush
**/
public <T> T invoke(ForkJoinTask<T> task) {
if (task == null)
throw new NullPointerException();
externalPush(task);
return task.join();
}
// 内部会调用externalSubmit
final void externalPush(ForkJoinTask<?> task) {
WorkQueue[] ws; WorkQueue q; int m;
int r = ThreadLocalRandom.getProbe();
int rs = runState;
if ((ws = workQueues) != null && (m = (ws.length - 1)) >= 0 &&
(q = ws[m & r & SQMASK]) != null && r != 0 && rs > 0 &&
U.compareAndSwapInt(q, QLOCK, 0, 1)) {
ForkJoinTask<?>[] a; int am, n, s;
if ((a = q.array) != null &&
(am = a.length - 1) > (n = (s = q.top) - q.base)) {
int j = ((am & s) << ASHIFT) + ABASE;
U.putOrderedObject(a, j, task);
U.putOrderedInt(q, QTOP, s + 1);
U.putIntVolatile(q, QLOCK, 0);
if (n <= 1)
signalWork(ws, q);
return;
}
U.compareAndSwapInt(q, QLOCK, 1, 0);
}
externalSubmit(task);
}
// 创建workQueues,并选取一个队列,任务入队。
private void externalSubmit(ForkJoinTask<?> task) {
int r; // initialize caller's probe
if ((r = ThreadLocalRandom.getProbe()) == 0) {
ThreadLocalRandom.localInit();
r = ThreadLocalRandom.getProbe();
}
for (;;) {
WorkQueue[] ws; WorkQueue q; int rs, m, k;
boolean move = false;
if ((rs = runState) < 0) {
tryTerminate(false, false); // help terminate
throw new RejectedExecutionException();
}
else if ((rs & STARTED) == 0 || // initialize
((ws = workQueues) == null || (m = ws.length - 1) < 0)) {
int ns = 0;
rs = lockRunState();
try {
if ((rs & STARTED) == 0) {
U.compareAndSwapObject(this, STEALCOUNTER, null,
new AtomicLong());
// create workQueues array with size a power of two
int p = config & SMASK; // ensure at least 2 slots
int n = (p > 1) ? p - 1 : 1;
n |= n >>> 1; n |= n >>> 2; n |= n >>> 4;
n |= n >>> 8; n |= n >>> 16; n = (n + 1) << 1;
workQueues = new WorkQueue[n];
ns = STARTED;
}
} finally {
unlockRunState(rs, (rs & ~RSLOCK) | ns);
}
}
else if ((q = ws[k = r & m & SQMASK]) != null) {
if (q.qlock == 0 && U.compareAndSwapInt(q, QLOCK, 0, 1)) {
ForkJoinTask<?>[] a = q.array;
int s = q.top;
boolean submitted = false; // initial submission or resizing
try { // locked version of push
if ((a != null && a.length > s + 1 - q.base) ||
(a = q.growArray()) != null) {
int j = (((a.length - 1) & s) << ASHIFT) + ABASE;
U.putOrderedObject(a, j, task);
U.putOrderedInt(q, QTOP, s + 1);
submitted = true;
}
} finally {
U.compareAndSwapInt(q, QLOCK, 1, 0);
}
if (submitted) {
signalWork(ws, q);
return;
}
}
move = true; // move on failure
}
else if (((rs = runState) & RSLOCK) == 0) { // create new queue
q = new WorkQueue(this, null);
q.hint = r;
q.config = k | SHARED_QUEUE;
q.scanState = INACTIVE;
rs = lockRunState(); // publish index
if (rs > 0 && (ws = workQueues) != null &&
k < ws.length && ws[k] == null)
ws[k] = q; // else terminated
unlockRunState(rs, rs & ~RSLOCK);
}
else
move = true; // move if busy
if (move)
r = ThreadLocalRandom.advanceProbe(r);
}
}
}
FutureTask
任务类,需要消费任务。FutureTask有两个子类:RecursiveTask 表示有返回结果,RecursiveAction 表示没有返回结果。所以使用的时候需要写一个任务类继承RecursiveTask 或者 RecursiveAction,然后重写 compute 方法。
public abstract class ForkJoinTask<V> implements Future<V>, Serializable {
// 任务的执行入口
final int doExec() {
int s; boolean completed;
if ((s = status) >= 0) {
try {
completed = exec();// 调用exec 方法
} catch (Throwable rex) {
return setExceptionalCompletion(rex);
}
if (completed)
s = setCompletion(NORMAL);
}
return s;
}
}
// 子类可以接受返回结果
public abstract class RecursiveTask<V> extends ForkJoinTask<V> {
V result;
protected final void setRawResult(V value) {
result = value;
}
protected abstract V compute();
protected final boolean exec() {
result = compute();
return true;
}
}
// 子类表示没有返回结果
public abstract class RecursiveAction extends ForkJoinTask<Void> {
protected final boolean exec() {
compute();
return true;
}
protected abstract void compute();
}
DEMO:多线程协作计算累和
/**
* @author Jinhanlai
* @date 2021/8/30 9:23 下午
*/
import java.util.concurrent.ForkJoinPool;
import java.util.concurrent.RecursiveTask;
/**
* ForkJoin 是一个并行计算框架,以支持分治任务模型,Fork 对应任务分解,Join对应任何合并
* 实现多线程的协作完成 多线程帮忙计算累和
*/
public class ForkJoinTest {
public static void main(String[] args) {
int[] arr=new int[]{1,2,3,4,5,6,7,8,9,10};
ForkJoinPool pool = new ForkJoinPool();
SumTask calc=new SumTask(arr,0,arr.length-1,arr.length/2);
pool.invoke(calc);
System.out.println("The answer is "+calc.join());
}
}
/**
* 计算累和的任务
*/
class SumTask extends RecursiveTask<Integer> {
int[] arr;// 计算的数组
int from,to;// 计算的索引范围
int threshold;// 拆分子任务的阈值
public SumTask(int[] arr,int from,int to,int threshold){
this.arr=arr;
this.from=from;
this.to=to;
this.threshold=threshold;
}
@Override
protected Integer compute() {
if(to-from<threshold){
int ans=0;
for(int i=from;i<=to;i++){
ans+=arr[i];
}
return ans;
}else{
// 拆分任务
int mid=from+(to-from)/2;
SumTask sumTask1 = new SumTask(arr, from, mid, threshold);
SumTask sumTask2 = new SumTask(arr, mid+1, to, threshold);
invokeAll(sumTask1,sumTask2);
return sumTask1.join()+ sumTask2.join();
}
}
}