Java - ThreadLocal 源码阅读笔记

in Tech Java

简介

顾名思义,ThreadLocal 类提供线程局部变量存储功能。也就是数据与线程进行绑定,每个线程都拥有独立的数据副本,数据变更互不影响。当线程被销毁之后,数据也就不复存在啦。

示例

@Log4j2
public class ThreadLocalTest {
    private String threadName;
    private static ThreadLocal<String> threadLocal = new ThreadLocal<>();

    public static void main(String[] args) {
        for (int i = 0; i < 10; i++) {
            new Thread(() -> {
                setThreadName(Thread.currentThread().getName());
                log.info(getThreadName());
            }).start();
        }
    }

    public static String getThreadName() {
        return threadLocal.get();
    }

    public static void setThreadName(String threadName) {
        threadLocal.set(threadName);
    }
}

class ThreadId {
    private static final AtomicInteger nextId = new AtomicInteger(0);
    private static final ThreadLocal<Integer> threadId = ThreadLocal.withInitial(() -> nextId.getAndIncrement());

    /**
     * @return 当前线程的唯一ID
     */
    public static int get() {
        return threadId.get();
    }
}

ThreadLocal 类

public class ThreadLocal<T> {
    /**
     * 每个线程中多个 ThreadLocal 实例的唯一标识
     */
    private final int threadLocalHashCode = nextHashCode();

    private static AtomicInteger nextHashCode =
        new AtomicInteger();

    /**
     * 使用 0x61c88647 ,是为了让哈希码能均匀的分布在2的N次方的数组里,也就是减少出现哈希碰撞的几率
     */
    private static final int HASH_INCREMENT = 0x61c88647;
    ...
}

set() 方法

    public void set(T value) {
        Thread t = Thread.currentThread();// 获取当前线程对象
        // 获取线程中的全局变量 threadLocals
        // 而该变量又是 ThreadLocal 类的内部类
        ThreadLocalMap map = getMap(t);
        if (map != null) // 已初始化 ThreadLocalMap 
            map.set(this, value);
        else // 未初始化 ThreadLocalMap 
            createMap(t, value);
    }

get() 方法

    public T get() {
        Thread t = Thread.currentThread();
        ThreadLocalMap map = getMap(t);
        if (map != null) {
            ThreadLocalMap.Entry e = map.getEntry(this);
            if (e != null) {
                @SuppressWarnings("unchecked")
                T result = (T)e.value;
                return result;
            }
        }
        return setInitialValue();// 没找到,设置值为 null 的初始化操作
    }

getMap() 方法

    ThreadLocalMap getMap(Thread t) {
        return t.threadLocals;
    }

createMap() 方法

    void createMap(Thread t, T firstValue) {
        t.threadLocals = new ThreadLocalMap(this, firstValue);
    }

ThreadLocalMap 类

  static class ThreadLocalMap {

        /**
         * 默认初始容量16,两倍扩增
         */
        private static final int INITIAL_CAPACITY = 16;

        /**
         * 使用数组来存储数据
         */
        private Entry[] table;

        /**
         * 数据条目
         */
        private int size = 0;

        /**
         * 扩容阀值,容量的三分之二
         */
        private int threshold; // Default to 0
   ...
}

构造方法

        ThreadLocalMap(ThreadLocal<?> firstKey, Object firstValue) {
            table = new Entry[INITIAL_CAPACITY];// 初始化
            int i = firstKey.threadLocalHashCode & (INITIAL_CAPACITY - 1);// 计算第一个元素的索引
            table[i] = new Entry(firstKey, firstValue);
            size = 1;
            setThreshold(INITIAL_CAPACITY);
        }

set() 方法

        private void set(ThreadLocal<?> key, Object value) {
            Entry[] tab = table;
            int len = tab.length;
            int i = key.threadLocalHashCode & (len-1);// 计算索引
            // 出现哈希碰撞时,使用传说中的“开放定址法”解决
            // 也就是,往后查找到为空的槽位
            for (Entry e = tab[i];
                 e != null;
                 e = tab[i = nextIndex(i, len)]) {
                ThreadLocal<?> k = e.get();

                if (k == key) {
                    e.value = value;
                    return;
                }

                if (k == null) {// 找到了要插入的槽位
                    replaceStaleEntry(key, value, i);
                    return;
                }
            }

            tab[i] = new Entry(key, value);
            int sz = ++size;
            if (!cleanSomeSlots(i, sz) && sz >= threshold)
                rehash();
        }

getEntry() 方法

        private Entry getEntry(ThreadLocal<?> key) {
            int i = key.threadLocalHashCode & (table.length - 1);// 计算索引
            Entry e = table[i];
            if (e != null && e.get() == key)// 找到了对应的数据
                return e;
            else
                return getEntryAfterMiss(key, i, e);// 没找到就往后找
        }

getEntryAfterMiss() 方法

        private Entry getEntryAfterMiss(ThreadLocal<?> key, int i, Entry e) {
            Entry[] tab = table;
            int len = tab.length;

            while (e != null) {
                ThreadLocal<?> k = e.get();
                if (k == key)
                    return e;
                if (k == null)
                    expungeStaleEntry(i);
                else
                    i = nextIndex(i, len);
                e = tab[i];
            }
            return null;
        }