写一个concurrentHashMap

2019-06-18  本文已影响0人  无聊之园

随手的学习用,不可较真使用。

1、concurrentHashmap是极其复杂的并发类。只是学着写一点点。

2、只是做到了put的只锁node节点,以及put和get的并发操作。

3、因为,没有引入hashmap的取模方式,所以无法,做到resize的时候,直接定位到扩容后的节点,所以无法直接锁节点。

等有时间再把这些机制移入进来。

目前是,通过put的cas操作,以及读写锁来解决resize并发问题的。

public class DiyConcurrentHashMap<K, V> {

    private volatile DiyConcurrentHashMap.Node[] nodes;


    private volatile int size;

    private volatile Integer threshold;

    private Float rate;

    private volatile Integer length;

    private volatile int sizeCtl;

    private static final sun.misc.Unsafe U;

    private static final long SIZECTL;
    private static final int ASHIFT;
    private static final long ABASE;

    private static final long SIZE;

    private ReentrantReadWriteLock lock = new ReentrantReadWriteLock(true);

    static {
        U = UnsafeUtil.getInstance();


        // unsafe操作数组准备
        Class<?> ak = DiyConcurrentHashMap.Node[].class;
        ABASE = U.arrayBaseOffset(ak);
        int scale = U.arrayIndexScale(ak);
        if ((scale & (scale - 1)) != 0)
            throw new Error("data type scale not a power of two");
        ASHIFT = 31 - Integer.numberOfLeadingZeros(scale);

        // unsafe操作sizeCtl变量
        Class<?> k = DiyConcurrentHashMap.class;
        try {
            SIZECTL = U.objectFieldOffset
                    (k.getDeclaredField("sizeCtl"));

            SIZE = U.objectFieldOffset
                    (k.getDeclaredField("size"));
        } catch (NoSuchFieldException e) {
            throw new Error(e);
        }


    }


    public DiyConcurrentHashMap(Integer length, Float rate) {
        this.length = length;
        this.rate = rate;
        Float temp = length * rate;
        this.threshold = temp.intValue();
    }


    private class Node<K, V> {
        private K key;
        private V value;
        private volatile Node next;

        public Node(K key, V value) {
            this.key = key;
            this.value = value;
        }
    }

    public void initTable() {
        int sc = this.sizeCtl;
        while (true) {
            if (sc == -1) {
                break;
            } else {
                if (U.compareAndSwapInt(this, SIZECTL, sc, -1)) {
                    if (nodes == null) {
                        nodes = new DiyConcurrentHashMap.Node[length];
                        U.compareAndSwapInt(this, SIZECTL, sc, 0);
                        break;
                    }
                }
            }

        }
    }

    static final DiyConcurrentHashMap.Node tabAt(DiyConcurrentHashMap.Node[] tab, int i) {
        return (DiyConcurrentHashMap.Node) U.getObjectVolatile(tab, ((long) i << ASHIFT) + ABASE);
    }

    static final boolean casTabAt(DiyConcurrentHashMap.Node[] tab, int i,
                                  DiyConcurrentHashMap.Node c, DiyConcurrentHashMap.Node v) {
        return U.compareAndSwapObject(tab, ((long) i << ASHIFT) + ABASE, c, v);
    }

    public void put(K k, V v) {
        lock.readLock().lock();
        if (nodes == null) {
            initTable();
        }
        Node[] tabs = this.nodes;
        while (true) {
            int index = indexOfArray(k.hashCode());
            Node node = new Node(k, v);
            Node n;
            if ((n = tabAt(tabs, index)) == null) {
                if (casTabAt(tabs, index, null, node)) {
                    break;
                }
            } else {
                synchronized (n) {
                    if ((tabAt(tabs, index)) == n) {
                        Node oldNode = n;
                        while (true) {
                            if (n == null) {
                                // 直接队列尾放置
                                oldNode.next = node;
                                break;
                            } else if (n.key.hashCode() == k.hashCode() && n.key.equals(k)) {
                                // 替换
                                n.value = v;
                                lock.readLock().unlock();
                                return;
                            } else {
                                // 往下遍历
                                oldNode = n;
                                n = n.next;
                            }
                        }
                        break;
                    }
                }
            }
        }
        // size + 1
        while (true) {
            int s = size;
            if (U.compareAndSwapInt(this, SIZE, s, s = s + 1)) {
                break;
            }
        }
        lock.readLock().unlock();
        // 扩容
        if (size > threshold) {
            resize();
        }

    }

    public V get(K k) {
        lock.readLock().lock();
        int index = indexOfArray(k.hashCode());
        DiyConcurrentHashMap.Node[] tabs = nodes;
        Node<K, V> node = tabAt(tabs, index);
        while (true) {
            if (node == null) {
                lock.readLock().unlock();
                return null;
            } else if (node.key.hashCode() == k.hashCode() && node.key.equals(k)) {
                // 找到了
                lock.readLock().unlock();
                return node.value;
            } else {
                node = node.next;
            }
        }
    }


    public void resize() {
        lock.writeLock().lock();
        if(size < threshold){
            lock.writeLock().unlock();
            return;
        }

        int sc = this.sizeCtl;

        int newLength = length << 1;
        this.length = newLength;
        Float temp = length * rate;
        this.threshold = temp.intValue();
        Node[] oldNodes = this.nodes;
        DiyConcurrentHashMap.Node[] newNodes = new DiyConcurrentHashMap.Node[newLength];
        this.nodes = newNodes;
        for (int i = 0; i < oldNodes.length; i++) {
            Node oldNode = oldNodes[i];
            while (true) {
                if (oldNode != null) {
                    int newIndex = indexOfArray(oldNode.key.hashCode());
                    // 新数组的这个位置是否有值
                    if (null == newNodes[newIndex]) {
                        newNodes[newIndex] = oldNode;
                        oldNode = oldNode.next;
                        newNodes[newIndex].next = null;
                    } else {
                        // 直接放到头部
                        Node tmp = oldNode.next;
                        Node t = newNodes[newIndex];
                        newNodes[newIndex] = oldNode;
                        oldNode.next = t;
                        oldNode = tmp;
                    }
                } else {
                    break;
                }

            }
        }

        lock.writeLock().unlock();

    }

    public int indexOfArray(int hashcode) {
        int i = hashcode % length;
        if (i < 0) {
            return 0 - i;
        }
        return i;
//        return 1;
    }

    public int size(){
        return size;
    }
}

测试:
1000个线程,同时并发读写操作。
不让扩容。
第一个自己写的输出:
1000
573798991
第二concurrentHashmap输出:
1000
261345303

可以看到,性能相差2倍多。

public class TestConcurrentHashMap {

    public static void main(String[] args) throws InterruptedException {
        long start = System.nanoTime();
        DiyConcurrentHashMap<String, String> s = new DiyConcurrentHashMap<>(2000, 0.75f);
        CountDownLatch c = new CountDownLatch(1000);
        for(int i = 0; i < 1000; i++){
            new Thread(){
                @Override
                public void run() {
                    for(int i = 0; i < 1000; i++){
                        s.put("hello" + i, "world" + i);
                    }
                    for (int i = 0; i < 1000; i++) {
                        String value = s.get("hello" + i);
                        if (!value.equals("world" + i)) {
                            System.out.println("fail");
                        }
                    }
                    c.countDown();
                }
            }.start();
        }
        c.await();
        System.out.println(s.size());
        long end = System.nanoTime();
        System.out.println(end - start);
    }

    @Test
    public void test() throws InterruptedException {
        long start = System.nanoTime();
        ConcurrentHashMap<String, String> s = new ConcurrentHashMap<>(2000, 0.75f);
        CountDownLatch c = new CountDownLatch(1000);
        for(int i = 0; i < 1000; i++){
            new Thread(){
                @Override
                public void run() {
                    for(int i = 0; i < 1000; i++){
                        s.put("hello" + i, "world" + i);
                    }
                    for (int i = 0; i < 1000; i++) {
                        String value = s.get("hello" + i);
                        if (!value.equals("world" + i)) {
                            System.out.println("fail");
                        }
                    }
                    c.countDown();
                }
            }.start();
        }
        c.await();
        System.out.println(s.size());
        long end = System.nanoTime();
        System.out.println(end - start);
    }

}

如果让其扩容。可以看到,自己写的,扩容直接锁住,性能太差,concurrenthashmap性能几乎没受到影响。
第一个:
1000
12767916688
第二个:
1000
298357664

public class TestConcurrentHashMap {

    public static void main(String[] args) throws InterruptedException {
        long start = System.nanoTime();
        DiyConcurrentHashMap<String, String> s = new DiyConcurrentHashMap<>(16, 0.75f);
        CountDownLatch c = new CountDownLatch(1000);
        for(int i = 0; i < 1000; i++){
            new Thread(){
                @Override
                public void run() {
                    for(int i = 0; i < 1000; i++){
                        s.put("hello" + i, "world" + i);
                    }
                    for (int i = 0; i < 1000; i++) {
                        String value = s.get("hello" + i);
                        if (!value.equals("world" + i)) {
                            System.out.println("fail");
                        }
                    }
                    c.countDown();
                }
            }.start();
        }
        c.await();
        System.out.println(s.size());
        long end = System.nanoTime();
        System.out.println(end - start);
    }

    @Test
    public void test() throws InterruptedException {
        long start = System.nanoTime();
        ConcurrentHashMap<String, String> s = new ConcurrentHashMap<>(16, 0.75f);
        CountDownLatch c = new CountDownLatch(1000);
        for(int i = 0; i < 1000; i++){
            new Thread(){
                @Override
                public void run() {
                    for(int i = 0; i < 1000; i++){
                        s.put("hello" + i, "world" + i);
                    }
                    for (int i = 0; i < 1000; i++) {
                        String value = s.get("hello" + i);
                        if (!value.equals("world" + i)) {
                            System.out.println("fail");
                        }
                    }
                    c.countDown();
                }
            }.start();
        }
        c.await();
        System.out.println(s.size());
        long end = System.nanoTime();
        System.out.println(end - start);
    }

}

再测试一下hashmap。在不扩容的情况下,最起码比hashmap性能好一点。
hashmap的扩容几乎没有影响。
1000
1069096183

@Test
    public void test2() throws InterruptedException {
        long start = System.nanoTime();
        HashMap<String, String> s = new HashMap<>(2000, 0.75f);
        ReentrantLock lock = new ReentrantLock();
        CountDownLatch c = new CountDownLatch(1000);
        for(int i = 0; i < 1000; i++){
            new Thread(){
                @Override
                public void run() {
                    for(int i = 0; i < 1000; i++){
                        lock.lock();
                        s.put("hello" + i, "world" + i);
                        lock.unlock();
                    }
                    for (int i = 0; i < 1000; i++) {
                        lock.lock();
                        String value = s.get("hello" + i);
                        lock.unlock();
                        if (!value.equals("world" + i)) {
                            System.out.println("fail");
                        }
                    }
                    c.countDown();
                }
            }.start();
        }
        c.await();
        System.out.println(s.size());
        long end = System.nanoTime();
        System.out.println(end - start);
    }
上一篇下一篇

猜你喜欢

热点阅读