ThreadLocal分析

2021-07-09  本文已影响0人  雨之都

上一次看ThreadLocal的源代码已经是很久之前的事情了,今早突然想起发现自己连ThreadLocal的原理一点也想不起了,因此重新再读一次源码,分析一下ThreadLocal的原理

ThreadLocal正如其名(线程本地)这是指对象设置或者获取的值都是当前线程访问的,其他线程设置和访问的不是同一个对象(当前前提是initialValue和setValue使用姿势正确).诸如数据库的连接对象就可以使用ThreadLocal来保存.下面就可以展开分析了

对于ThreadLocal来说,公开的函数就是

通常ThreadLocal的使用姿势有,直接构造一个ThreadLocal对象,然后调用set 设置值, 调用get获取值,这种情况是对于ThreadLocal没有初值的情况,因此如果我们在调用get之前没有调用set.那么第一次获取的值就是空的,对于这种情况,ThreadLocal提供了一个保护方法

子类通过覆写这个方法,使得ThreadLocal在get的时候第一次能够获取到初始值,ThreadLocal的一个静态函数

构造函数

public ThreadLocal() {
    }

可以看到ThreadLocal的构造函数是空的,平平无奇,ThreadLocal的魔法应该是在set和get的时候会发生的,构造函数应该也不需要做什么特别的工作

get函数

public T get() {
        Thread t = Thread.currentThread();
        ThreadLocalMap map = getMap(t);
        if (map != null) {
            ThreadLocalMap.Entry e = map.getEntry(this);
            if (e != null) {
                @SuppressWarnings("unchecked")
                T result = (T)e.value;
                return result;
            }
        }
        return setInitialValue();
    }

根据源代码,查找主要分为如下几个步骤

  1. 获取当前Thread
  2. 根据当前Thread关联的ThreadLocalMap
  3. 如果不为空,调用其查找函数
  4. 如果为空的化,那么调用setInitialValue设置初值并且返回,setInitialValue和set函数的流程基本查不到,所以这里不赘述,后文分析完set函数之后,基本也就明白了它的功能了
  5. 因此我们对不为空的查找函数,深入去了解一下
private Entry getEntry(ThreadLocal<?> key) {
                        // 通过当前ThreadLocal的hashKey获取目标位置
            int i = key.threadLocalHashCode & (table.length - 1);
            Entry e = table[i];
                        // 如果目标位置的元素不为空且key相同,那么就查找完毕,返回该Entry
            if (e != null && e.get() == key)
                return e;
            else
                return getEntryAfterMiss(key, i, e);
        }

可以看到当第一次在目标位置没有找到的时候,会调用getEntryAfterMiss函数,我们看一下该函数的实现;可以看到就是往后线性遍历,一直到Entry为空,未找到则直接返回为空,我们注意到中间有一步是当获取的key为空的情况(当然可能! 因为key是ThreadLocal通过弱引用的方式保存的,如果ThreadLocal被销毁了,那么key就是为空了)

private Entry getEntryAfterMiss(ThreadLocal<?> key, int i, Entry e) {
            Entry[] tab = table;
            int len = tab.length;

            while (e != null) {
                ThreadLocal<?> k = e.get();
                if (k == key)
                    return e;
                if (k == null)
                    expungeStaleEntry(i);
                else
                    i = nextIndex(i, len);
                e = tab[i];
            }
            return null;
        }

垃圾回收可能发生在任何时间,所以当key无效的时候,我们应该做清理工作,我个人理解的清理工作是遍历从i到后面所有的已经过期了的,将这些移除,并且对于之后的元素rehash.重新插入队列,那么我们看一下代码观察看是否做了这样的事情

private int expungeStaleEntry(int staleSlot) {
            Entry[] tab = table;
            int len = tab.length;

            // expunge entry at staleSlot
            tab[staleSlot].value = null;
            tab[staleSlot] = null;
            size--;

            // Rehash until we encounter null
            Entry e;
            int i;
            for (i = nextIndex(staleSlot, len);
                 (e = tab[i]) != null;
                 i = nextIndex(i, len)) {
                ThreadLocal<?> k = e.get();
                if (k == null) {
                    e.value = null;
                    tab[i] = null;
                    size--;
                } else {
                                        // 不为空的元素就重新插入
                    int h = k.threadLocalHashCode & (len - 1);
                    if (h != i) {
                        tab[i] = null;

                        // Unlike Knuth 6.4 Algorithm R, we must scan until
                        // null because multiple entries could have been stale.
                        while (tab[h] != null)
                            h = nextIndex(h, len);
                        tab[h] = e;
                    }
                }
            }
                        // 下一个可以插入的为null的slot
            return i;
        }

看代码确实就是做了这样的事情,所以插入实际上就是查找+线性遍历

set函数

public void set(T value) {
        // 获取当前的线程
    Thread t = Thread.currentThread();
    ThreadLocalMap map = getMap(t);
    if (map != null)
        map.set(this, value);
    else
        createMap(t, value);
}

看到第二步调用了getMap,并且把当前线程传入了,那么这里做了什么?展开看看

ThreadLocalMap getMap(Thread t) {
        return t.threadLocals;
    }

发现只是返回了线程的这个成员,那如果我们根本就没有设置过值的化,那么这个值理所当然是空的,因此会走到下面的createMap函数,继续往下跟

void createMap(Thread t, T firstValue) {
        t.threadLocals = new ThreadLocalMap(this, firstValue);
    }

可以看到直接创建了ThreadLocalMap对象,并且把this和firstValue当成构造函数的参数传入了

那在分析流程继续之前,有必要看一下ThreadLocalMap的源码,看ThreadLocalMap的注释说,ThreadLocalMap是一个hashmap,是专门为了维护线程本地数据而造出的一个数据结构,因此它没有暴露出任何方法,但为了让Thread能够访问,所以ThreadLocalMap本身是包访问权限的

之前研究HashMap的时候发现,HashMap无非是Entry的数组+链表,那ThreadLocalMap肯定也不例外,看一下它的Entry长什么样子

static class Entry extends WeakReference<ThreadLocal<?>> {
            /** The value associated with this ThreadLocal. */
            Object value;

            Entry(ThreadLocal<?> k, Object v) {
                super(k);
                value = v;
            }
        }

将ThreadLocal作为k,并且key传给了父类的构造函数,且因为父类是WeakReference.所以Entry的key是弱引用的,接着来看一下ThreadLocalMap的构造函数

ThreadLocalMap(ThreadLocal<?> firstKey, Object firstValue) {
            table = new Entry[INITIAL_CAPACITY];
            int i = firstKey.threadLocalHashCode & (INITIAL_CAPACITY - 1);
            table[i] = new Entry(firstKey, firstValue);
            size = 1;
            setThreshold(INITIAL_CAPACITY);
        }

  1. 根据INITIAL_CAPACITY构造了一个Entry的数组table

  2. 第二行代码计算了将要插入表中的位置

    • 里面访问了threadLocalHashCode的属性,看一下threadLocalHashCode长什么样子
    private final int threadLocalHashCode = nextHashCode();
    
    • 那nextHashCode做了什么呢,继续往下翻
        private static final int HASH_INCREMENT = 0x61c88647;
        /**
         * Returns the next hash code.
         */
        private static int nextHashCode() {
            return nextHashCode.getAndAdd(HASH_INCREMENT);
        }
    

    nextHashCode是一个原子类型的数据,每次调用这个方法都加上了一个HASH_INCREMENT,这个数字的具体原理没有深入研究,谷歌了一下发现通过这样的方式能够减少碰撞,暂且不表,

    • 通过第二步获取的nextHashCode和位的大小减1 进行位于,找到了元素被防止的防止,构造一个Entry.将其放入数组

    最后一步调用了了setThreshold方法设置了一下阈值,和hashmap的阈值是等同的

    private void setThreshold(int len) {
                threshold = len * 2 / 3;
            }
    

    可以看到阈值是长度的2/3。

    如果我们创建了第二个ThreadLocal.同样调用设值。

    假设这个时候相应的ThreadLocalMap已经创建好了,那么就会走到If中的map.set(this, value)中去,看一下set方法长什么样子

    private void set(ThreadLocal<?> key, Object value) {
    
                // We don't use a fast path as with get() because it is at
                // least as common to use set() to create new entries as
                // it is to replace existing ones, in which case, a fast
                // path would fail more often than not.
    
                Entry[] tab = table;
                int len = tab.length;
                            // 通过threadLocalHashCode获取要插的下一个点,每一个ThreadLocal对象的
                            // threadLocalHashCode都不一致
                int i = key.threadLocalHashCode & (len-1);
                            // 线性探测法去避免冲突
                for (Entry e = tab[i];
                     e != null;
                     e = tab[i = nextIndex(i, len)]) {
                                    // 获取当前entry对应的key
                    ThreadLocal<?> k = e.get();
                                    // 如果相等就直接替换
                    if (k == key) {
                        e.value = value;
                        return;
                    }
                                    // 如果为空;说明当前的ThreadLocal对象被回收了;那么执行替换
                    if (k == null) {
                        replaceStaleEntry(key, value, i);
                        return;
                    }
                }
    
                tab[i] = new Entry(key, value);
                int sz = ++size;
                            // 如果cleanSomwSlots没有清理移除元素,并且下面已经超过threshold了;那么需要执行rehash
                if (!cleanSomeSlots(i, sz) && sz >= threshold)
                    rehash();
            }
    
                    private static int nextIndex(int i, int len) {
                return ((i + 1 < len) ? i + 1 : 0);
            }
    
                    private static int prevIndex(int i, int len) {
                return ((i - 1 >= 0) ? i - 1 : len - 1);
            }
    
                    
    private void replaceStaleEntry(ThreadLocal<?> key, Object value,
                                           int staleSlot) {
                Entry[] tab = table;
                int len = tab.length;
                Entry e;
    
                // Back up to check for prior stale entry in current run.
                // We clean out whole runs at a time to avoid continual
                // incremental rehashing due to garbage collector freeing
                // up refs in bunches (i.e., whenever the collector runs).
                int slotToExpunge = staleSlot;
                for (int i = prevIndex(staleSlot, len);
                     (e = tab[i]) != null;
                     i = prevIndex(i, len))
                    if (e.get() == null)
                        slotToExpunge = i;
    
                // Find either the key or trailing null slot of run, whichever
                // occurs first
                for (int i = nextIndex(staleSlot, len);
                     (e = tab[i]) != null;
                     i = nextIndex(i, len)) {
                    ThreadLocal<?> k = e.get();
    
                    // If we find key, then we need to swap it
                    // with the stale entry to maintain hash table order.
                    // The newly stale slot, or any other stale slot
                    // encountered above it, can then be sent to expungeStaleEntry
                    // to remove or rehash all of the other entries in run.
                    if (k == key) {
                        e.value = value;
                                            // 把当前的数值和过期的slot交换;这里必须要交换;否则就破坏了插入的原则;可能会导致之后查找失败
                        tab[i] = tab[staleSlot];
                        tab[staleSlot] = e;
                                            // 这个时候i的slot就是过期的
                        // Start expunge at preceding stale entry if it exists
                        if (slotToExpunge == staleSlot)
                            slotToExpunge = i;
                                            // slotToExpunge是要擦除的起点
                        cleanSomeSlots(expungeStaleEntry(slotToExpunge), len);
                        return;
                    }
    
                    // If we didn't find stale entry on backward scan, the
                    // first stale entry seen while scanning for key is the
                    // first still present in the run.
                    if (k == null && slotToExpunge == staleSlot)
                        slotToExpunge = i;
                }
    
                          // key并不在map里面;staleSlot是可以插入的,直接插入
                tab[staleSlot].value = null;
                tab[staleSlot] = new Entry(key, value);
                            
                // slotToExpunge是前项已经过期了的,做一些清理工作
                if (slotToExpunge != staleSlot)
                    cleanSomeSlots(expungeStaleEntry(slotToExpunge), len);
            }
    
    // 启发式地搜索,最多可能搜索o(n),通常情况应该是o(lgn)
    // 如果找到了一个过期的,就把这个过期的元素重新清理,并且把没有过期的重新hash重新插入
    // 关于expungeStaleEntry上文已经分析过了详细的流程
    private boolean cleanSomeSlots(int i, int n) {
                boolean removed = false;
                Entry[] tab = table;
                int len = tab.length;
                do {
                    i = nextIndex(i, len);
                    Entry e = tab[i];
                    if (e != null && e.get() == null) {
                        n = len;
                        removed = true;
                        i = expungeStaleEntry(i);
                    }
                } while ( (n >>>= 1) != 0);
                return removed;
            }
    

    在分析完成ThreadLocal后之后,我提出了我自己的几个 Q && A

    1. 为什么ThreadLocal是线程安全的

    因为ThreadLocal操作的是当前线程的一个threadLocals变量,不同线程操作的是不同的变量,同一时间,一个线程只可能有一个代码序列访问threadLocals.因此ThreadLocal是线程安全的

    1. 在一个线程里面创建无数个ThreadLocal,有没有可能有两个ThreadLocal的key完全一致?

    有可能,因为ThreadLocal的key的hashcode就是从0一直叠加魔法数字,所以创建大量的ThreadLocal可能导致两个key完全一致,但这个场景在实际中实际上不可能,我相信正常的开发同学也不会new异常数量的ThreadLocal的

    1. ThreadLocal的大致原理?

    实际上ThreadLocal就是散列+开放地址法(解决冲突),之所以看ThreadLocal的代码感觉有点复杂,是因为ThreadLocal还处理了每次插入的时候以及获取的时候去删除已经过期了的元素,所以这也是我们将ThreadLocal的key封装弱引用的原因

上一篇下一篇

猜你喜欢

热点阅读