关于 CopyOnWriteArrayList 的一个简单优化

2020-06-07  本文已影响0人  M_lear

一、优化动机

COW 简介:增删改都会加锁并拷贝工作数组,在拷贝数组上做完增删改操作后,会把拷贝数组切换为工作数组,在这个过程中,不会阻塞读操作。

所以 COW 很适合于读多写少的情况。

但是如果写再多一点呢,或者大部分的增删改只发生在工作数组的末尾呢。对于这些情况,在有大量数据的业务场景中,每次写操作都要拷贝整个工作数组,是很不划算的。

二、优化实现

考虑大部分的增删改只发生在工作数组末尾的这种情况。

我们可以使用一个数组来保存初始数据,然后后续发生在数组末尾的增删改数据保存在另一个数组中。

这样对于大量的增删改操作就只需要拷贝后面那个较短的数组即可,而不需要拷贝整个数据。

所以我考虑使用两个数组来保存数据,两个数组在逻辑上当成一个数组的前后两段使用。

三、优化导致的问题和解决方案

3.1 读的效率问题

COW 最显著的优点,就是读很快,任何写操作都不会阻塞读操作。所以对 COW 的优化,也应该具有极高的读效率。

两个数组当成一个数组使用的这个优化,对读操作最明显的影响是,在每次读的时候,都需要判断读操作发生在两个数组中的哪一个上面。

对于这个问题,每次通过判断语句判断,显然是很影响效率的。

对应的解决办法是,利用 Java 底层对数组访问做的安全控制,我们在代码层面不加以判断,等底层判断发现越界并抛出数组越界异常之后,再去处理。
代码示意如下,

    public E get(int index) {
        try{
            return o[index]; // o 为第一段数组
        } catch (ArrayIndexOutOfBoundsException e) {
            return c[index-o.length]; // c 为第二段数组
        }
    }

可以看到,这样做,对于发生在比较长的第一段数组中的读完全没有任何影响。但对于读第二段较短的数组,由于需要捕获数组越界异常,并做一个减法下标映射,所以会稍微影响读效率。

另外,还有一个小问题是,最终抛出的越界异常是针对第二段数组的,给出的越界信息会不正确。不过这个小问题不关乎大局,完全可以忽略。
纠正数组越界信息,代码示意如下:

    public E get(int index) {
        try{
            return o[index]; // o 为第一段数组
        } catch (ArrayIndexOutOfBoundsException e) {
            try{
                return c[index-o.length]; // c 为第二段数组
            } catch (ArrayIndexOutOfBoundsException e1) {
                // 抛出正确的越界信息
                throw new ArrayIndexOutOfBoundsException(index);
            }
        }
    }

3.2 读的正确性问题

上一个问题的解决方案,又会引发出另一个问题。
即,在读第二段数组中的元素时,第一段数组的长度不能变长。

举例说明,假设第一段数组长度为 10,第二段数组长度为 5。
此时用户想要 get(10),读第一段数组时,发现越界转而去读第二段数组。
此时另一个用户在第一段数组中添加元素,并且已经执行完拷贝、增加、切换数组的操作。那么此时第一段数组的长度就变成了 11。前一个用户的 get(10) 执行到 c[index-o.length] 时,index-o.length 为 -1,此时变成想要访问第二段数组下标为 -1 的元素,显然会抛出原本不应该抛出的数组越界异常。

解决方案一:
不允许在第一段数组上增加元素。

解决方案二:
捕获发生在第二段数组上的数组越界异常。
如果发现是第一段数组变长所致,尝试重新读。有点自旋锁的意思。
代码示意如下,

    public E get(int index) {
        try{
            return o[index]; // o 为第一段数组
        } catch (ArrayIndexOutOfBoundsException e) {
            int i = 0;
            try{
                return c[i = index-o.length]; // c 为第二段数组
            } catch (ArrayIndexOutOfBoundsException e1) {
                if(i < 0) return get(index); // 重新读
                else throw new ArrayIndexOutOfBoundsException(index);
            }
        }
    }

对于读多写少的场景,完全可以这样递归读。

如果写操作比较频繁,也可以对第一段数组加锁,然后再次进行读操作。
代码示意如下,

    public E get(int index) {
        try{
            return o[index]; // o 为第一段数组
        } catch (ArrayIndexOutOfBoundsException e) {
            int i = 0;
            try{
                return c[i = index-o.length]; // c 为第二段数组
            } catch (ArrayIndexOutOfBoundsException e1) {
                if(i < 0){
                    synchronized (o) { // 此时,读操作有一定概率被阻塞
                        return get(index);
                    }
                }else throw new ArrayIndexOutOfBoundsException(index);
            }
        }
    }

这样,大多数的读,都和之前一样。只有在读第二段数组时,第一段数组变长,并且 index-o.length 为负这种情况,才会阻塞读操作。

四、示意代码

时间关系,只实现了构造方法和基本的增删改查。

package main;

import java.util.Arrays;
import java.util.Collection;
import java.util.concurrent.CopyOnWriteArrayList;

public class IncrementCOW<E> {

    // 成员变量
    private volatile Object[] o;
    private CopyOnWriteArrayList<E> c = new CopyOnWriteArrayList<>();

    // 构造方法
    public IncrementCOW(){
        o = new Object[0];
    }
    
    public IncrementCOW(E[] toCopyIn){
        o = Arrays.copyOf(toCopyIn, toCopyIn.length, Object[].class);
    }

    public IncrementCOW(Collection<? extends E> c){
        Object[] elements = c.toArray();
        // c.toArray might (incorrectly) not return Object[] (see 6260652)
        if (elements.getClass() != Object[].class)
            elements = Arrays.copyOf(elements, elements.length, Object[].class);
        o = elements;
    }

    // 普通方法
    public int size(){
        return o.length + c.size();
    }

    @SuppressWarnings("unchecked")
    private E get(Object[] a, int index) {
        return (E) a[index];
    }
    
    public E get(int index) {
        try{
            return get(o, index);
        }catch (ArrayIndexOutOfBoundsException e) {
            int i = 0;
            try{
                return c.get(i = index-o.length);
            }catch (ArrayIndexOutOfBoundsException e1) {
                if(i < 0) return get(index); // 不加锁的方式
                else throw new ArrayIndexOutOfBoundsException(index);
            }
        }
    }

    /**
     * 没有同时给 o 和 c 加锁
     * 所以在 set c 的时候,index 可能不是逻辑数组实时的下标
     * 只保证 index 是用户调用 set 方法后逻辑数组一个快照的下标
     * */
    public E set(int index, E element) {
        if(index < 0) throw new IndexOutOfBoundsException("Index: "+index);
        int i;
        synchronized (o) {
            i = index-o.length;
            if(i < 0){
                E oldValue = get(o, index);
                if(oldValue != element){
                    Object[] newElements = Arrays.copyOf(o, o.length);
                    newElements[index] = element;
                    o = newElements;
                }
                return oldValue;
            }
        }
        return c.set(i, element);
    }

    public void add(E e) {
        c.add(e);
    }

    // 同 set 方法,没有同时给 o 和 c 加锁
    public void add(int index, E element) {
        if(index < 0) throw new IndexOutOfBoundsException("Index: "+index);
        int i;
        synchronized (o) {
            i = index-o.length;
            if(i < 0){
                Object[] newElements = new Object[o.length + 1];
                System.arraycopy(o, 0, newElements, 0, index);
                System.arraycopy(o, index, newElements, index + 1,
                        -i);
                newElements[index] = element;
                o = newElements;
            }
        }
        c.add(i, element);
    }

    // 同 set 方法,没有同时给 o 和 c 加锁
    public E remove(int index) {
        if(index < 0) throw new IndexOutOfBoundsException("Index: "+index);
        int i;
        synchronized (o) {
            i = index-o.length;
            if(i < 0){
                E oldValue = get(o, index);
                int numMoved = -i - 1;
                if(numMoved == 0)
                    o = Arrays.copyOf(o, o.length - 1);
                else{
                    Object[] newElements = new Object[o.length - 1];
                    System.arraycopy(o, 0, newElements, 0, index);
                    System.arraycopy(o, index + 1, newElements, index,
                                     numMoved);
                    o = newElements;
                }
                return oldValue;
            }
        }
        return c.remove(i);
    }

}
上一篇下一篇

猜你喜欢

热点阅读