ThreadLocal原理

2017-06-19  本文已影响0人  丑星星

我们通常用ThreadLocal来实现线程局部变量的存储。在许多开源框架中ThreadLocal被广泛使用。这篇文章来讨论一下ThreadLocal的实现原理

一、ThreadLocal的实现原理

ThreadLocal的实现原理其实很简单,我们先来看一下ThreadLocal常用的几个方法:

public void set(T value)
public void remove()
public T get()

通过set、get、remove这个三个方法,可以实现线程局部变量的添加、获取、删除。
首先我们先来看一下set方法的实现:

  public void set(T value) {
        Thread t = Thread.currentThread();
        ThreadLocalMap map = getMap(t);
        if (map != null)
            map.set(this, value);
        else
            createMap(t, value);
    }

    ThreadLocalMap getMap(Thread t) {
        return t.threadLocals;
    }

我们可以看到set方法是将对象value保存在当前线程的threadLocals这个ThreadLocalMap中,以当前的ThreadLocal对象作为map的键值。ThreadLoccaMap是ThreadLocal类的一个内部类,我们先不管它的实现,先来看一下ThreadLocal的get方法和remove方法的实现。

get方法:
    public T get() {
        Thread t = Thread.currentThread();
        ThreadLocalMap map = getMap(t);
        if (map != null) {
            ThreadLocalMap.Entry e = map.getEntry(this);
            if (e != null) {
                @SuppressWarnings("unchecked")
                T result = (T)e.value;
                return result;
            }
        }
        return setInitialValue();
    }

    private T setInitialValue() {
        T value = initialValue();
        Thread t = Thread.currentThread();
        ThreadLocalMap map = getMap(t);
        if (map != null)
            map.set(this, value);
        else
            createMap(t, value);
        return value;
    }

我们可以看到,当当前线程的ThreadLocalMap对象存在的时候,返回值从这个对象中获取,这个和set方法保存value是相对应的,都是从当前线程保存的ThreadLocalMap对象中存储和获取。当ThreadLocalMap对象不存在的时候返回 setInitialValue() 返回的对象。这个方法我们可以看到:通过 initialValue()方法获得一个value对象,这个方法默认返回一个null,留给子类实现,用来初始化一个用来保存的对象的默认值。然后将这个对象放在当前线程的ThreadLocalMap中(如果不存在ThreadLocalMap对象就创建一个)。

remove方法:
     public void remove() {
         ThreadLocalMap m = getMap(Thread.currentThread());
         if (m != null)
             m.remove(this);
     }

依然是调用ThreadLocalMap的remove方法。
好了,究竟ThreadLocalMap是什么?我们接下来看一下:
其实ThreadLocalMap是一个Map,它的实现和HashMap类似,我们先来看一下用来保存K-V的节点:

        static class Entry extends WeakReference<ThreadLocal<?>> {
            /** The value associated with this ThreadLocal. */
            Object value;

            Entry(ThreadLocal<?> k, Object v) {
                super(k);
                value = v;
            }
        }

我们先不用管WeakReference是什么,我们只需要知道,Entry 保存了k和v。
接下来看一下ThreadLocalMap的set方法的实现:

        private Entry[] table;

        private void set(ThreadLocal<?> key, Object value) {
            Entry[] tab = table;
            int len = tab.length;
            int i = key.threadLocalHashCode & (len-1)
            for (Entry e = tab[i];
                 e != null;
                 e = tab[i = nextIndex(i, len)]) {
                ThreadLocal<?> k = e.get();
                if (k == key) {
                    e.value = value;
                    return;
                }
                if (k == null) {
                    replaceStaleEntry(key, value, i);
                    return;
                }
            }
            tab[i] = new Entry(key, value);
            int sz = ++size;
            if (!cleanSomeSlots(i, sz) && sz >= threshold)
                rehash();
        }

ThreadLocalMap的set方法和早期HashMap的实现类似,都是先计算哈希,然后确定hash槽的位置,不同的是,ThreadLocalMap通过数组存储K-V对象(Entry),而HashMap是通过散列表存储K-V对象。ThreadLocalMap首先获得存储数组的长度,然后通过hash算法计算要设置的节点所在的哈希槽的位置,如果哈希槽的位置没有元素,就新创建一个Entry对象放在这里。如果有元素,就判断该元素的k是否和当前要设置的k相等,如果是就将这个哈希槽存储的entry对象的value重新赋值;如果k是空的话,说明这个ThreadLocal对象被手动设置为null了,是无效的。就把这个节点替换掉,具体怎么实现看一下replaceStaleEntry这个方法,这里不赘述。如果当前哈希槽位置有合法元素,并且k不和要保存的k相等,就去下一个哈希槽的位置重复检查,下一个哈希槽的位置是这个方法计算的:

  private static int nextIndex(int i, int len) {
            return ((i + 1 < len) ? i + 1 : 0);
        }

添加完元素之后,会去判断当前存储数组内元素的数量是否超过了threshold,我们可以叫threshold为扩容因子,threshold = len * 2 / 3。当超过扩容因子的时候就去检查并且移除坏节点。移除坏节后,如果size >= threshold - threshold / 4,就要真正的扩容。我们看一下这个方法:

private void rehash() {
    expungeStaleEntries();
    // Use lower threshold for doubling to avoid hysteresis
    if (size >= threshold - threshold / 4)
        resize();
}

expungeStaleEntries()这个方法,从方法名上可以看出这个方法的作用:去除坏掉的Entries。什么是坏掉的Entries呢?我们可以看一下这个方法的实现:

        private void expungeStaleEntries() {
            Entry[] tab = table;
            int len = tab.length;
            for (int j = 0; j < len; j++) {
                Entry e = tab[j];
                if (e != null && e.get() == null)
                    expungeStaleEntry(j);
            }
        }

当一个节点不是null时,调用节点的get()方法,如果说得的结果是null,这个节点就是坏的节点。Entry的get方法其实是他父类WeakReference<T>的父类Reference<T>的方法。这两个类是什么呢?对JVM有了解的小伙伴应该对这两个类不陌生,我们知道,在java中,对象的引用分为四种:强引用、软引用、弱引用、虚引用。引用强度逐渐减弱。
强引用是我们常见的对象引用,比如:Object o = new Object();
只要一个对象被强引用所引用,就不会被垃圾收集器回收。当内存不足时,jvm会抛出OOM异常。
软引用对应着Reference<T>的实现类SoftReference<T>,这种引用引用的对象不会立刻被回收,但是当内存空间不足的时候,垃圾收集器就会回收软引用所引用的对象。
弱引用对应着Reference<T>的实现类WeakReference<T>,当软引用指向的对象被垃圾收集器发现后,就会回收这个对象(只有软引用引用这个对象,如果强引用同时也引用这个对象时,这个对象并不会被回收)。
虚引用:也叫幽灵引用,虚引用主要用来跟踪对象被垃圾回收器回收的活动。不会影响任何垃圾回收的过程。
回到前面的所说的判断是否为坏的节点,e.get()所获得的其实是e存储的key,也就是ThreadLocal对象。所以我们可以看出,Thread中的ThreadLocalMap并不会影响ThreadLocal在jvm中的生命周期。当一个节点被判定为坏节点后,这个节点就会被移除,具体实现我们看一下expungeStaleEntry这个方法:

        private int expungeStaleEntry(int staleSlot) {
            Entry[] tab = table;
            int len = tab.length;

            // expunge entry at staleSlot
            tab[staleSlot].value = null;
            tab[staleSlot] = null;
            size--;

            // Rehash until we encounter null
            Entry e;
            int i;
            for (i = nextIndex(staleSlot, len);
                 (e = tab[i]) != null;
                 i = nextIndex(i, len)) {
                ThreadLocal<?> k = e.get();
                if (k == null) {
                    e.value = null;
                    tab[i] = null;
                    size--;
                } else {
                    int h = k.threadLocalHashCode & (len - 1);
                    if (h != i) {
                        tab[i] = null;

                        // Unlike Knuth 6.4 Algorithm R, we must scan until
                        // null because multiple entries could have been stale.
                        while (tab[h] != null)
                            h = nextIndex(h, len);
                        tab[h] = e;
                    }
                }
            }
            return i;
        }

首先将位置为staleSlot处存储的Entry的对象的value也设置为null,然后将这个对象从存储的数组中移除,并将size减1。然后一次检查下一个位置(nextIndex(staleSlot, len)确定下一个位置)是否为坏节点,是的话就移除,否则重新计算这个节点所在的位置,将这个节点移动到计算后的新位置。这样做的原因是因为在set节点的时候,如果存在hash冲突,并且key不相等时,会将调用nextIndex(staleSlot, len)方法重新确定hash槽的位置。
真正的扩容方法是由resize()方法实现的,实现过程是很简单。我们看一下具体实现:

        private void resize() {
            Entry[] oldTab = table;
            int oldLen = oldTab.length;
            int newLen = oldLen * 2;
            Entry[] newTab = new Entry[newLen];
            int count = 0;

            for (int j = 0; j < oldLen; ++j) {
                Entry e = oldTab[j];
                if (e != null) {
                    ThreadLocal<?> k = e.get();
                    if (k == null) {
                        e.value = null; // Help the GC
                    } else {
                        int h = k.threadLocalHashCode & (newLen - 1);
                        while (newTab[h] != null)
                            h = nextIndex(h, newLen);
                        newTab[h] = e;
                        count++;
                    }
                }
            }

            setThreshold(newLen);
            size = count;
            table = newTab;
        }

我们可以看到,每次扩容大小都是原来大小的2倍,扩容的过程就是新建一个大小为原来2倍的数组,将原来数组内的元素放到新数据中。过程很简单这里就不在赘述。

二、疑问,线程池中,大量任务使用ThreadLocal会不会造成OOM

根据上面的分析,ThreadLocal实现线程局部存储是通过每个线程Thread中的ThreadLocalMap存储以ThreadLocal对象为key,以要存储的对象为value来实现的。而ThreadLocalMap中,存储K-V是通过Entry实现,Entry继承了WeakReference。所以ThreadLocalMap不会影响ThreadLocal对象的在内存中的回收。通过之前《java线程池浅析》这篇我们可以知道线程池实现的原理其实是多个(或一个)线程执行提交的runnable任务。runnable任务中使用ThreadLocal,当runnable任务结束(其实是run方法结束),runnable任务中使用的ThreadLocal就会失去和GCroot的连接,这个时候只有ThreadLocalMap中的Entry会引用该ThreadLocal对像,所以当内存不足的时候,ThreadLocal对像会被回收。所以在线程池中,这种ThreadLocal的使用是不会造成OOM的。

上一篇 下一篇

猜你喜欢

热点阅读