手写实现自定义线程池

2020-04-04  本文已影响0人  雨夜都行

从今天开始养成写文章的习惯。同时想把自己知道的,学习到的java知识和大家一起分享,共同进步。

动手实现一个简化版的线程池,可以通过这个例子,了解线程池大致的工作原理
已实现的功能:
1.阻塞等待队列
2.自定义线程池
3.拒绝策略
4.线程池测试

package threadpool;

import java.time.LocalDateTime;

import java.util.ArrayDeque;
import java.util.Deque;
import java.util.concurrent.RejectedExecutionException;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.locks.Condition;
import java.util.concurrent.locks.ReentrantLock;

/**
 * 自定义线程池测试
 *
 * @author
 * @date 2020.4.3
 */
public class MyThreadPoolTest {
    public static void main(String[] args) {
        // MyThreadPool threadPool = new MyThreadPool(2, 1, (r) -> {
            // 拒绝策略1:打印日志
//            System.out.println("拒绝执行"));
//        MyThreadPool threadPool = new MyThreadPool(2, 1, (r) -> {
            // 拒绝策略2:抛出异常
//            throw new RejectedExecutionException("阻塞等待队列已满");
//        });
        MyThreadPool threadPool = new MyThreadPool(2, 1, (r) -> {
            // 拒绝策略3:由当前线程来执行
            r.run();
        });
        threadPool.execute(() -> {
            System.out.println("----->hello1 start");
            System.out.println(LocalDateTime.now());
            // 模拟需要执行很久
            sleep(2);
            System.out.println(LocalDateTime.now());
            System.out.println("----->hello1 end");
        });
        threadPool.execute(() -> {
            System.out.println("----->hello2 start");
            System.out.println(LocalDateTime.now());
            sleep(3);
            System.out.println(LocalDateTime.now());
            System.out.println("----->hello2 end");
        });
        // 先加入队列 , 2s 后执行
        threadPool.execute(() -> {
            System.out.println("----->hello3 start");
            System.out.println(LocalDateTime.now());
            sleep(1);
            System.out.println(LocalDateTime.now());
            System.out.println("----->hello3 end");
        });
        // 进入等待 2.2s 后 应该要进入 队列 . 3s 后执行
        sleep(0.2);
        threadPool.execute(() -> {
            System.out.println(Thread.currentThread().getName() + "----->hello4 start");
            System.out.println(LocalDateTime.now());
            System.out.println("----->hello4 end");
        });
    }

    public static void sleep(double sleepTime) {
        try {
            TimeUnit.MILLISECONDS.sleep((long) (sleepTime * 1000L));
        } catch (InterruptedException e) {
            e.printStackTrace();
        }
    }
}

/**
 * 自定义线程池
 */
class MyThreadPool {

    // 核心线程数
    private int corePoolSize;

    // 提交的任务数
    private AtomicInteger count;

    // 阻塞队列
    private BlockingQueue<Runnable> blockingQueue;

    // 拒绝策略
    RejectedPolicy rejectedPolicy;

    public MyThreadPool(int corePoolSize, int QueueSize) {
        this.corePoolSize =  corePoolSize;
        this.count = new AtomicInteger();
        this.blockingQueue = new BlockingQueue<>(QueueSize);
    }

    public MyThreadPool(int corePoolSize, int QueueSize, RejectedPolicy rejectedPolicy) {
        this.corePoolSize =  corePoolSize;
        this.count = new AtomicInteger();
        this.blockingQueue = new BlockingQueue<>(QueueSize);
        this.rejectedPolicy = rejectedPolicy;
    }

    public void execute(Runnable task) {
        int curCount = count.getAndIncrement();
        if (curCount < corePoolSize) {
            Worker worker = new Worker(task, "worker" + curCount);
            worker.start();
        } else {
            if (rejectedPolicy != null) {
                blockingQueue.put(task, rejectedPolicy);

            } else {
                blockingQueue.put(task);
            }
        }
    }

    public boolean shutdown() {
        return true;
    }

    /**
     * 工作线程用来处理提交的任务
     */
    class Worker extends Thread {

        private Runnable task;
        private int count;
        private String workerName;

        public Worker(Runnable task, String workerName) {
            super(task);
            super.setName(workerName);
            this.workerName = workerName;
            this.task = task;
        }

        @Override
        public void run() {
            while(task != null || (task = blockingQueue.take()) != null) {
                System.out.println(workerName + "执行任务");
                task.run();
                task = null;
                System.out.println(workerName + "执行已执行任务数:" + ++count);
            }
        }
    }
}

/**
 * 拒绝策略
 */
interface RejectedPolicy{

    void rejectedExecution(Runnable runnable);
}

/**
 * 阻塞队列
 * @param <T>
 */
class BlockingQueue<T> {

    // 存放数据
    private Deque<T> blockingQueue;

    // 等待时间
    private int timeOut;

    // 队列大小
    private int capacity;

    // 生产者与消费者公用的锁
    private ReentrantLock lock = new ReentrantLock();

    // 条件变量
    private Condition notFull = lock.newCondition();
    // 条件变量
    private Condition notEmpty = lock.newCondition();

    public BlockingQueue(int capacity){
        this.capacity = capacity;
        this.blockingQueue = new ArrayDeque<>(capacity);
    }


    public void put(T ele) {
        if (ele == null) {
            return;
        }
        lock.lock();
        try {
            while(blockingQueue.size() == capacity) {
                System.out.println("进入等待");
                System.out.println(LocalDateTime.now());
                notFull.await();
            }
            System.out.println("加入队列");
            blockingQueue.addLast(ele);
            notEmpty.signalAll();
        } catch (InterruptedException e) {
            e.printStackTrace();
        } finally {
            lock.unlock();
        }
    }

    /**
     * 带拒绝策略的put
     * @param ele
     * @param rejectedPolicy
     */
    public void put(T ele, RejectedPolicy rejectedPolicy) {
        if (ele == null || rejectedPolicy == null) {
            return;
        }
        lock.lock();
        try {
            while(blockingQueue.size() == capacity) {
                rejectedPolicy.rejectedExecution((Runnable)ele);
                return;
            }
            System.out.println("加入队列");
            blockingQueue.addLast(ele);
            notEmpty.signalAll();
        } finally {
            lock.unlock();
        }
    }

    /**
     * 带超时时间的获取
     * @param timeOut
     * @return
     */
    public T poll(long timeOut) {
        T ele = null;
        long next = 0;
        lock.lock();
        try {
            while(blockingQueue.size() == 0) {
                timeOut = timeOut - next;
                if (timeOut <= 0) {
                    return null;
                }
                next = notEmpty.awaitNanos(timeOut);
            }
            ele = blockingQueue.pollFirst();
            notFull.signalAll();
            return ele;
        } catch (InterruptedException e) {
            e.printStackTrace();
        } finally {
            lock.unlock();
        }
        return ele;
    }

    public T take() {
        T ele = null;
        lock.lock();
        try {
            while(blockingQueue.size() == 0) {
                System.out.println("队列中的数据为空,"+ Thread.currentThread().getName() +"进入等待");
                // 反过来看,等待队列中的数据不为空
                notEmpty.await();
            }
            ele = blockingQueue.pollFirst();
            notFull.signalAll();
            return ele;
        } catch (InterruptedException e) {
            e.printStackTrace();
        } finally {
            lock.unlock();
        }
        return ele;
    }
}

以上就是线程池简化版的实现,下个文章和大家分享AQS同步器的工作原理

上一篇下一篇

猜你喜欢

热点阅读